import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import Sequential,losses,optimizers,layers,datasets
lr=0.01
batchsz=512network=Sequential([#网络容器layers.Conv2D(6,kernel_size=3,strides=1),#第1个卷积层,6个3*3卷积核layers.BatchNormalization(),#插入BN层layers.MaxPooling2D(pool_size=2,strides=2),#高宽各减半的池化层layers.ReLU(),#激活函数layers.Conv2D(16,kernel_size=3,strides=1),#第2个卷积层,16个3*3卷积核layers.BatchNormalization(),#插入BN层layers.MaxPooling2D(pool_size=2,strides=2),#高宽各减半的池化层layers.ReLU(),#激活函数layers.Flatten(),#打平层,方便全连接层处理layers.Dense(120,activation='relu'),#全连接层,120个节点layers.BatchNormalization(),#插入BN层layers.Dense(84,activation='relu'),#全连接层,84个节点layers.BatchNormalization(),#插入BN层layers.Dense(10)#10个节点
])network.build(input_shape=(4,28,28,1))
#network.summary()
(x_train,y_train),(x_test,y_test)=datasets.mnist.load_data()#加载数据集
train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train))#构建Dataset数据集
train_db=train_db.shuffle(10000)#随机打散,防止记忆化
train_db=train_db.batch(batchsz)#批量
def propress(x,y):#预处理函数x=tf.cast(x,dtype=tf.float32)/255.x=tf.reshape(x,[-1,28,28])y=tf.cast(y,dtype=tf.int32)y=tf.one_hot(y,depth=10)return x,y
train_db=train_db.map(propress)
train_db=train_db.repeat(20)
#tf.keras.optimizers.SGD(learning_rate=5e-4) 声明了一个梯度下降 优化器 (Optimizer)
optimizer=optimizers.SGD(learning_rate=lr)
#使用交叉熵损失函数
criteon=losses.CategoricalCrossentropy(from_logits=True)test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db=test_db.shuffle(10000)
test_db=test_db.batch(batchsz)
def propress(x,y):x=tf.cast(x,dtype=tf.float32)/255.x=tf.reshape(x,[-1,28,28])y=tf.cast(y,dtype=tf.int32)y=tf.one_hot(y,depth=10)return x,y
test_db=test_db.map(propress)
test_db=test_db.repeat(20)losses=[]
acc=[]for step,(x,y) in enumerate(train_db):with tf.GradientTape() as tape:x=tf.expand_dims(x,axis=3)#设置网络参数的参数training=True区分BN层是训练还是测试模型out=network(x,training=True)loss=criteon(y,out)#前面已经进行了one_hot编码的转化grades=tape.gradient(loss,network.trainable_variables)optimizer.apply_gradients(zip(grades,network.trainable_variables))if step%100 ==0:print(step,'loss:{}'.format(float(loss)))losses.append(float(loss))correct,total=0,0if step%100==0:for x,y in test_db:x=tf.expand_dims(x,axis=3)#(128, 28, 28)#设置网络参数的参数training=False避免BN层采用错误的行为out=network(x,training=False)#(128,10)pred=tf.argmax(out,axis=1)#(128,)y_test=tf.argmax(y,axis=1)correct+=tf.reduce_sum(tf.cast(tf.equal(pred,y_test),dtype=tf.int32)).numpy()total+=x.shape[0]print(step,"test_acc:{}".format(float(correct/total)))acc.append(correct/total)plt.figure()
x=[i*5 for i in range(len(losses))]
plt.plot(x,losses,color='C0',marker='s',label='训练')
plt.xlabel('step')
plt.ylabel('losses')
plt.show()plt.plot(x,acc,color='C0',marker='s',label='c测试')
plt.xlabel('step')
plt.ylabel('acc')
plt.show()
0 loss:2.755319118499756
0 test_acc:0.1306
100 loss:0.45182928442955017
100 test_acc:0.2676
200 loss:0.33671215176582336
200 test_acc:0.4734
300 loss:0.2975693643093109
300 test_acc:0.7508
400 loss:0.21543249487876892
400 test_acc:0.9093
500 loss:0.2360231578350067
500 test_acc:0.9453
600 loss:0.17904561758041382
600 test_acc:0.9547
700 loss:0.1409425437450409
700 test_acc:0.9591
800 loss:0.1463257521390915
800 test_acc:0.9625
900 loss:0.12103059887886047
900 test_acc:0.9632
1000 loss:0.14103813469409943
1000 test_acc:0.9664
1100 loss:0.11426682770252228
1100 test_acc:0.9665
1200 loss:0.1260887235403061
1200 test_acc:0.9682
1300 loss:0.08206073939800262
1300 test_acc:0.9692
1400 loss:0.12712940573692322
1400 test_acc:0.9708
1500 loss:0.11015598475933075
1500 test_acc:0.9706
1600 loss:0.09226851165294647
1600 test_acc:0.9723
1700 loss:0.11278203129768372
1700 test_acc:0.9726
1800 loss:0.08437538146972656
1800 test_acc:0.9733
1900 loss:0.0878986120223999
1900 test_acc:0.9749
2000 loss:0.0847143679857254
2000 test_acc:0.9756
2100 loss:0.08022576570510864
2100 test_acc:0.9763
2200 loss:0.10397559404373169
2200 test_acc:0.9763
2300 loss:0.07622624933719635
2300 test_acc:0.9775


