文章目录

  • 1.介绍
  • 2.CIFAR100实战

1.介绍


2.CIFAR100实战

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'# 解决了UnknownError: Failed to get convolution algorithm. This is probably
# because cuDNN failed to initialize, so try looking to see if a warning log
#  message was printed above. [Op:Conv2D]
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)# 10层的卷积与3层的全连接层
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,datasets,Sequentialtf.random.set_seed(2345)# 组建Sequential的list
conv_layers = [# 5 units of conv + max pooling# unit 1layers.Conv2D(64,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.Conv2D(64,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),# unit 2layers.Conv2D(128,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.Conv2D(128,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),# unit 3layers.Conv2D(256,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.Conv2D(256,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),# unit 4layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),# unit 5layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.Conv2D(512,kernel_size=[3,3],padding='same',activation=tf.nn.relu),layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same')
]# 预处理函数
def preprocess(x,y):# [0-1]x = tf.cast(x,dtype=tf.float32) /255.y = tf.cast(y,dtype=tf.int32)return x,y# 加载数据集
(x,y),(x_test,y_test) = datasets.cifar100.load_data()
# 消去y的1维度
y = tf.squeeze(y,axis=1)
y_test = tf.squeeze(y_test,axis=1)
print(x.shape,y.shape,x_test.shape,y_test.shape)
# (50000, 32, 32, 3) (50000, ) (10000, 32, 32, 3) (10000, )train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.map(preprocess).shuffle(10000).batch(128)test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(128)sample = next(iter(train_db))
print('sample:',sample[0].shape,sample[1].shape,tf.reduce_min(sample[0]),tf.reduce_max(sample[0]))
# sample: (64, 32, 32, 3) (64, ) tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32)def main():# 构建多层网络# [b,32,32,3] => [b,1,1,512]conv_net = Sequential(conv_layers)# 构建全连接层fc_net = Sequential([layers.Dense(256,activation=tf.nn.relu),layers.Dense(128,activation=tf.nn.relu),layers.Dense(100,activation=None),])conv_net.build(input_shape=[None, 32, 32, 3])fc_net.build(input_shape=[None,512])# 所有的训练参数variables = conv_net.trainable_variables + fc_net.trainable_variables# 优化器optimizer = optimizers.Adam(lr=1e-4)for epoch in range(50):for step,(x,y) in enumerate(train_db):with tf.GradientTape() as tape:# [b,32,32,3] => [b,1,1,512]out = conv_net(x)# 相当于flatten层-打平out = tf.reshape(out,[-1,512])# [b,512] => [b,100]logits = fc_net(out)# [b] =>[b,100]y_onehot = tf.one_hot(y,depth=100)# compute lossloss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)loss = tf.reduce_mean(loss)grads = tape.gradient(loss,variables)optimizer.apply_gradients(zip(grads,variables))if step % 100 ==0:print(epoch,step,'loss:',float(loss))total_num,total_correct = 0,0for x,y in test_db:out = conv_net(x)out = tf.reshape(out,[-1,512])logits = fc_net(out)prob = tf.nn.softmax(logits,axis=1)pred = tf.argmax(prob,axis=1)pred = tf.cast(pred,dtype=tf.int32)correct = tf.cast(tf.equal(pred,y),dtype=tf.int32)correct = tf.reduce_sum(correct)total_num += x.shape[0]total_correct += int(correct)acc = total_correct / total_numprint(epoch,'acc:',acc)if __name__ == '__main__':main()

