立即学习:https://edu.csdn.net/course/play/24719/279509?utm_source=blogtoedu

目录

一、神经网络训练代码

二、思路总结

1、数据集图片数据、目标值的导入

2、目标值转化为one_hot编码

3、神经网络中样本、目标的输入

4、神经网络的搭建

5、损失的计算与优化

6、批处理的方法

7、准确率的计算

8、模型的保存

三、API总结


一、神经网络训练代码

import tensorflow as tf
import cv2
import glob
import os
import numpy as npdef data_reader(dataset_path):data_list = []label_list = []for cls_path in glob.glob(os.path.join(dataset_path, '*')):for file_name in glob.glob((os.path.join(cls_path, '*'))):img = cv2.imread(file_name)data_list.append(img)label_list.append(int(cls_path[-1]))data_np = np.array(data_list)label_np = np.array(label_list)return data_np, label_npdef shuffle_data(data, label):idx = np.arange(len(data))np.random.shuffle(idx)return data[idx, ...], label[idx, ...]def train(data, label):data_in = tf.placeholder(tf.float32, [None, 100, 100, 3], name="data_in")  # None实际上指的是batch的大小,batch的大小可以在运行时改变label_in = tf.placeholder(tf.float32, [None, 3])  # 准备构造一个one_hot的label,而我们的数据总共有3个类# 一个小的卷积网络,用来处理图片# (如果数据量大,需要构造一个大一点的神经网络,可以在卷积网络部分复制一下,就可以构造一个比较深的神经网络)# (但是神经网络并不是越深越好,可能会涉及过拟合、梯度消失等问题)# (目前网络参数的取值,主要还是根据数据集特征和大小,还有靠我们的直觉的经验。)# (所以神经网络也叫作当代炼金术)out = tf.layers.conv2d(data_in, 4, 3, padding='same')  # out形状为[?, 100, 100, 4]out = tf.layers.max_pooling2d(out, 2, 2, padding='same')  # out形状为[?, 50, 50, 4]out = tf.nn.relu(out)  # out形状为[?, 50, 50, 4]# 把提取出来的特征压扁成一个一维数组out = tf.reshape(out, (-1, int(np.prod(out.get_shape()[1:]))))  # out形状为[?, 10000]# 送入全连接层out = tf.layers.dense(out, 2000, activation=tf.nn.relu)# 再加一个全连接层(按理说,我们数据量很少,应该会过拟合,也就数准确率接近100%)out = tf.layers.dense(out, 256, activation=tf.nn.relu)# 得到输出(输出的分类数要和输入的label的分类数一样,不然会报错)pred = tf.layers.dense(out, 3)# 将one_hot变回值0、1、2out_label = tf.argmax(pred, 1, name="output")# 计算交叉熵损失loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=label_in, logits=pred))# 梯度下降、优化损失train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)# 初始化变量的opinit_op = tf.initialize_all_variables()# 定义一些参数batch_size = 16# 进行会话with tf.Session() as sess:sess.run(init_op)# 进行多次训练for epoch in range(50):# (在训练的时候,我们不希望每次取到的数据的顺序都是一样的,这样很容易导致过拟合)# (因此我们一般会加一个shuffle,也就是在输入到网络前,把数据顺序打乱)datas, labels = shuffle_data(data, label)# 准确率统计total_loss = 0avg_accuracy = 0# 按批次训练num_batch = len(data) // batch_sizefor batch_idx in range(num_batch):# 计算每个batch开始和结束时候的下标start_idx = batch_idx * batch_sizeend_idx = (batch_idx + 1) * batch_size# 准备好输入层的数据# (...代表省略后面的形状)feed_dict = {data_in: datas[start_idx: end_idx, ...],label_in: labels[start_idx: end_idx, ...]}# 计算准确率correct_prediction = tf.equal(out_label, tf.argmax(label_in, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 进行训练_, loss_out, acc = sess.run([train_op, loss, accuracy], feed_dict=feed_dict)# 准确率统计total_loss += loss_outavg_accuracy += acc# 统计每批次准确率print("avg_accuracy", avg_accuracy / num_batch)# 准确率够高则保存模型if avg_accuracy / num_batch > 0.94:saver = tf.train.Saver()saver.save(sess, './model/model')breakdef dense_to_one_hot(label, num_class):num_label = label.shape[0]index_offset = np.arange(num_label) * num_classlabel_one_hot = np.zeros((num_label, num_class))label_one_hot.flat[index_offset + label.ravel()] = 1return label_one_hotif __name__ == '__main__':dataset_path = "./dataset"data, label = data_reader(dataset_path)one_hot_label = dense_to_one_hot(label, 3)train(data, one_hot_label)

二、思路总结

1、数据集图片数据、目标值的导入

图片数据

首先,用cv读取图片,将图片一个一个添加到一个list里面去。

其次,用np(numpy)将图片的列表转化为np数组。

图片目标值

首先,由于图片被按照目录分类

在读取图片时,每读取一张图,就将图片的上级目录的名字作为一个目标值。(这里是0、1、2)

2、目标值转化为one_hot编码

