为了探究更多网络图像分类的效果,尝试LSTM网络处理,顺便谈一谈对循环神经网络的简单理解。最终效果:7M模型85%准确率,单层网络。对比之间做的CNN效果(7M模型,95%准确率,但存在过拟合问题),文章链接https://blog.csdn.net/qq_36187544/article/details/90669462(附源代码)

目录

项目源码百度云

循环神经网络粗浅理解

调参

tensorboard展示

源代码


项目源码百度云

注:图片都是经过预处理的,统一大小,不然会报错!图像处理文件路径可以参考上面的CNN网络链接

链接:https://pan.baidu.com/s/1h0pKo5-p-JDPtM-iUs84_Q 
提取码:j44p 

models,logs 两个文件夹用于存放模型文件和日志文件,现均为空,带上文件夹让程序可以直接运行
data 数据文件夹,详细图参考上右图,分为7类,每类下有图片。为了防止数据外泄,只在lh1中放了一张图片,可以查看图片是何样
setting.py 配置文件
rnn_train.py 网络训练文件,主文件

循环神经网络粗浅理解

百度一搜各种LSTM,RNN详解,这里只简单说一下:

RNN说白了就是序列化,以28×28图片为例,生成28个CELL,最后对output[28]输出处理一下即可:

所以,对于RGB彩图,先看代码,再说下原理,网络框架部分代码:

def rnn_graph(x, rnn_size, out_size, width, height, channel):'''循环神经网络计算图:param x:输入数据:param rnn_size::param out_size::param width::param height::return:'''# 权重及偏置w = weight_variable([rnn_size, out_size])b = bias_variable([out_size])# LSTM# rnn_size这里指BasicLSTMCell的num_units,指输出的向量维度lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)# transpose的作用将(?,32,448,3)形状转为(32,?,448,3),?为batch-size,32为高,448为宽,3为通道数(彩图)# 准备划分为32个相同网络,输入序列为(448,3),这样速度较快,逻辑较为符合一般思维x = tf.transpose(x, [1,0,2,3])# reshape -1 代表自适应,这里按照图像每一列的长度为reshape后的列长度x = tf.reshape(x, [-1, channel*width])# split默任在第一维即0 dimension进行分割,分割成height份,这里实际指把所有图片向量按对应行号进行重组x = tf.split(x, height)# 这里RNN会有与输入层相同数量的输出层,我们只需要最后一个输出outputs, status = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)y_conv = tf.add(tf.matmul(outputs[-1], w), b)return y_conv

(32,?,448,3)格式的数据传入网络目的:分为32个cell,每个序列对应448*3,即3色的横向条状序列!

如果格式转为以竖向条状序列更改可如下,这样做网络将很大:

# x = tf.transpose(x, [1,0,2,3])
# x = tf.reshape(x, [-1, channel*width])
# x = tf.split(x, height)
x = tf.transpose(x, [2,0,1,3])
x = tf.reshape(x, [-1, channel*height])
x = tf.split(x, width)

如果调整为3个cell,每个原色的图作为一个输入也是同理!


调参

1.batch-size,很重要,合适的batch-size才能收敛合适,https://blog.csdn.net/qq_36187544/article/details/90478051

2.学习率:

3.序列多少?RNN网络的核心思想之一是前后序列有关,所以考虑一张长方形图片分为横条和竖条效果是不是不一样?后发现基本一样。。。。。。那就采用小序列进行训练,这样可以加快训练速度

4.RNN中num_units参数,越大学习到的特征越多,准确率提升,相当于增宽神经网络

5.没有尝试加深网络,单层测试,准确率85%


tensorboard展示

数据流图:

损失和准确率:


源代码

rnn_train.py源代码:

