1.导入相应的库:

import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import datasets,losses,Sequential,optimizers

2.加载数据集

(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
x_train=tf.reshape(x_train,[-1,28,28,1])
x_test=tf.reshape(x_test,[-1,28,28,1])
Batch_Size=32
print(np.shape(x_train))
print(np.shape(y_train))
print(np.shape(x_test))
print(np.shape(y_test))


3.构建Datasets数据:

train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train))
test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
def processing(x,y):x=tf.cast(x,dtype=tf.float32)/255.0y=tf.cast(y,dtype=tf.int32)y=tf.one_hot(y,depth=10)return x,y
train_db=train_db.shuffle(10000)
train_db=train_db.batch(batch_Size)
train_db=train_db.map(processing)
train_db=train_db.repeat(1)
test_db=test_db.shuffle(10000)
test_db=test_db.batch(batch_Size)
test_db=test_db.map(processing)
test_db=test_db.repeat(1)

4.设置全局参数:

EPOCHES=1
batch_Size=32
learning_rate=0.001

5.定义ResNet模型:

#定义每一个ResNetBlock,其中一组中包含多个ResNetBlock,每一个ResNetBlock又包含几层卷积层
class ResNetBlock(tf.keras.Model):def __init__(self,filter_num,stride=1):super(ResNetBlock,self).__init__()self.conv1=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=stride,padding='same')self.bn1=tf.keras.layers.BatchNormalization()self.relu=tf.keras.layers.Activation('relu')self.conv2=tf.keras.layers.Conv2D(filter_num,kernel_size=[3,3],strides=1,padding='same')self.bn2=tf.keras.layers.BatchNormalization()#当经过卷积层之后,发现经过卷积层之前的shape和经过卷积之后的shape不相同,#可以通过1*1的卷积层将它们的shape转换为一样的,再叠加。相当于当前的X通过乘以一个W,转换为相同的shapeif stride!=1:self.downSample=Sequential([tf.keras.layers.Conv2D(filter_num,kernel_size=[1,1],strides=stride)])#如果相同的话就直接叠加,不需要转换else:self.downSample=lambda x:xdef call(self,inputs,training=None):#通过当前的第一层卷积层out=self.conv1(inputs)out=self.bn1(out)out=self.relu(out)#通过当前的第二层卷积层out=self.conv2(out)out=self.bn2(out)#最后经过一个f(x)+xidentity=self.downSample(inputs)output=tf.keras.layers.add([identity,out])output=tf.nn.relu(output)return output
class ResNet(tf.keras.Model):def __init__(self,layers_num,num_classes=10):super(ResNet,self).__init__()#开始的输入层经过一个3*3,步长为1的卷积层和最大池化层self.stem=Sequential([tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1]),tf.keras.layers.BatchNormalization(),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=[1,1],padding='same') ])#通过第一个组self.layer1=self.build_resblock(64,layers_num[0])#通过第二个组self.layer2=self.build_resblock(128,layers_num[1],stride=2)#通过第三个组self.layer3=self.build_resblock(256,layers_num[2],stride=2)#通过第四个组self.layer4=self.build_resblock(512,layers_num[3],stride=2)#经过全局平均池化层self.avgPool=tf.keras.layers.GlobalAveragePooling2D()#经过最后的全连接层self.fc=tf.keras.layers.Dense(num_classes)#这个函数是在构建一个组中的ResNetBlockdef build_resblock(self,filter_num,blocks,stride=1):res_block=Sequential([ResNetBlock(filter_num,stride)]) for i in range(1,blocks):res_block.add(ResNetBlock(filter_num,stride=1))return res_blockdef call(self,inputs,training=None):x=self.stem(inputs)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)out=self.avgPool(x)output=self.fc(out)return output

6.调用模型:

def resnet18():return ResNet([2,2,2,2])
def resnet34():return ResNet([3,4,6,3])
model=resnet18()
model.build(input_shape=(None,28,28,1))
model.summary()


7.模型训练:

optimizer=optimizers.Adam(learning_rate=learning_rate)
losses,acc=[],[]
for step,(x,y) in enumerate(train_db):with tf.GradientTape() as tape:out=model(x)loss=tf.losses.categorical_crossentropy(y,out,from_logits=True)loss=tf.reduce_mean(loss)grads=tape.gradient(loss,model.trainable_variables)optimizer.apply_gradients(zip(grads,model.trainable_variables))if step%100==0:print('The loss: {}'.format(float(loss)))losses.append(float(loss))correct=0
total_correct=0
total=0
for x,y in test_db:out=model(x)pred=tf.argmax(out,axis=1)y_pred=tf.argmax(y,axis=1)correct=tf.equal(pred,y_pred)total_correct+=tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy()total+=x.shape[0]
print('The Test_accuracy:{}'.format(float(total_correct/total)))
acc.append(float(total_correct/total))
The loss: 2.3025636672973633
The loss: 1.3160710334777832
The loss: 0.3155096769332886
The loss: 0.06830280274152756
The loss: 0.07288875430822372
The loss: 0.10025028884410858
The loss: 0.020024564117193222
The loss: 0.1279890239238739
The loss: 0.1865064799785614
The loss: 0.09510926902294159
The loss: 0.08556611090898514
The loss: 0.011156385764479637
The loss: 0.07167764753103256
The loss: 0.06475668400526047
The loss: 0.16705751419067383
The loss: 0.14960885047912598
The loss: 0.004834086634218693
The loss: 0.17640721797943115
The loss: 0.005258279386907816
The Test_accuracy:0.9691