深度学习2.0-31.CIFAR100与VGG13实战相关推荐

  1. 深度学习笔记(31) 迁移与增强

    深度学习笔记(31) 迁移与增强 1. 迁移学习 2. 大训练集的迁移学习 3. 迁移规律 4. 数据增强 1. 迁移学习 如果要做一个计算机视觉的应用,相比于从头训练权重,或者说从随机初始化权重开始 ...

  2. halcon 深度学习标注_HALCON深度学习工具0.4 早鸟版发布了

    原标题:HALCON深度学习工具0.4 早鸟版发布了 HALOCN深度学习工具在整个深度学习过程中扮演着重要的作用,而且在将来将扮演更重要的辅助作用,大大加快深度学习的开发流程,目前发布版本工具的主要 ...

  3. halcon显示坐标_HALCON深度学习工具0.4 早鸟版发布了

    HALOCN深度学习工具在整个深度学习过程中扮演着重要的作用,而且在将来将扮演更重要的辅助作用,大大加快深度学习的开发流程,目前发布版本工具的主要作用是图像数据处理和目标检测和分类中的标注. 标注训练 ...

  4. 神经网络与深度学习——TensorFlow2.0实战(笔记)(二)(开发环境介绍)

    开发环境介绍 Python3 1.结构清晰,简单易学 2.丰富的标准库 3.强大的的第三方生态系统 4.开源.开放体系 5.高可扩展性:胶水语言 6.高可扩展性:胶水语言 7.解释型语言,实现复杂算法 ...

  5. 资源下载| 深度学习Pytoch1.0如何玩?这一门含900页ppt和代码实例的深度学习课程带你飞

    本文来自专知 近日,在NeurIPS 2018 大会上,Facebook 官方宣布 PyTorch 1.0 正式版发布了.如何用Pytorch1.0搞深度学习?对很多小白学生是个问题.瑞士非盈利研究机 ...

  6. 深度学习_TensorFlow2.0基础_张量创建,运算,维度变换,采样

    Tensorflow2.0 基础 一:TensorFlow特性 1.TensorFlow An end-to-end open source machine learning platform end ...

  7. 神经网络与深度学习——TensorFlow2.0实战(笔记)(五)(NumPy科学计算库<2>python)

    数组元素的切片 一维数组 #一维数组 #切片方法和Python序列数据结构的切片一样 a=np.array([0,1,2,3,4],dtype=np.int64)#占用新的内存 #不包括结束位置 pr ...

  8. 深度学习-计算机视觉-0基础-学习历程

    周志华<机器学习>------------------------若是想从基础算法公式开始可以先试着看一下周志华的<机器学习>,由于我对公式推导很头疼,看了几页就跳过了.(在经 ...

  9. 动手学深度学习V2.0(Pytorch)——11.模型选择+过拟合和欠拟合

    文章目录 1. 模型选择 2. 过拟合和欠拟合 3. 代码 4. Q&A 4.1 SVM和神经网络相比,缺点在哪里 4.2 训练集验证集测试集比例 4.3 时序预测问题中的测试集训练集 4.4 ...

  10. 深度学习数字仪表盘识别_【深度学习系列】手写数字识别实战

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

最新文章

  1. 十一好礼,90份新品MCU开发板免费送!
  2. Java 高级 --- 多线程快速入门
  3. Apache ZooKeeper - ZK的数据和文件
  4. hive实现not in
  5. js中执行到一个if就停止的代码_Node 中如何引入一个模块及其细节
  6. 为什么数据可视化很重要
  7. 圈圈教你玩usb第一版硬件实物使用说明
  8. c# 实现图片转双层PDF,PDF转OFD格式文件
  9. JavaScript中浏览器兼容性解决办法
  10. 五大云存储服务对比:iCloud、Google...
  11. 考研数学——全程复习建议(汤)
  12. Symbol Factory Universal v3.X 工业图形库
  13. 【工程光学】理想光学系统
  14. 电脑开机黑屏---只有一个鼠标箭头处理办法
  15. 我的(此)电脑里面除了磁盘以外,多了一个CD驱动器,删除方法,亲测有效
  16. Linux命令·chmod
  17. 小案例 CSS之旋转的可乐瓶
  18. yocto linux dns,ZYNQ_LINUX的根文件系统设置为QSPI_FLASH,JFFS2。
  19. 一次buge寻找过程
  20. 船舶领域研究综述(截至2018)

热门文章

  1. J2EE Architecture(6)
  2. 为何近期QQ和MSN老是被攻击
  3. 【linux】修改某一行
  4. 解数独(Python)
  5. jquery 初步(四)内容过滤器
  6. 遍历josn的三种方式
  7. .ashx文件与.ashx.cs
  8. 类数据源Visual C++对ODBC数据库资源的访问
  9. SVN server
  10. LightSpeed ORM .NET简单运用