TensorFlow中的ResNet残差网络实战(2)
1.导入相应的库:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import datasets,losses,Sequential,optimizers
2.加载数据集:
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
x_train=tf.reshape(x_train,[-1,28,28,1])
x_test=tf.reshape(x_test,[-1,28,28,1])
Batch_Size=32
print(np.shape(x_train))
print(np.shape(y_train))
print(np.shape(x_test))
print(np.shape(y_test))
3.构建Datasets数据:
train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train))
test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
def processing(x,y):x=tf.cast(x,dtype=tf.float32)/255.0y=tf.cast(y,dtype=tf.int32)y=tf.one_hot(y,depth=10)return x,y
train_db=train_db.shuffle(10000)
train_db=train_db.batch(batch_Size)
train_db=train_db.map(processing)
train_db=train_db.repeat(1)
test_db=test_db.shuffle(10000)
test_db=test_db.batch(batch_Size)
test_db=test_db.map(processing)
test_db=test_db.repeat(1)
4.设置全局参数:
EPOCHES=1
batch_Size=32
learning_rate=0.001
5.定义ResNet模型:
#定义每一个ResNetBlock,其中一组中包含多个ResNetBlock,每一个ResNetBlock又包含几层卷积层
class ResNetBlock(tf.keras.Model):def __init__(self,filter_num,stride=1):super(ResNetBlock,self).__init__()self.conv1=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=stride,padding='same')self.bn1=tf.keras.layers.BatchNormalization()self.relu=tf.keras.layers.Activation('relu')self.conv2=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=1,padding='same')self.bn2=tf.keras.layers.BatchNormalization()#当经过卷积层之后,发现经过卷积层之前的shape和经过卷积之后的shape不相同,#可以通过1*1的卷积层将它们的shape转换为一样的,再叠加。相当于当前的X通过乘以一个W,转换为相同的shapeif stride!=1:self.downSample=Sequential([tf.keras.layers.Conv2D(filter_num,kernel_size=[1,1],strides=stride)])#如果相同的话就直接叠加,不需要转换else:self.downSample=lambda x:xdef call(self,inputs,training=None):#通过当前的第一层卷积层out=self.conv1(inputs)out=self.bn1(out)out=self.relu(out)#通过当前的第二层卷积层out=self.conv2(out)out=self.bn2(out)#最后经过一个f(x)+xidentity=self.downSample(inputs)output=tf.keras.layers.add([identity,out])output=tf.nn.relu(output)return output
class ResNet(tf.keras.Model):def __init__(self,layers_num,num_classes=10):super(ResNet,self).__init__()#开始的输入层经过一个3*3,步长为1的卷积层和最大池化层self.stem=Sequential([tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1]),tf.keras.layers.BatchNormalization(),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=[1,1],padding='same') ])#通过第一个组self.layer1=self.build_resblock(64,layers_num[0])#通过第二个组self.layer2=self.build_resblock(128,layers_num[1],stride=2)#通过第三个组self.layer3=self.build_resblock(256,layers_num[2],stride=2)#通过第四个组self.layer4=self.build_resblock(512,layers_num[3],stride=2)#经过全局平均池化层self.avgPool=tf.keras.layers.GlobalAveragePooling2D()#经过最后的全连接层self.fc=tf.keras.layers.Dense(num_classes)#这个函数是在构建一个组中的ResNetBlockdef build_resblock(self,filter_num,blocks,stride=1):res_block=Sequential([ResNetBlock(filter_num,stride)]) for i in range(1,blocks):res_block.add(ResNetBlock(filter_num,stride=1))return res_blockdef call(self,inputs,training=None):x=self.stem(inputs)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)out=self.avgPool(x)output=self.fc(out)return output
6.调用模型:
def resnet18():return ResNet([2,2,2,2])
def resnet34():return ResNet([3,4,6,3])
model=resnet18()
model.build(input_shape=(None,28,28,1))
model.summary()
7.模型训练:
optimizer=optimizers.Adam(learning_rate=learning_rate)
losses,acc=[],[]
for step,(x,y) in enumerate(train_db):with tf.GradientTape() as tape:out=model(x)loss=tf.losses.categorical_crossentropy(y,out,from_logits=True)loss=tf.reduce_mean(loss)grads=tape.gradient(loss,model.trainable_variables)optimizer.apply_gradients(zip(grads,model.trainable_variables))if step%100==0:print('The loss: {}'.format(float(loss)))losses.append(float(loss))correct=0
total_correct=0
total=0
for x,y in test_db:out=model(x)pred=tf.argmax(out,axis=1)y_pred=tf.argmax(y,axis=1)correct=tf.equal(pred,y_pred)total_correct+=tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy()total+=x.shape[0]
print('The Test_accuracy:{}'.format(float(total_correct/total)))
acc.append(float(total_correct/total))
The loss: 2.3025636672973633
The loss: 1.3160710334777832
The loss: 0.3155096769332886
The loss: 0.06830280274152756
The loss: 0.07288875430822372
The loss: 0.10025028884410858
The loss: 0.020024564117193222
The loss: 0.1279890239238739
The loss: 0.1865064799785614
The loss: 0.09510926902294159
The loss: 0.08556611090898514
The loss: 0.011156385764479637
The loss: 0.07167764753103256
The loss: 0.06475668400526047
The loss: 0.16705751419067383
The loss: 0.14960885047912598
The loss: 0.004834086634218693
The loss: 0.17640721797943115
The loss: 0.005258279386907816
The Test_accuracy:0.9691
8.画出loss图:
x=[i for i in range(len(losses))]
fig=plt.figure(figsize=(3,4))
plt.plot(x,losses,color='red')
plt.title('Model losses')
plt.xlabel('epoch')
plt.ylabel('losses')
plt.legend()
plt.show()
plt.close()
TensorFlow中的ResNet残差网络实战(2)相关推荐
- TensorFlow中的ResNet残差网络实战(1)
1.导入相应的库 import os import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from t ...
- ResNet残差网络及变体详解(符代码实现)
本文通过分析深度网络模型的缺点引出ResNet残差网络,并介绍了几种变体,最后用代码实现ResNet18. 文章目录 前言 模型退化 残差结构 ResNet网络结构 Pre Activation Re ...
- 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)
使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(二) 目录 (6).数据集划分 (7).训练集增强 (8).装载数据集 (9).初始化模型 (10).模型训练 (1 ...
- 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)
使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...
- (pytorch-深度学习系列)ResNet残差网络的理解-学习笔记
ResNet残差网络的理解 ResNet伴随文章 Deep Residual Learning for Image Recognition 诞生,该文章是MSRA何凯明团队在2015年ImageNet ...
- ResNet 残差网络、残差块
在深度学习中,为了增强模型的学习能力,网络的层数会不断的加深,于此同时,也伴随着一些比较棘手的问题,主要包括: ①模型复杂度上升,网络训练困难 ②出现梯度消失/梯度爆炸问题 ③网络退化,即增加层数并不 ...
- ResNet残差网络Pytorch实现——对花的种类进行训练
ResNet残差网络Pytorch实现--对花的种类进行训练 上一篇:[结合各个残差块] ✌✌✌✌ [目录] ✌✌✌✌ 下一篇:[对花的种类进行单数据预测] 大学生一枚,最近在学习神经网络,写这篇文章 ...
- 目标检测学习笔记2——ResNet残差网络学习、ResNet论文解读
ResNet残差网络学习.ResNet论文解读 一.前言 为什么会提出ResNet? 什么是网络退化现象? 那网络退化现象是什么造成的呢? ResNet要如何解决退化问题? 二.残差模块 三.残差模块 ...
- ResNet残差网络
(二十七)通俗易懂理解--Resnet残差网络 - 梦里寻梦的文章 - 知乎 https://zhuanlan.zhihu.com/p/67860570
最新文章
- 【亲测可用→防止入坑Routes】设置angular10项目异步加载、惰性加载、懒加载路由
- 60篇论文入选,两度夺魁,“史上最难ECCV”商汤再攀高峰
- Error message Exception raised without specific error
- mysql导入的excel更新_excel导入数据库,存在则更新不存在添加
- C#后台调用前台js(RegisterStartupScript)
- java开发者工具开源版_6种开源工具可帮助教育工作者保持井井有条
- hbase Java API 介绍及使用示例
- C#序列化出现“因其保护级别而不可访问。只能处理公共类型。”
- webstorm开发工具找回被误删除的代码
- 2019新闻自动挂机阅读脚本
- 【如何快速的开发一个完整的iOS直播app】(原理篇)
- 单片机中断程序详解(转)
- win10文件服务器ssd当缓存盘,Win10开启写入缓存策略来提高SSD固态硬盘性能
- Excel技能培训之八合并计算,多区域合并计算,分类汇总,展开隐藏列
- 【转帖】财务尽职调查资料收集总结
- 使用自定义的Layer和Cell实现手写汉字生成(Tensorflow2)
- Kubernetes基础:包含多个容器的Pod
- CRM系统能给企业带来什么? CRM系统推荐
- webp图片怎么批量转换成jpg等常用格式
- 2022-09-11-cloud-init
热门文章
- 防火墙产品原理与应用:防火墙产品的技术及实现关于IPV6【寒假】
- 刘强东的代码水平到底有多强?30年前就已破万!
- zabbix系列~ 监控模式
- javaEE之------Spring-----》 AspectJ注解
- 怎么申请微信支付接口
- sql server 经典SQL——分组统计
- BZOJ3262 : 陌上花开
- Oracle10g 回收站及彻底删除table : drop table xx purge
- 局域网怎样自动安装FLASH插件(浏览器不安装flashplayer都可以浏览.swf文件)
- java.lang.RuntimeException: Expected one of local, maven-local, maven-central, scala-tools-releases,