8.画出loss图:

x=[i for i in range(len(losses))]
fig=plt.figure(figsize=(3,4))
plt.plot(x,losses,color='red')
plt.title('Model losses')
plt.xlabel('epoch')
plt.ylabel('losses')
plt.legend()
plt.show()
plt.close()

TensorFlow中的ResNet残差网络实战(2)相关推荐

  1. TensorFlow中的ResNet残差网络实战(1)

    1.导入相应的库 import os import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from t ...

  2. ResNet残差网络及变体详解(符代码实现)

    本文通过分析深度网络模型的缺点引出ResNet残差网络,并介绍了几种变体,最后用代码实现ResNet18. 文章目录 前言 模型退化 残差结构 ResNet网络结构 Pre Activation Re ...

  3. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(二) 目录 (6).数据集划分 (7).训练集增强 (8).装载数据集 (9).初始化模型 (10).模型训练 (1 ...

  4. 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)

    使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...

  5. (pytorch-深度学习系列)ResNet残差网络的理解-学习笔记

    ResNet残差网络的理解 ResNet伴随文章 Deep Residual Learning for Image Recognition 诞生,该文章是MSRA何凯明团队在2015年ImageNet ...

  6. ResNet 残差网络、残差块

    在深度学习中,为了增强模型的学习能力,网络的层数会不断的加深,于此同时,也伴随着一些比较棘手的问题,主要包括: ①模型复杂度上升,网络训练困难 ②出现梯度消失/梯度爆炸问题 ③网络退化,即增加层数并不 ...

  7. ResNet残差网络Pytorch实现——对花的种类进行训练

    ResNet残差网络Pytorch实现--对花的种类进行训练 上一篇:[结合各个残差块] ✌✌✌✌ [目录] ✌✌✌✌ 下一篇:[对花的种类进行单数据预测] 大学生一枚,最近在学习神经网络,写这篇文章 ...

  8. 目标检测学习笔记2——ResNet残差网络学习、ResNet论文解读

    ResNet残差网络学习.ResNet论文解读 一.前言 为什么会提出ResNet? 什么是网络退化现象? 那网络退化现象是什么造成的呢? ResNet要如何解决退化问题? 二.残差模块 三.残差模块 ...

  9. ResNet残差网络

    (二十七)通俗易懂理解--Resnet残差网络 - 梦里寻梦的文章 - 知乎 https://zhuanlan.zhihu.com/p/67860570

最新文章

  1. 【亲测可用→防止入坑Routes】设置angular10项目异步加载、惰性加载、懒加载路由
  2. 60篇论文入选,两度夺魁,“史上最难ECCV”商汤再攀高峰
  3. Error message Exception raised without specific error
  4. mysql导入的excel更新_excel导入数据库,存在则更新不存在添加
  5. C#后台调用前台js(RegisterStartupScript)
  6. java开发者工具开源版_6种开源工具可帮助教育工作者保持井井有条
  7. hbase Java API 介绍及使用示例
  8. C#序列化出现“因其保护级别而不可访问。只能处理公共类型。”
  9. webstorm开发工具找回被误删除的代码
  10. 2019新闻自动挂机阅读脚本
  11. 【如何快速的开发一个完整的iOS直播app】(原理篇)
  12. 单片机中断程序详解(转)
  13. win10文件服务器ssd当缓存盘,Win10开启写入缓存策略来提高SSD固态硬盘性能
  14. Excel技能培训之八合并计算,多区域合并计算,分类汇总,展开隐藏列
  15. 【转帖】财务尽职调查资料收集总结
  16. 使用自定义的Layer和Cell实现手写汉字生成(Tensorflow2)
  17. Kubernetes基础:包含多个容器的Pod
  18. CRM系统能给企业带来什么? CRM系统推荐
  19. webp图片怎么批量转换成jpg等常用格式
  20. 2022-09-11-cloud-init

热门文章

  1. 防火墙产品原理与应用:防火墙产品的技术及实现关于IPV6【寒假】
  2. 刘强东的代码水平到底有多强?30年前就已破万!
  3. zabbix系列~ 监控模式
  4. javaEE之------Spring-----》 AspectJ注解
  5. 怎么申请微信支付接口
  6. sql server 经典SQL——分组统计
  7. BZOJ3262 : 陌上花开
  8. Oracle10g 回收站及彻底删除table : drop table xx purge
  9. 局域网怎样自动安装FLASH插件(浏览器不安装flashplayer都可以浏览.swf文件)
  10. java.lang.RuntimeException: Expected one of local, maven-local, maven-central, scala-tools-releases,