import os
import tensorflow as tf
from time import time
import numpy as np
from LSTM.setting import batch_size, width, height, rnn_size, out_size, channel, learning_rate, num_epoch'''
训练主函数
tensorboard --logdir=D:\python\LSTM\logs
'''def weight_variable(shape, w_alpha=0.01):'''增加噪音,随机生成权重:param shape: 权重形状:param w_alpha:随机噪声:return:'''initial = w_alpha * tf.random_normal(shape)return tf.Variable(initial)
def bias_variable(shape, b_alpha=0.1):'''增加噪音,随机生成偏置项:param shape:权重形状:param b_alpha:随机噪声:return:'''initial = b_alpha * tf.random_normal(shape)return tf.Variable(initial)
def rnn_graph(x, rnn_size, out_size, width, height, channel):'''循环神经网络计算图:param x:输入数据:param rnn_size::param out_size::param width::param height::return:'''# 权重及偏置w = weight_variable([rnn_size, out_size])b = bias_variable([out_size])# LSTM# rnn_size这里指BasicLSTMCell的num_units,指输出的向量维度lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)# transpose的作用将(?,32,448,3)形状转为(32,?,448,3),?为batch-size,32为高,448为宽,3为通道数(彩图)# 准备划分为32个相同网络,输入序列为(448,3),这样速度较快,逻辑较为符合一般思维x = tf.transpose(x, [1,0,2,3])# reshape -1 代表自适应,这里按照图像每一列的长度为reshape后的列长度x = tf.reshape(x, [-1, channel*width])# split默任在第一维即0 dimension进行分割,分割成height份,这里实际指把所有图片向量按对应行号进行重组x = tf.split(x, height)# 这里RNN会有与输入层相同数量的输出层,我们只需要最后一个输出outputs, status = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)y_conv = tf.add(tf.matmul(outputs[-1], w), b)return y_convdef accuracy_graph(y, y_conv):'''偏差计算图:param y::param y_conv::return:'''correct = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))return accuracydef get_batch(image_list,label_list,img_width,img_height,batch_size,capacity,channel):'''#通过读取列表来载入批量图片及标签:param image_list: 图片路径list:param label_list: 标签list:param img_width: 图片宽度:param img_height: 图片高度:param batch_size::param capacity::return:'''image = tf.cast(image_list,tf.string)label = tf.cast(label_list,tf.int32)input_queue = tf.train.slice_input_producer([image,label],shuffle=True)label = input_queue[1]image_contents = tf.read_file(input_queue[0])image = tf.image.decode_jpeg(image_contents,channels=channel)image = tf.cast(image,tf.float32)if channel==3:image -= [42.79902,42.79902,42.79902] # 减均值elif channel == 1:image -= 42.79902  # 减均值image.set_shape((img_height,img_width,channel))image_batch,label_batch = tf.train.batch([image,label],batch_size=batch_size,num_threads=64,capacity=capacity)label_batch = tf.reshape(label_batch,[batch_size])return image_batch,label_batchdef get_file(file_dir):'''通过文件路径获取图片路径及标签:param file_dir: 文件路径:return:'''images = []for root,sub_folders,files in os.walk(file_dir):for name in files:images.append(os.path.join(root,name))labels = []for label_name in images:letter = label_name.split("\\")[-2]if letter =="lh1":labels.append(0)elif letter =="lh2":labels.append(1)elif letter == "lh3":labels.append(2)elif letter == "lh4":labels.append(3)elif letter == "lh5":labels.append(4)elif letter == "lh6":labels.append(5)elif letter == "lh7":labels.append(6)print("check for get_file:",images[0],"label is ",labels[0])#shuffletemp = np.array([images,labels])temp = temp.transpose()np.random.shuffle(temp)image_list = list(temp[:,0])label_list = list(temp[:,1])label_list = [int(float(i)) for i in label_list]return image_list,label_list#标签格式重构
def onehot(labels):n_sample = len(labels)n_class = 7  # max(labels) + 1onehot_labels = np.zeros((n_sample,n_class))onehot_labels[np.arange(n_sample),labels] = 1return onehot_labelsif __name__ == '__main__':startTime = time()# 按照图片大小申请占位符x = tf.placeholder(tf.float32, [None, height, width, channel])y = tf.placeholder(tf.float32)# rnn模型y_conv = rnn_graph(x, rnn_size, out_size, width, height, channel)# 独热编码转化y_conv_prediction = tf.argmax(y_conv, 1)y_real = tf.argmax(y, 1)# 优化计算图loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_conv, labels=y))optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)# 偏差accuracy = accuracy_graph(y, y_conv)# 自训练图像xs, ys = get_file('./data/train1')  # 获取图像列表与标签列表image_batch, label_batch = get_batch(xs, ys, img_width=width, img_height=height, batch_size=batch_size, capacity=256,channel=channel)# 验证集xs_val, ys_val = get_file('./data/test1')  # 获取图像列表与标签列表image_val_batch, label_val_batch = get_batch(xs_val, ys_val, img_width=width, img_height=height,batch_size=455, capacity=256,channel=channel)# 启动会话.开始训练sess = tf.Session()sess.run(tf.global_variables_initializer())saver = tf.train.Saver()# 启动线程coord = tf.train.Coordinator()  # 使用协调器管理线程threads = tf.train.start_queue_runners(coord=coord, sess=sess)# 日志记录summary_writer = tf.summary.FileWriter('./logs/', graph=sess.graph, flush_secs=15)summary_writer2 = tf.summary.FileWriter('./logs/plot2/', flush_secs=15)tf.summary.scalar(name='loss_func', tensor=loss)tf.summary.scalar(name='accuracy', tensor=accuracy)merged_summary_op = tf.summary.merge_all()step = 0acc_rate = 0.98epoch_start_time = time()for i in range(num_epoch):batch_x, batch_y = sess.run([image_batch, label_batch])batch_y = onehot(batch_y)merged_summary,_,loss_show = sess.run([merged_summary_op,optimizer,loss], feed_dict={x: batch_x, y: batch_y})summary_writer.add_summary(merged_summary, global_step=i)if i % (int(7000//batch_size)) == 0:batch_x_test, batch_y_test = sess.run([image_val_batch, label_val_batch])batch_y_test = onehot(batch_y_test)batch_x_test = batch_x_test.reshape([-1, height, width, channel])merged_summary_val,acc,prediction_val_out,real_val_out,loss_show = sess.run([merged_summary_op,accuracy,y_conv_prediction,y_real,loss],feed_dict={x: batch_x_test, y: batch_y_test})summary_writer2.add_summary(merged_summary_val, global_step=i)# 输出每个类别正确率lh1_right, lh2_right, lh3_right, lh4_right, lh5_right, lh6_right, lh7_right = 0, 0, 0, 0, 0, 0, 0lh1_wrong, lh2_wrong, lh3_wrong, lh4_wrong, lh5_wrong, lh6_wrong, lh7_wrong = 0, 0, 0, 0, 0, 0, 0for ii in range(len(prediction_val_out)):if prediction_val_out[ii] == real_val_out[ii]:if real_val_out[ii] == 0:lh1_right += 1elif real_val_out[ii] == 1:lh2_right += 1elif real_val_out[ii] == 2:lh3_right += 1elif real_val_out[ii] == 3:lh4_right += 1elif real_val_out[ii] == 4:lh5_right += 1elif real_val_out[ii] == 5:lh6_right += 1elif real_val_out[ii] == 6:lh7_right += 1else:if real_val_out[ii] == 0:lh1_wrong += 1elif real_val_out[ii] == 1:lh2_wrong += 1elif real_val_out[ii] == 2:lh3_wrong += 1elif real_val_out[ii] == 3:lh4_wrong += 1elif real_val_out[ii] == 4:lh5_wrong += 1elif real_val_out[ii] == 5:lh6_wrong += 1elif real_val_out[ii] == 6:lh7_wrong += 1print(step, "correct rate :", ((lh1_right) / (lh1_right + lh1_wrong)), ((lh2_right) / (lh2_right + lh2_wrong)),((lh3_right) / (lh3_right + lh3_wrong)), ((lh4_right) / (lh4_right + lh4_wrong)),((lh5_right) / (lh5_right + lh5_wrong)), ((lh6_right) / (lh6_right + lh6_wrong)),((lh7_right) / (lh7_right + lh7_wrong)))print(step, "准确的估计准确率为",(((lh1_right) / (lh1_right + lh1_wrong))+((lh2_right) / (lh2_right + lh2_wrong))+((lh3_right) / (lh3_right + lh3_wrong))+((lh4_right) / (lh4_right + lh4_wrong))+((lh5_right) / (lh5_right + lh5_wrong))+((lh6_right) / (lh6_right + lh6_wrong))+((lh7_right) / (lh7_right + lh7_wrong)))/7)epoch_end_time = time()print("takes time:",(epoch_end_time-epoch_start_time), ' step:', step, ' accuracy:', acc," loss_fun:",loss_show)epoch_start_time = epoch_end_time# 偏差满足要求,保存模型if acc >= acc_rate:model_path = os.getcwd() + os.sep + '\models\\'+str(acc_rate) + "LSTM.model"saver.save(sess, model_path, global_step=step)breakif step % 10 == 0 and step != 0:model_path = os.getcwd() + os.sep + '\models\\'  + str(acc_rate)+ "LSTM"+str(step)+".model"print(model_path)saver.save(sess, model_path, global_step=step)step += 1duration = time() - startTimeprint("total takes time:",duration)summary_writer.close()coord.request_stop()  # 通知线程关闭coord.join(threads)  # 等其他线程关闭这一函数才返回

LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)相关推荐

  1. 基于TensorFlow训练花朵识别模型的源码和Demo

    基于TensorFlow训练花朵识别模型的源码和Demo 转发来源: https://blog.csdn.net/Anymake_ren/article/details/80550684 下面就通过对 ...

  2. 【Tensorflow+自然语言处理+LSTM】搭建智能聊天客服机器人实战(附源码、数据集和演示 超详细)

    需要源码和数据集请点赞关注收藏后评论区留言私信~~~ 一.自然语言处理与智能 自然语言处理技术是智能客服应用的基础,在自然语言处理过程中,首先需要进行分词处理,这个过程通常基于统计学理论,分词的精细化 ...

  3. 深度学习大模型训练--分布式 deepspeed PipeLine Parallelism 源码解析

    deepspeed PipeLine Parallelism 源码解析 basic concept PipeDream abstract 1F1B 4 steps Code comprehension ...

  4. 干货|Pytorch弹性训练极简实现( 附源码)

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨颜挺帅@知乎(已授权) 来源丨https://zhuanlan ...

  5. YOLO-V3-SPP 训练时正样本筛选源码解析之build_targets

    前言 理论详解:YOLO-V3-SPP详细解析 build_targets 讲解形式主要以流程图形式,逐流程详细解读每一行代码 代码以pytorch框架为基础 targets处理整体流程 这里主要介绍 ...

  6. 【LSTM车速预测】基于matlab麻雀算法优化LSTM车速预测(含前后对比)【含Matlab源码 2063期】

    ⛄一.麻雀算法及LSTM简介 1 麻雀算法简介 麻雀搜索算法(Sparrow Search Algorithm, SSA)是于2020年提出的.SSA 主要是受麻雀的觅食行为和反捕食行为的启发而提出的 ...

  7. Winform中实现颜色拾取器获取RGB与16进制颜色程序与源码分享

    场景 效果 实现 关键代码 using System; using System.Collections.Generic; using System.ComponentModel; using Sys ...

  8. 户外体能训练/运动健身小程序源码

    目录 一.整体目录(示范): 文档含项目技术介绍.E-R图.数据字典.项目功能介绍与截图等 二.运行截图 三.代码部分(示范): 四.数据库表(示范): 数据库表有注释,可以导出数据字典及更新数据库时 ...

  9. Tensorflow③ Keras的LSTM和TF的LSTM实现的源码剖析

    最近在做可以转成pb模型的RNN/LSTM层的实现细节分析.经过一些分析,发现了在Keras里面常见的keras.layers.LSTM和Tensorflow的tf.contrib.rnn.LSTMC ...

