TensorFlow中的ResNet残差网络实战(1)
1.导入相应的库
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import datasets,losses,Sequential,optimizers
2.加载MNIST数据集:
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
x_train,x_test=x_train.astype('float')/255.0,x_test.astype('float32')/255.0Batch_Size=32
x_train=tf.reshape(x_train,[-1,28,28,1])
x_test=tf.reshape(x_test,[-1,28,28,1])
3.转one_hot编码:
y_train=tf.one_hot(y_train,depth=10).numpy()
y_test=tf.one_hot(y_test,depth=10).numpy()
4.设置全局参数:
EPOCHES=1
batch_Size=32
learning_rate=0.001
5.定义ResNet模型:
#定义每一个ResNetBlock,其中一组中包含多个ResNetBlock,每一个ResNetBlock又包含几层卷积层
class ResNetBlock(tf.keras.Model):def __init__(self,filter_num,stride=1):super(ResNetBlock,self).__init__()self.conv1=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=stride,padding='same')self.bn1=tf.keras.layers.BatchNormalization()self.relu=tf.keras.layers.Activation('relu')self.conv2=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=1,padding='same')self.bn2=tf.keras.layers.BatchNormalization()#当经过卷积层之后,发现经过卷积层之前的shape和经过卷积之后的shape不相同,#可以通过1*1的卷积层将它们的shape转换为一样的,再叠加。相当于当前的X通过乘以一个W,转换为相同的shapeif stride!=1:self.downSample=Sequential([tf.keras.layers.Conv2D(filter_num,kernel_size=[1,1],strides=stride)])#如果相同的话就直接叠加,不需要转换else:self.downSample=lambda x:xdef call(self,inputs,training=None):#通过当前的第一层卷积层out=self.conv1(inputs)out=self.bn1(out)out=self.relu(out)#通过当前的第二层卷积层out=self.conv2(out)out=self.bn2(out)#最后经过一个f(x)+xidentity=self.downSample(inputs)output=tf.keras.layers.add([identity,out])output=tf.nn.relu(output)return output
class ResNet(tf.keras.Model):def __init__(self,layers_num,num_classes=10):super(ResNet,self).__init__()#开始的输入层经过一个3*3,步长为1的卷积层和最大池化层self.stem=Sequential([tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1]),tf.keras.layers.BatchNormalization(),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=[1,1],padding='same') ])#通过第一个组self.layer1=self.build_resblock(64,layers_num[0])#通过第二个组self.layer2=self.build_resblock(128,layers_num[1],stride=2)#通过第三个组self.layer3=self.build_resblock(256,layers_num[2],stride=2)#通过第四个组self.layer4=self.build_resblock(512,layers_num[3],stride=2)#经过全局平均池化层self.avgPool=tf.keras.layers.GlobalAveragePooling2D()#经过最后的全连接层self.fc=tf.keras.layers.Dense(num_classes)#这个函数是在构建一个组中的ResNetBlockdef build_resblock(self,filter_num,blocks,stride=1):res_block=Sequential([ResNetBlock(filter_num,stride)]) for i in range(1,blocks):res_block.add(ResNetBlock(filter_num,stride=1))return res_blockdef call(self,inputs,training=None):x=self.stem(inputs)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)out=self.avgPool(x)output=self.fc(out)return output
6.调用模型:
def resnet18():return ResNet([2,2,2,2])
def resnet34():return ResNet([3,4,6,3])
7.模型编译和优化器选择
model=resnet18()
optimizer=optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer,loss=losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
model.build(input_shape=(None,28,28,1))
model.summary()
8.模型训练
model.fit(x_train,y_train,epochs=EPOCHES,validation_data=(x_test,y_test),verbose=1,batch_size=batch_Size)
model.save_weights('ResNet_34.h5')
9. 使用测试集数据评估误差和准确率
loss,accuracy=model.evaluate(x_test,y_test,batch_size=batch_Size,verbose=1)
TensorFlow中的ResNet残差网络实战(1)相关推荐
- TensorFlow中的ResNet残差网络实战(2)
1.导入相应的库: import os import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from ...
- ResNet残差网络及变体详解(符代码实现)
本文通过分析深度网络模型的缺点引出ResNet残差网络,并介绍了几种变体,最后用代码实现ResNet18. 文章目录 前言 模型退化 残差结构 ResNet网络结构 Pre Activation Re ...
- 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)
使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(二) 目录 (6).数据集划分 (7).训练集增强 (8).装载数据集 (9).初始化模型 (10).模型训练 (1 ...
- 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)
使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...
- (pytorch-深度学习系列)ResNet残差网络的理解-学习笔记
ResNet残差网络的理解 ResNet伴随文章 Deep Residual Learning for Image Recognition 诞生,该文章是MSRA何凯明团队在2015年ImageNet ...
- ResNet 残差网络、残差块
在深度学习中,为了增强模型的学习能力,网络的层数会不断的加深,于此同时,也伴随着一些比较棘手的问题,主要包括: ①模型复杂度上升,网络训练困难 ②出现梯度消失/梯度爆炸问题 ③网络退化,即增加层数并不 ...
- ResNet残差网络Pytorch实现——对花的种类进行训练
ResNet残差网络Pytorch实现--对花的种类进行训练 上一篇:[结合各个残差块] ✌✌✌✌ [目录] ✌✌✌✌ 下一篇:[对花的种类进行单数据预测] 大学生一枚,最近在学习神经网络,写这篇文章 ...
- 目标检测学习笔记2——ResNet残差网络学习、ResNet论文解读
ResNet残差网络学习.ResNet论文解读 一.前言 为什么会提出ResNet? 什么是网络退化现象? 那网络退化现象是什么造成的呢? ResNet要如何解决退化问题? 二.残差模块 三.残差模块 ...
- ResNet残差网络
(二十七)通俗易懂理解--Resnet残差网络 - 梦里寻梦的文章 - 知乎 https://zhuanlan.zhihu.com/p/67860570
最新文章
- rails 如何 支持 bootstrap3
- HttpServletRequest对象方法的用法(转)
- centos7 xfce 中文字体输入法
- java泛型_Java核心知识 基础五 JAVA 泛型
- 炒股要学会向动物学习
- .net程序员写业务代码需要注意的地方
- MyBatis3与Spring3的整合配置(初级篇)
- java自动行走_java数据结构实现机器人行走
- vs2010 c++项目创建简易教程
- svn 1.12.0 版本以及汉化包(百度网盘分享--永久有效)
- Java面试-重写和重载的规则
- x86、ARM和MIPS三种主流芯片架构
- 东汉唯物主义哲学家——王充
- 浅议化学与社会的关系——兼议绿色化学重要性
- Python:使用nltk统计词频并绘制统计图
- 揭秘支付宝中的深度学习引擎:xNN
- html选项卡出现乱码,html乱码
- The requested operation requires elevation问题解决
- 134_人人后台管理系统-立即执行定时任务失败(坑)
- 智能化改造推动企业生产过程更为精准与高效