在将原始数据打乱并分成训练集和验证集之后,进行训练模型,本文是十二层的卷积结构,实验过程中可调batch size、epoch、优化参数等,输出为分类准确率和分类损失率的折线图

# In[1]:
import os
import shutiltrain_dir = r'C:\Users\10133\Desktop\bishe\matlab\traintest\train'
test_dir = r'C:\Users\10133\Desktop\bishe\matlab\traintest\test'#导入训练集和验证集# In[2]建立结构
from keras import layers
from keras import modelsmodel = models.Sequential()  # 由Sequential model定义的Keras
model.add(layers.Conv2D(32, (5, 5), padding='same', activation='relu',input_shape=(32, 32, 3)))  # 卷积层  (samples, timesteps,features)
model.add(layers.MaxPooling2D(2, 2))  # 池化层  (samples, features)model.add(layers.Conv2D(64, (5, 5), padding='same', activation='relu'))  # 卷积层
model.add(layers.MaxPooling2D(2, 2))  # 池化层model.add(layers.Conv2D(128, (5, 5), padding='same', activation='relu'))  # 卷积层
model.add(layers.MaxPooling2D(2, 2))  # 池化层model.add(layers.Conv2D(128, (5, 5), padding='same', activation='relu'))  # 卷积层
model.add(layers.MaxPooling2D(2, 2))  # 池化层model.add(layers.Flatten())  # 多维输入一维化
model.add(layers.Dense(512, activation='relu'))  # 全连接层 512个神经元 激活函数Relu
model.add(layers.Dropout(0.4))  # 正则化
model.add(layers.Dense(5, activation='softmax'))  # 返回,输出四个类别的概率 数组 总和为1
model.summary()# In[3]
from keras import optimizers
import kerasmodel.compile(loss='categorical_crossentropy',  # 损失函数optimizer=optimizers.Adam(lr=1e-4),  # 优化参数metrics=['acc'])  # 度量指标# In[4]:数据读取 数据预处理from keras.preprocessing.image import ImageDataGeneratortrain_datagen = ImageDataGenerator(rescale=1. / 255)
test_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory(  # 训练路径directory=train_dir,target_size=(256, 256),batch_size=4,  # 将数据分解成小批量,训练过程中这个参数应该保持一致
)test_generator = test_datagen.flow_from_directory(  # 测试路径directory=test_dir,target_size=(256, 256),batch_size=4,
)# In[5]:训练模型,对model.fit()的调用返回一个history对象 包含在训练期间发生的所有事情的数据num_of_train_samples = 158
num_of_test_samples = 39
batch_size =4
epochs = 10#迭代次数,训练过程中的所有迭代次数应该是一样的history = model.fit(train_generator,steps_per_epoch=num_of_train_samples // batch_size,epochs=epochs,validation_data=test_generator,validation_steps=num_of_test_samples // batch_size)# In[6]:#history有四个关键词 用来绘制图像
import numpy as np
from pylab import mpl
import matplotlib.pyplot as plt
from matplotlib.font_manager import _rebuild_rebuild()
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = Falseacc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))def smooth_curve(points, factor=0.8):  # 查看loss曲线 利用代码实现在TensorBoard里面的将曲线变平缓的功能smoothed_points = []for point in points:if smoothed_points:previous = smoothed_points[-1]smoothed_points.append(previous * factor + point * (1 - factor))else:smoothed_points.append(point)return smoothed_pointsplt.plot(epochs, acc, color='red', marker='o', label='Training acc')
plt.plot(epochs, val_acc, color='blue', marker='d', label='Validation acc')
plt.annotate('精度{0}'.format(np.max(val_acc)), xy=(np.argmax(val_acc), np.max(val_acc)),arrowprops=dict(facecolor="orange", shrink=0.05), fontsize=10, color='black')
# plt.title('矩阵图')
plt.xlabel('迭代次数/次', verticalalignment='top', fontsize=12)
plt.ylabel('精度/%', horizontalalignment='right', fontsize=12)
plt.legend()plt.savefig(r'C:\Users\10133\Desktop\bishe\matlab\分类准确率.jpg')
plt.figure()plt.plot(epochs, loss, color='red', marker='o', label='Training loss')
plt.plot(epochs, val_loss, color='blue', marker='p', label='Validation loss')
# plt.title('矩阵图')
plt.xlabel('迭代次数/次', verticalalignment='top', fontsize=12)
plt.ylabel('损失/%', horizontalalignment='right', fontsize=12)
plt.savefig(r'C:\Users\10133\Desktop\bishe\matlab\分类损失率.jpg')
plt.legend()plt.show()# In[7]:acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']# 存储模型训练精度
steps = []
for i in range(1, 11):steps.append(str(i))
steps = list(map(int, steps))import numpy as np
import xlwt# 创建工作薄,但是好像运行的时候表单一直没有出来过
book = xlwt.Workbook()# 创建表单
sheet1 = book.add_sheet(u'sheet1', cell_overwrite_ok=True)for i in range(epochs):sheet1.write(i, 0, steps[i])sheet1.write(i, 1, val_acc[i])sheet1.write(i, 2, acc[i])sheet1.write(i, 3, loss[i])sheet1.write(i, 4, val_loss[i])# sheet1.write(11,1,str(loss_test))# sheet1.write(11,2,str(acc_test))
# 保存文件
book.save(r'C:\Users\10133\Desktop\bishe\matlab\分类训练值.xls')