最新文章

  1. LeetCode 140. 单词拆分 II
  2. 淘宝姐姐不要过滤掉js我们还是好朋友
  3. 使用 CSS3 伪元素实现立体的照片堆叠效
  4. Proguard 部分类不混淆的技巧
  5. 基于php构建APi流程,php – 如何构建一个RESTful API?
  6. MySQL填充字符串函数 LPAD(str,len,padstr),RPAD(str,len,padstr)
  7. cheerio的小案例
  8. 还可以这样玩?揭秘打通线上线下新思路
  9. android安装git插件安装教程,Android Studio 3.6.1上使用gitee
  10. iStack详解(二)——堆叠连接方式堆叠拓扑变动处理
  11. WSS3 Elevation of Privilege 替代 用户身份模拟Impersonate 进行权限提升
  12. centos 雷凌凌ralink无线网卡驱动 安装
  13. 【钛坦白】清华大学李建:深度学习在时空大数据分析中的应用(转载)
  14. 各行业的英语术语(绝对精华 1)
  15. 所有学java的女生都进来看看
  16. iptables结合ipset禁止国外IP进行访问
  17. SpringBoot - Lombok使用详解1(基本介绍、安装配置、var和val)
  18. 解放前端工程师——手把手教你开发自己的自定义列表和自定义表单系列之一缘起
  19. CG基础教程-陈惟老师十二讲笔记
  20. ESP32运行MicroPython通过MQTT上报温湿度到中移OneNET物联网平台(附源码)

热门文章

  1. Golang map源码浅析
  2. 如何分析个股基本面_如何分析个股基本面?个股基本面分析方法
  3. 公众号开发——自动回复功能
  4. 上海星尚传媒主持人刘彦池
  5. 服务器怎么和本地文件同步,本地与服务器文件同步
  6. META-INF/MANIFEST.MF介绍 _
  7. [常微分方程的数值解法系列三] 改进欧拉法(预估校正法)
  8. 四川大学计算机考研课程表,2019年四川大学研究生教学运行作息时间表.doc
  9. 校招面试真题 | 实习生和应届生有什么区别?
  10. AdobeAcrobat DC pro 2019版JH教程