IMDB数据

包含来自互联网电影数据库的50000条两极分化的评论,数据集被分为用于训练的25000条评论与用于测试的25000条评论,训练集和测试集都包含百分之五十的正面评论和百分之五十负面评论。
该数据集也内置于Keras库,它已经过预处理:评论已经被转换为整数序列,其中每个整数代表字典中的某个单词。
加载IMDB数据集

from tensorflow.keras.datasets import imdb
# 仅仅保留训练数据前10000个最常出现的单词,低频单词将被舍弃,加载数据,其中train_label,test_label都是0和1组成的列表,0代表负面,1代表正面
(train_data, train_labels),(test_data, test_labels) = imdb.load_data(num_words=10000)

实现过程

准备数据
将整数序列编码成二进制矩阵
enumerate方法实例:

>>>seasons = ['Spring', 'Summer', 'Fall', 'Winter']
>>> list(enumerate(seasons))
[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]
>>> list(enumerate(seasons, start=1))       # 下标从 1 开始
[(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')]
import numpy as np
def vectorize_sequence(sequences,dimension=10000):results = np.zeros((len(sequences), dimension))# enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列for i, sequence in enumerate(sequences):results[i, sequence] = 1.return results
# 将数据都向量化
x_train = vectorize_sequence(train_data)
x_test = vectorize_sequence(test_data)
#将标签向量化
y_train = np.asarray(train_labels.astype('float32'))
y_test = np.asarray(test_labels.astype('float32'))


构建网络

model= models.Sequential()
model.add(layers.Dense(16,activation='relu',input_shape=(10000,)))
model.add(layers.Dense(16,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))

relu是深度学习最常用的激活函数,如果没有relu等激活函数(非线性),Dense层将只包含两个线性运算(点积和加法)。
需要选择损失函数和优化器,由于面对的是一个二分类问题,网络输出是一个概率值,那么最好用binary_crossentropy(二元交叉熵)损失,还可以选择mean_squared_error(均方误差),但是对于输出概率值的模型,交叉熵往往是一个最好的选择。
编译模型
keras内置rmsprop优化器,以及binary_crossentropy和accuracy。

# 编译模型
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

配置优化器

# 配置优化器
model.compile(optimizer=optimizers.RMSprop(lr=0.001),loss='binary_crossentropy',metrics=['accuracy'])

使用自定义的损失和指标

# 使用自定义的损失和指标
model.compile(optimizer=optimizers.RMSprop(lr=0.001),loss=losses.binary_crossentropy,metrics=[metrics.binary_accuracy])

验证方法

# 留出验证集,10000个
x_val = x_train[:10000]
partial_x_train = x_train[10000:]
y_val = y_train[:10000]
partial_y_train = y_train[10000:]
# 训练模型,使用512个样本作为小批量,将模型训练20个轮次
model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['acc'])
history = model.fit(partial_x_train,partial_y_train,epochs=20,batch_size=512,#将验证数据传入validation_data参数来完成validation_data=(x_val, y_val))

model.fit()返回了一个history对象,这个对象有一个成员history,它是一个字典,包含训练过程中的所有数据。

绘制图像

# 绘制训练损失和验证损失
history_dict = history.history
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']
epochs = range(1, len(loss_values)+1)
plt.plot(epochs,loss_values,'bo',label='Training loss')
plt.plot(epochs,val_loss_values,'b',label='Validation loss')
plt.title("Training and validation loss")
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
acc_values = history_dict['acc']
val_acc = history_dict['val_acc']
plt.plot(epochs,loss_values,'bo',label='Training acc')
plt.plot(epochs,val_loss_values,'b',label='Validation acc')
plt.title("Training and validation accuracy")
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()



分析:
如图所见,训练损失每轮都在降低,训练精度每轮都在上升,可见它们似乎在第四轮最佳,再往后模型在训练数据上表现的越来越好,在前所未有的数据上不一定表现的越来越好,也就是过拟合,为了防止过拟合,可以在第三轮后停止训练,通常可以采用多种方法来降低拟合。

完整代码:

from tensorflow.keras.datasets import imdb
from tensorflow.keras import models, layers
from tensorflow.keras import optimizers, losses, metrics
import numpy as np
from tensorflow_core.python.keras import regularizersdef vectorize_sequence(sequences,dimension=10000):results = np.zeros((len(sequences), dimension))# enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列for i, sequence in enumerate(sequences):results[i, sequence] = 1.return results
(train_data, train_labels),(test_data, test_labels) = imdb.load_data(num_words=10000)
# 将数据都向量化
x_train = vectorize_sequence(train_data)
x_test = vectorize_sequence(test_data)
y_train = np.asarray((train_labels).astype('float32'))
y_test = np.asarray((test_labels).astype('float32'))
model=models.Sequential()
model.add(layers.Dense(16, kernel_regularizer=regularizers.l2(0.001),activation='relu',input_shape=(10000,)))
model.add(layers.Dense(16,kernel_regularizer=regularizers.l2(0.001),activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(x_train,y_train,epochs=4,batch_size=512)
results = model.evaluate(x_test, y_test)
print(results)
# 使用训练好的网络在新数据上生成预测结果
print(model.predict(x_test))

深度学习1:二分类问题相关推荐

  1. 深度学习的二分类问题(电影评论分类)

    一.目的 会用神经网络解决基本的分类问题. 二.内容 1.准备数据 2.构建网络解决分类问题 3.验证网络,生成预测结果 三.方法与步骤 1.查看keras库的版本 2.IMDB数据集 2.1.加载I ...

  2. 语义分割:基于openCV和深度学习(二)

    语义分割:基于openCV和深度学习(二) Semantic segmentation in images with OpenCV 开始吧-打开segment.py归档并插入以下代码: Semanti ...

  3. 用MXnet实战深度学习之二:Neural art

    用MXnet实战深度学习之二:Neural art - 推酷 题注:本来这是第三集的内容,但是 Eric Xie 勤劳又机智的修复了mxnet和cuDNN的协作问题,我就把这篇当作一个卷积网络Conv ...

  4. 万字总结Keras深度学习中文文本分类

    摘要:文章将详细讲解Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CNN.TextCNN. 本文分享自华为云社区<Keras深度学习中文 ...

  5. 花书+吴恩达深度学习(二)非线性激活函数(ReLU, maxout, sigmoid, tanh)

    目录 0. 前言 1. ReLU 整流线性单元 2. 绝对值整流线性单元 3. 渗漏整流线性单元 4. 参数化整流线性单元 5. maxout 单元 6. logistic sigmoid 单元 7. ...

  6. 深度学习在情感分类中的应用

    简介与背景 情感分类及其作用 情感分类是情感分析的重要组成部分,情感分类是针对文本的情感倾向进行极性分类,分类数量可以是二分类(积极或消极),也可以是多分类(按情感表达的不同程度),情感分析在影音评论 ...

  7. 综述:基于深度学习的文本分类 --《Deep Learning Based Text Classification: A Comprehensive Review》总结(一)

    文章目录 综述:基于深度学习的文本分类 <Deep Learning Based Text Classification: A Comprehensive Review>论文总结(一) 总 ...

  8. 自然语言处理入门实战2:基于深度学习的文本分类

    自然语言处理入门实战2:基于深度学习的文本分类 数据集 数据预处理 模型 模型训练 模型测试 参考 本文参考复旦大学自然语言处理入门练习,主要是实现基于深度学习的文本分类. 环境:python3.7 ...

  9. 自然语言处理入门实战3:基于深度学习的文本分类(2)

    基于深度学习的文本分类(2) 数据集 数据预处理 CNN模型 RNN模型 利用CNN模型进行训练和测试 利用RNN模型进行训练和测试 预测 总结 参考 本文主要是使用CNN和RNN进行文本分类操作. ...

  10. MATLAB与深度学习(二)— 训练神经网络(图像分类识别)

    MATLAB与深度学习(二)- 训练神经网络(图像分类识别) 上一篇,我们介绍了与深度学习相关的MATLAB工具包.这一篇,我们将介绍如何训练神经网络和相关的基础知识.本文借鉴和引用了网上许多前辈的经 ...

最新文章

  1. Nginx的File not found 错误解决
  2. Linux sed删除文件注释行并删除空行
  3. 游戏引擎和编程语言的关系
  4. Redis主从复制配置(原理剖析)
  5. 将来时态:I will fly - I'm going to fly - I'm flying_48
  6. Excel转html
  7. sqlserver200864位下载_SQL2008下载 SQL Server 2008 R2 简体中文版(64位) 下载-脚本之家
  8. (学习笔记)地理加权回归
  9. 总结了一些很实用值得收藏的站点
  10. 微信小程序引入 vant UI组件库
  11. 如何添加使用微信小程序,教程在这里,微信小程序怎样添加使用
  12. An Underwater Image Enhancement Benchmark Dataset and Beyong
  13. 浏览器如何工作(How browsers work)
  14. ppt中流程图旁边怎么添加_word中流程图怎么导入到ppt ppt流程图导入word
  15. php字符串函数处理emoji,PHP中处理内容含有emoji表情的几种方式
  16. DELPHI 颜色表
  17. 浅谈Git原理和常用命令(学习笔记)
  18. java开发面试自我介绍模板_java求职自我介绍范文_java工程师面试个人介绍
  19. apache 安装与修改端口,修改默认页面,配置虚拟主机
  20. Mongoose初使用总结

热门文章

  1. 2021年12月青少年C/C++软件编程(四级)等级考试试卷及答案解析
  2. 学习《华为基本法》(13):市场营销
  3. [导入]discuz!NT整合经验总结
  4. 从简历被拒,到斩获 BAT offer,全靠这些吊炸天的公众号!
  5. 如何把很多照片拼成一张照片_如何能把多张照片拼凑在一张上图片上
  6. scrapy框架之全站数据的爬取
  7. 2021年安徽省大数据与人工智能应用竞赛人工智能(网络赛)-本科组赛题
  8. RFC8314文档中对465端口和587端口的阐述
  9. 课本剧剧本和计算机专业相关,【课本剧】 高中课本剧剧本大全
  10. pmp-相关方权利/利益方格