【4】基于深度神经网络的脑电睡眠分期方法研究(训练模型)相关推荐

  1. 基于深度神经网络的遮挡人脸识别算法的研究(小白初学)

    基于深度神经网络的遮挡人脸识别算法的研究(小白初学) 研究背景 在自然条件下人脸面部的光照变化.角度变化.表情变化以及存在遮挡物,使得采集到的人脸图像存在人脸特征的损失.因此研究遮挡人脸识别算法提高识 ...

  2. 014基于深度学习的脑电癫痫自动检测系统-2018(300引用)

    An automated system for epilepsy detection using EEG brain signals based on deep learning approach   ...

  3. 基于深度强化学习的智能车间调度方法研究

    摘要: 工业物联网的空前繁荣为传统的工业生产制造模式开辟了一条新的道路.智能车间调度是整个生产过程实现全面控制和柔性生产的关键技术之一,要求以最大完工时间最小化分派多道工序和多台机器的生产调度.首先, ...

  4. 基于深度学习的遥感图像场景识别方法研究

    文章目录 概述 方法原理 代码实现 结果分析 SVM Resnet LSTM 概述 从2012年深度卷积神经网络(AlexNet)成功应用于图像识别以来,发展出多个改进的卷积神经网络构架,包括2014 ...

  5. (DEAP)基于图卷积神经网络的脑电情绪识别(附代码)

    1. 数据集介绍以及特征部分见上篇文章: DEAP数据集介绍以及特征提取部分 深度学习基于DEAP的脑电情绪识别情感分类(附代码)_qq_3196288251的博客-CSDN博客 2. 图卷积神经网络 ...

  6. 深度神经网络对脑电信号运动想象动作的在线解码

    目录 简介 网络模型 结果比较 结论 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 QQ交流群:941473018 简介 近年来,深度学习方法的快速发展使得无需任何特征工程的端到端学习成为 ...

  7. (脑肿瘤分割笔记:四十四)基于深度学习的脑肿瘤分割的综述

    目录 Abstract&Introduction 脑肿瘤分割任务面临的主要挑战 深度学习方法的脑肿瘤分割的方法 脑肿瘤分割方法一:设计有效的架构分割方法 针对于不同目的的模型 对于精度有要求的 ...

  8. 优梦思脑电睡眠监测方案惊艳亮相CMEF 开创睡眠精准医疗新景象

    热讯:第84届中国国际医疗器械(春季)博览会(CMEF2021)近日在上海国家会展中心举办.云睿智能作为鱼跃医疗的全国战略合作伙伴,携优梦思便携式脑电监测设备在呼吸展区惊艳亮相,一经展出就迅速吸引了现 ...

  9. 【深度学习】基于深度神经网络进行权重剪枝的算法(二)

    [深度学习]基于深度神经网络进行权重剪枝的算法(二) 文章目录 1 摘要 2 介绍 3 OBD 4 一个例子 1 摘要 通过从网络中删除不重要的权重,可以有更好的泛化能力.需求更少的训练样本.更少的学 ...

  10. 【深度学习】基于深度神经网络进行权重剪枝的算法(一)

    [深度学习]基于深度神经网络进行权重剪枝的算法(一) 1 pruning 2 代码例子 3 tensorflow2 keras 权重剪裁(tensorflow-model-optimization)3 ...

最新文章

  1. 你知道“啥是佩奇”,却不一定了解佩奇排名算法
  2. 一文读懂Https的安全性原理、数字证书、单项认证、双项认证等
  3. opencl track资料整理
  4. linux 提取某一行内容
  5. 查看mysql本地路径
  6. Gradle 配置jetty启动项目
  7. GameObject数组逐渐消失
  8. 如何测试Java类的线程安全性
  9. 通过QMP/QGA与虚拟机进行交互
  10. 弃用数据库自增ID,曝光一下我自己用到的解决方法之---终结篇
  11. c语言不用临时变量交换两个数程序分析
  12. SpringBoot2整合Shiro实现权限管理
  13. PDF控件Aspose.Pdf 12月新版17.12发布 | 附下载
  14. 速成! | 遗传算法详解及其MATLAB实现
  15. android 投屏代码,android投屏技术:控制设备源码分析
  16. 水晶易表(Xcelsius) 2008 学习
  17. Springboot整合JdbcTemplate实现分页查询
  18. java身份证号/手机号隐藏中间几位
  19. 【BZOJ1135】【POI2009】Lyz
  20. Phpspreadsheet 中文文档(六)读写文件+读取文件

热门文章

  1. WBE漏洞-SQL注入之报错盲注
  2. docker搭建sftp服务器
  3. 黑客网络安全扫描工具
  4. echarts中国地图(省市两级经纬度版本)
  5. 详述人工智能在自动驾驶中的应用
  6. 防范蠕虫式勒索软件病毒攻击的安全预警通告
  7. osgearth加载mapbox在线高程数据
  8. 计算机软考初级信息技术试题及答案,2015年软考信息技术处理员考试模拟试题及答案...
  9. 前端学习笔记 - 移动端Web开发
  10. 电子海图数据购买、安装、更新及使用注意事项