tensorflow中的BN层实现相关推荐

  1. caffe中的batchNorm层(caffe 中为什么bn层要和scale层一起使用)

    caffe中的batchNorm层 链接: http://blog.csdn.net/wfei101/article/details/78449680 caffe 中为什么bn层要和scale层一起使 ...

  2. Tensorflow训练和预测中的BN层的坑(转载)

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  3. Tensorflow训练和预测中的BN层的坑(转)-训练和测试差异性巨大

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  4. pytorch 模型中的bn层一键转化为同步bn(syncbn)

    pytorch 将模型中的所有BatchNorm2d layer转换为SyncBatchNorm layer: (单机多卡设置下) import torch.distributed as distdi ...

  5. 狠补基础-数学+算法角度讲解卷积层,激活函数,池化层,Dropout层,BN层,全链接层

    狠补基础-数学+算法角度讲解卷积层,激活函数,池化层,Dropout层,BN层,全链接层 在这篇文章中您将会从数学和算法两个角度去重新温习一下卷积层,激活函数,池化层,Dropout层,BN层,全链接 ...

  6. 模型压缩(一)通道剪枝-BN层

    论文:https://arxiv.org/pdf/1708.06519.pdf BN层中缩放因子γ与卷积层中的每个通道关联起来.在训练过程中对这些比例因子进行稀疏正则化,以自动识别不重要的通道.缩放因 ...

  7. 网络骨架:Backbone(神经网络基本组成——BN层、全连接层)

    BN层 为了追求更高的性能,卷积网络被设计得越来越深,然而网络却变得难以训练收敛与调参.原因在于,浅层参数的微弱变化经过多层线性变化与激活函数后会被放大,改变了每一层的输入分布,造成深层的网络需要不断 ...

  8. TF之BN:BN算法对多层中的每层神经网络加快学习QuadraticFunction_InputData+Histogram+BN的Error_curve

    TF之BN:BN算法对多层中的每层神经网络加快学习QuadraticFunction_InputData+Histogram+BN的Error_curve 目录 输出结果 代码设计 输出结果 代码设计 ...

  9. Pytorch中BN层入门思想及实现

    批归一化层-BN层(Batch Normalization) 作用及影响: 直接作用:对输入BN层的张量进行数值归一化,使其成为均值为零,方差为一的张量. 带来影响: 1.使得网络更加稳定,结果不容易 ...

最新文章

  1. 【设计模式】迪米特法则和六种原则的总结
  2. 通过案例对SparkStreaming透彻理解-3
  3. C语言高级编程:数组和指针作为函数形参
  4. Mysql事务,并发问题,锁机制-- 幻读、不可重复读--专题
  5. nginx php 没认,NginX没有执行PHP
  6. 从另一个视角看待逻辑回归
  7. 整理了70个Python实战项目列表,都有完整且详细的教程
  8. Linux源码安装PHP7.3.1
  9. java tfidf_Hanlp分词实例:Java实现TFIDF算法
  10. Spiral Matrix(Medium)
  11. 汉字转拼音源码的两个类
  12. Request模块实战04 ---- 爬取豆瓣电影排行榜
  13. My summery
  14. 为什么你还没有买新能源汽车? 1
  15. 电动汽车热管理粘合剂和密封剂市场现状及未来发展趋势
  16. biosequence analysis using profile hidden Markov models(使用隐马尔可夫模型分析序列)
  17. 国内云服务地域选择和测速
  18. 面试总结-----工程化软件项目开发的流程、步骤
  19. 非银行支付机构网络支付业务管理办法对第三方支付账户的影响
  20. Android作为HTTP服务器--NanoHTTPD源码分析

热门文章

  1. 自然语言处理(NLP)之从文本中提取时间
  2. python中列表、字典和集合推导式
  3. 我的第一个VUE示例
  4. BIG T 下学期选修_python作业
  5. 洛奇6里很喜欢的一段话!洛奇6经典台词!而是你能挨多重,并且坚持向前,你能承受多少并且坚持向前,这样才叫胜利!
  6. Nginx配置文件nginx.conf中文详解(总结)
  7. C语言找最大的int型数!_只愿与一人十指紧扣_新浪博客
  8. C语言的数顺序输出与反序输出_只愿与一人十指紧扣_新浪博客
  9. 基于 CNN 特征区域进行目标检测
  10. 【T07】不要低估tcp的性能