1.导入相应的库

import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import datasets,losses,Sequential,optimizers

2.加载MNIST数据集:

(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
x_train,x_test=x_train.astype('float')/255.0,x_test.astype('float32')/255.0Batch_Size=32
x_train=tf.reshape(x_train,[-1,28,28,1])
x_test=tf.reshape(x_test,[-1,28,28,1])

3.转one_hot编码:

y_train=tf.one_hot(y_train,depth=10).numpy()
y_test=tf.one_hot(y_test,depth=10).numpy()

4.设置全局参数:

EPOCHES=1
batch_Size=32
learning_rate=0.001

5.定义ResNet模型:

#定义每一个ResNetBlock,其中一组中包含多个ResNetBlock,每一个ResNetBlock又包含几层卷积层
class ResNetBlock(tf.keras.Model):def __init__(self,filter_num,stride=1):super(ResNetBlock,self).__init__()self.conv1=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=stride,padding='same')self.bn1=tf.keras.layers.BatchNormalization()self.relu=tf.keras.layers.Activation('relu')self.conv2=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=1,padding='same')self.bn2=tf.keras.layers.BatchNormalization()#当经过卷积层之后,发现经过卷积层之前的shape和经过卷积之后的shape不相同,#可以通过1*1的卷积层将它们的shape转换为一样的,再叠加。相当于当前的X通过乘以一个W,转换为相同的shapeif stride!=1:self.downSample=Sequential([tf.keras.layers.Conv2D(filter_num,kernel_size=[1,1],strides=stride)])#如果相同的话就直接叠加,不需要转换else:self.downSample=lambda x:xdef call(self,inputs,training=None):#通过当前的第一层卷积层out=self.conv1(inputs)out=self.bn1(out)out=self.relu(out)#通过当前的第二层卷积层out=self.conv2(out)out=self.bn2(out)#最后经过一个f(x)+xidentity=self.downSample(inputs)output=tf.keras.layers.add([identity,out])output=tf.nn.relu(output)return output
class ResNet(tf.keras.Model):def __init__(self,layers_num,num_classes=10):super(ResNet,self).__init__()#开始的输入层经过一个3*3,步长为1的卷积层和最大池化层self.stem=Sequential([tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1]),tf.keras.layers.BatchNormalization(),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=[1,1],padding='same') ])#通过第一个组self.layer1=self.build_resblock(64,layers_num[0])#通过第二个组self.layer2=self.build_resblock(128,layers_num[1],stride=2)#通过第三个组self.layer3=self.build_resblock(256,layers_num[2],stride=2)#通过第四个组self.layer4=self.build_resblock(512,layers_num[3],stride=2)#经过全局平均池化层self.avgPool=tf.keras.layers.GlobalAveragePooling2D()#经过最后的全连接层self.fc=tf.keras.layers.Dense(num_classes)#这个函数是在构建一个组中的ResNetBlockdef build_resblock(self,filter_num,blocks,stride=1):res_block=Sequential([ResNetBlock(filter_num,stride)]) for i in range(1,blocks):res_block.add(ResNetBlock(filter_num,stride=1))return res_blockdef call(self,inputs,training=None):x=self.stem(inputs)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)out=self.avgPool(x)output=self.fc(out)return output

6.调用模型:

def resnet18():return ResNet([2,2,2,2])
def resnet34():return ResNet([3,4,6,3])

7.模型编译和优化器选择

model=resnet18()
optimizer=optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer,loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
model.build(input_shape=(None,28,28,1))
model.summary()


8.模型训练

model.fit(x_train,y_train,epochs=EPOCHES,validation_data=(x_test,y_test),verbose=1,batch_size=batch_Size)
model.save_weights('ResNet_34.h5')

9. 使用测试集数据评估误差和准确率

loss,accuracy=model.evaluate(x_test,y_test,batch_size=batch_Size,verbose=1)

TensorFlow中的ResNet残差网络实战(1)相关推荐

  1. TensorFlow中的ResNet残差网络实战(2)

    1.导入相应的库: import os import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from ...

  2. ResNet残差网络及变体详解(符代码实现)

    本文通过分析深度网络模型的缺点引出ResNet残差网络,并介绍了几种变体,最后用代码实现ResNet18. 文章目录 前言 模型退化 残差结构 ResNet网络结构 Pre Activation Re ...

  3. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(二) 目录 (6).数据集划分 (7).训练集增强 (8).装载数据集 (9).初始化模型 (10).模型训练 (1 ...

  4. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...

  5. (pytorch-深度学习系列)ResNet残差网络的理解-学习笔记

    ResNet残差网络的理解 ResNet伴随文章 Deep Residual Learning for Image Recognition 诞生,该文章是MSRA何凯明团队在2015年ImageNet ...

  6. ResNet 残差网络、残差块

    在深度学习中,为了增强模型的学习能力,网络的层数会不断的加深,于此同时,也伴随着一些比较棘手的问题,主要包括: ①模型复杂度上升,网络训练困难 ②出现梯度消失/梯度爆炸问题 ③网络退化,即增加层数并不 ...

  7. ResNet残差网络Pytorch实现——对花的种类进行训练

    ResNet残差网络Pytorch实现--对花的种类进行训练 上一篇:[结合各个残差块] ✌✌✌✌ [目录] ✌✌✌✌ 下一篇:[对花的种类进行单数据预测] 大学生一枚,最近在学习神经网络,写这篇文章 ...

  8. 目标检测学习笔记2——ResNet残差网络学习、ResNet论文解读

    ResNet残差网络学习.ResNet论文解读 一.前言 为什么会提出ResNet? 什么是网络退化现象? 那网络退化现象是什么造成的呢? ResNet要如何解决退化问题? 二.残差模块 三.残差模块 ...

  9. ResNet残差网络

    (二十七)通俗易懂理解--Resnet残差网络 - 梦里寻梦的文章 - 知乎 https://zhuanlan.zhihu.com/p/67860570

最新文章

  1. rails 如何 支持 bootstrap3
  2. HttpServletRequest对象方法的用法(转)
  3. centos7 xfce 中文字体输入法
  4. java泛型_Java核心知识 基础五 JAVA 泛型
  5. 炒股要学会向动物学习
  6. .net程序员写业务代码需要注意的地方
  7. MyBatis3与Spring3的整合配置(初级篇)
  8. java自动行走_java数据结构实现机器人行走
  9. vs2010 c++项目创建简易教程
  10. svn 1.12.0 版本以及汉化包(百度网盘分享--永久有效)
  11. Java面试-重写和重载的规则
  12. x86、ARM和MIPS三种主流芯片架构
  13. 东汉唯物主义哲学家——王充
  14. 浅议化学与社会的关系——兼议绿色化学重要性
  15. Python:使用nltk统计词频并绘制统计图
  16. 揭秘支付宝中的深度学习引擎:xNN
  17. html选项卡出现乱码,html乱码
  18. The requested operation requires elevation问题解决
  19. 134_人人后台管理系统-立即执行定时任务失败(坑)
  20. 智能化改造推动企业生产过程更为精准与高效

热门文章

  1. 定义Serializer序列化器
  2. 数据结构-图-遍历-搜索
  3. MATLAB知识点总结
  4. python里的tplt什么意思 Python的format格式化输出
  5. 干货|深入浅出YOLOv5
  6. ORDNet:为场景分割捕获全范围依赖关系
  7. 基于OpenCV的条形码区域分割
  8. EP936E的IIC
  9. V神再为BCH发声!
  10. java-第十一章-类的无参方法-计算器运算