步骤如下:

  1. 原本一个样本只有一个目标值

  2. 一共600个样本,目标值总的可能性有3种(某种目标的one_hot编码就只可能是[1, 0, 0]、[0, 1, 0]、[0, 0, 1])

  3. 先生成好一个全零的、形状为(600, 3)的np数组

    【手把手带你入门深度学习之150行代码的汉字识别系统】学习笔记 ·002 训练神经网络相关推荐

    1. 【手把手带你入门深度学习之150行代码的汉字识别系统】学习笔记 ·001 用OpenCV制作数据集

      立即学习:https://edu.csdn.net/course/play/24719/279505?utm_source=blogtoedu 目录 一.制作数据集代码 二.思路总结 1.数据集目录的 ...

    2. 【手把手带你入门深度学习之150行代码的汉字识别系统】学习笔记 ·003 用训练模型进行预测

      立即学习:https://edu.csdn.net/course/play/24719/279510?utm_source=blogtoedu 目录 一.用训练模型进行预测代码 二.思路总结 1.模型 ...

    3. 手把手带你入门深度学习(一):保姆级Anaconda和PyTorch环境配置指南

      手把手带你入门深度学习(一):保姆级Anaconda和PyTorch环境配置指南 一. 前言和准备工作 1.1 python.anaconda和pytorch的关系 二. Anconda安装 2.1 ...

    4. linux 中国-新手村,从新手村开始,手把手带你入门梳理内核代码

      原标题:从新手村开始,手把手带你入门梳理内核代码 在上一期内容中,Java离Linux内核有多远? 我们介绍了从 JVM 到内核的编译原理,告诉大家应用和系统工程师如何接触到内核. 本文将 从一个简单 ...

    5. RPA之家手把手带你入门Blue Prism教程系列4_认识Blue Prism的界面

      RPA之家手把手带你入门Blue Prism 1. Home & Analytics 2. Studio 2.1 Process 2.2 Object 2.3 Process与Object的关 ...

    6. RPA之家手把手带你入门Blue Prism教程系列7_深入了解Data Item

      RPA之家手把手带你入门Blue Prism 1. Data Item类型 2. Data Item的表现形式 2.1 Environment Variable(环境变量) 2.2 Session V ...

    7. RPA之家手把手带你入门Blue Prism教程系列 -汇总

      RPA之家手把手带你入门Blue Prism 基础篇 -本文章由RPA之家(rpazj.com)提供, 学习交流群QQ群465620839 微信交流群: 基础篇 RPA之家手把手带你入门Blue Pr ...

    8. RPA之家手把手带你入门Blue Prism教程系列3_如何新建用户和配置数据库

      RPA之家手把手带你入门Blue Prism 创建用户 第一步:寻找Security标签下的Users 第二步:配置Users 配置数据库 第一步:新建数据库 第二步:配置数据库 -本文章由RPA之家 ...

    9. RPA之家手把手带你入门Blue Prism教程系列1_如何申请Blue Prism免费试用版

      RPA之家手把手带你入门Blue Prism Blue Prism 免费试用版 第一步:申请一个BluePrism Portal账号 第二步:在DX网站申请一个测试license 第三步:申请成功后, ...

    最新文章

    1. cocos2d 0.99.5版本屏幕默认是横屏,怎么修改为竖屏呢?
    2. TensorFlow深度学习算法原理与编程实战 人工智能机器学习技术丛书
    3. Fiddler抓取https的设置
    4. linux下curl指令常见使用
    5. Educational Codeforces Round 32 G. Xor-MST 01tire + 分治 + Boruvka
    6. 离职通知邮件主题写什么好_(原创)拿到了企业的offer后要注意什么?
    7. 数据结构之基于Java的顺序栈实现
    8. android传感器开发与智能设备案例实战_【我的物联网成长记2】设备如何进行选型?...
    9. java 按回车键查询
    10. 免费学习编程-值得收藏
    11. Python -- 关于函数的学习(五) — 传递任意数量的实参
    12. ArcGIS 把字段允许空值设为否
    13. 力扣刷题 DAY_72 回溯
    14. learning and evaluating representations for deep one-class classification
    15. 网络信息安全的重要性
    16. 蚂蚁金服张洁:基于深度学习的支付宝人脸识别技术解秘
    17. 本以为java语言很难学,其实就学完下面这些知识,就能理解了
    18. linux下/proc/sysrq-trigger详解
    19. MetaQ 简单使用(数据同步框架)
    20. 基于Java的Minecraft游戏后端自定义插件 06绘制简单粒子特效与BukkitRunable定时器

    热门文章

    1. go mysql es 不要分词_ElasticSearch踩坑记录-Go语言中文社区
    2. python32位系统下载_pythonwin下载-PythonWin 32位(Python集成开发环境) 3.6 官方版 - 河东下载站...
    3. python多线程异步爬虫-Python异步爬虫试验[Celery,gevent,requests]
    4. 不添加外键能关联查询_SpringDataJPA关联关系
    5. 人脸关键点: DCNN-Deep Convolutional Network Cascade for Facial Point Detection
    6. 机器学习从入门到进阶✅
    7. ShuffleNetv2的学习笔记
    8. Cross-Modal Retrieval——为什么要使用GAN呢?
    9. XGBoost深度理解
    10. Python import容易犯的一个错误