第一步 下载数据集到本地

提取码:9h3u

存储位置:C:/用户/用户名/.keras/datasets

(用户名不同人不一样,可能电脑不一样存储位置也略有差异)

第二步 导入数据集

import keras

import numpy as np

# load data

from keras.datasets import imdb

(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

----------

查看数据集是否导入正确

print(train_labels[0]) #1

print(max([max(sequence) for sequence in train_data])) #9999

----------

遇到的一些小问题以及解决办法:

若出现了几个问题,最后差不多是这样:raise ValueError("Object arrays cannot be loaded when " ValueError: Object arrays cannot be loaded ……

这说明numpy版本太高了,我一开始的版本是1.16.4,之后转换成了1.16.2

版本转换:

cmd输入xxxxxxxxxxxxxxxx numpy==1.16.2

xxxxxxxxxx为https://mirrors.tuna.tsinghua.edu.cn/help/pypi/中代码,可加快下载速度,直接复制,只需要将some-package改成numpy==1.16.2即可

第三步 电影评论二分类完整代码示例

import keras

import numpy as np

# load data

from keras.datasets import imdb

(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

print(train_labels[0]) #1

print(max([max(sequence) for sequence in train_data])) #9999

# 将索引解码为单词,需要下载imdb_word_index.json至C:/用户/用户名/.keras/datasets

# 链接:https://pan.baidu.com/s/1kkmpXrr1tkFtg7D3LX_lcw 提取码:wzjw

word_index = imdb.get_word_index() #将单词映射为整数索引的字典

reverse_word_index = dict([(value, key) for (key, value) in word_index.items()]) #键值颠倒,将整数索引映射为单词

decoded_review = ' '.join([reverse_word_index.get(i - 3, '?') for i in train_data[0]])

#print(decoded_review)

# 对列表进行one-shot编码,eg.将[3,5]转换成[0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,...]

def vectorize_sequences(sequences, dimension=10000):

results = np.zeros((len(sequences), dimension))

for i, sequence in enumerate(sequences):

results[i,sequence] = 1. #将 results[i] 的指定索引设为 1

return results

# handle input data

x_train = vectorize_sequences(train_data)

x_test = vectorize_sequences(test_data)

#print(x_train[0]) #[0. 1. 1. ... 0. 0. 0.]

# handle output data

y_train = np.asarray(train_labels).astype('float32')

y_test = np.asarray(test_labels).astype('float32')

# 验证集预留

x_val = x_train[:10000]

partial_x_train = x_train[10000:]

y_val = y_train[:10000]

partial_y_train = y_train[10000:]

# build model

from keras import models

from keras import layers

model = models.Sequential()

model.add(layers.Dense(16, activation='relu', input_shape=(10000,)))

model.add(layers.Dense(16, activation='relu'))

model.add(layers.Dense(1, activation='sigmoid'))

# train model

model.compile(optimizer='rmsprop',

loss='binary_crossentropy',

metrics=['accuracy'])

history = model.fit(partial_x_train,

partial_y_train,

epochs=20,

batch_size=512,

validation_data=(x_val, y_val))

history_dict = history.history

#print(history_dict.keys()) #dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])

# 绘制训练损失、验证损失、训练精度、验证精度

import matplotlib.pyplot as plt

# plot loss

acc = history.history['acc']

val_acc = history.history['val_acc']

loss = history.history['loss']

val_loss = history.history['val_loss']

epochs = range(1, len(acc)+1)

plt.plot(epochs, loss, 'bo', label='Training loss') #blue o

plt.plot(epochs, val_loss, 'b', label='Validation loss') #blue solid line

plt.title('Training and validation loss')

plt.xlabel('Epochs')

plt.ylabel('Loss')

plt.legend()

plt.show()

# plot accuracy

plt.clf() #清空图像

plt.plot(epochs, acc, 'bo', label='Training acc')

plt.plot(epochs, val_acc, 'b', label='Validation acc')

plt.title('Training and validation accuracy')

plt.xlabel('Epochs')

plt.ylabel('Accuracy')

plt.legend()

plt.show()

----------

此处验证集用来确定训练NN可采用的最佳epoch,训练集--->NN参数,得出epochs=4

用新参数搭建的NN去训练train_data,注释掉history以及history之后的代码:

model.fit(x_train, y_train, epochs=4, batch_size=512)

results = model.evaluate(x_test, y_test)

print(results)

print(model.predict(x_test))

------------

进一步实验的实验结果,(控制变量:

[0.29455984374523164, 0.88312] #原结构共三层

[0.2833905682277679, 0.88576] #共两层

[0.30949291754722597, 0.87984] #神经元个数为32

[0.08610797638118267, 0.88308] #mse

[0.32080167996406556, 0.87764] #用tanh代替relu

选定的结构较为合适。

imdb导mysql_keras如何导入本地下载的imdb数据集?相关推荐

  1. 【python】pycharm 中导入本地下载好的库

    Pycharm中导入库基本上都是使用在Interpreter中连网在线下载添加.(下图中的加号) 但是,有的时候不知为什么,总会出现导库失败.然后就想着直接下载库,之后导入环境中.下面来介绍一下简单的 ...

  2. python模块导入红色波浪线_解决pycharm导入本地py文件时,模块下方出现红色波浪线的问题...

    有时候导入本地模块或者py文件时,下方会出现红色的波浪线,但不影响程序的正常运行,但是在查看源函数文件时,会出现问题 问题如下: 解决方案: 1. 进入设置,找到Console下的Python Con ...

  3. 哪个读书app可以导入txt_QQ阅读iphone版怎么导入电子书 三种手机QQ阅读器导入本地图书图文教程...

    QQ阅读iphone版是一款比较方便的移动终端阅读软件,除了从电子书城下载或者购买电子书外,我们也可以把自己电脑中的电子书上传到QQ阅读软件中,不过果粉们都知道苹果手机没有文件管理器,无法使用qq阅读 ...

  4. go mod导入本地包

    利用go mod导入本地包 在实际项目开发过程中,为了完成一些功能,往往需要自己在本地新建一些包,然后在项目的其他go文件中调用该包.当使用go mod管理 包时,会出现一些错误,比如:如果我们本地的 ...

  5. 鸢尾花(iris)数据集保存到本地以及sklearn其他数据集下载保存

    鸢尾花数据集 问题起源 在机器学习到分类问题时,使用sklearn下载数据集的时候,不是很明白具体怎么下载的,以及如何下载其他数据集,于是仔细思考了一番 查看鸢尾花数据集 首先先看代码块 #从skle ...

  6. ABAP-1-会计凭证批量数据导入本地ACCESS

    公司会计凭证导入ACCESS数据库,需要发送给审计,原先的方案是采用DEPHI开发的功能(调用函数获取会计凭证信息,然后INSERT到ACCESS数据表),运行速度非常慢,业务方要求对该功能进行优化, ...

  7. 计算机视觉两个入门数据集(mnist和fashion mnist)本地下载地址

    1.计算机视觉经典数据集 1.mnist数据集 MNIST(Mixed National Institute of Standards andTechnology database)数据集大家可以说是 ...

  8. 批量地导入本地的scholar.enw到endnote

    批量地导入本地的scholar.enw到endnote EndNote 软件对于科研工作者来说,它是亲密无间的得力助手.可是,有时候也会有不尽人意的事情发生,比如在导入从谷歌学术精挑细选之后下载的的s ...

  9. idea在离线状态下使用maven导入本地仓库

    针对idea在离线状态下使用maven导入本地仓库的问题 当idea处于离线状态下,例如没有互联网,或者从事保密性质开发,电脑设备不允许联网,如何通过导入拷贝的maven本地仓库进行开发. (今天查了 ...

最新文章

  1. 机器学习之类别性特征
  2. C++中的const成员函数
  3. 07-CA/TA编程:rsakey demo
  4. WPF效果(GIS三维续篇)
  5. 数字逻辑基础与verilog设计_数字电路学习笔记(五):逻辑设计基础
  6. 阿里宜搭重磅发布专有云版本、精品应用市场,助力政企数字化转型
  7. 学JAVA的诗句_学Java有感(终)
  8. 躬身入境DIY - 《传奇动物园》北京沙盘活动精彩回顾
  9. 国培计算机培训奥鹏,3515011349奥鹏国培培训网络研修总结
  10. QT修改releas发布的exe图标
  11. 车主因眼睛小被自动驾驶误判?——智能座舱CV体验的经典corner case剖析 by 资深AI产品经理@方舟...
  12. Qgis 如何根据范围来裁剪地图,高程图等
  13. Java MultipartFile实现文件上传并为图片加上水印(二)
  14. 电力设备事故演练仿真培训_电力事故VR培训_广州华锐互动
  15. java支付宝rsa2签名_JAVA RSA签名 解签(利用支付宝封装的函数)
  16. sql注入--基本注入语句学习笔记
  17. Mysql创建自增序列方案(模拟Oracle序列)
  18. HTML怎么把图片颜色加深,怎么把Photoshop的图片整体颜色加深?
  19. 多机房UPS及环境集中监控方案丨UPS环境综合监控主机
  20. MFC下调用yolo_cpp_dll.dll

热门文章

  1. fsb,fev文件格式转换,提取与打包
  2. C++ priority_queue用法
  3. 【网络安全】浅析跨域原理及如何实现跨域
  4. 管理员请注意 一条后门病毒攻击链正在针对服务器发起入侵
  5. initdz linux挖坑病毒分析
  6. addr 与 offset 区别
  7. Python操作Json、Csv、Excel文件
  8. 图的邻接矩阵存储和邻接表存储定义方法
  9. 1.2.3 算法的空间复杂度
  10. C语言易错题集 第二部