AlexNet

AlexNet 可以说是具有历史意义的一个网络结构,可以说在AlexNet之前,深度学习已经沉寂了很久。历史的转折在2012年到来,AlexNet 在当年的ImageNet图像分类竞赛中,top-5错误率比上一年的冠军下降了十个百分点,而且远远超过当年的第二名。

下图所示即为Alexnet的网络结构,主要由九个部分组成:

  1. 图像输入层
  2. Layer1(卷积层+池化层)
  3. Layer2(卷积层+池化层)
  4. Layer3(卷积层)
  5. Layer4(卷积层)
  6. Layer5(卷积层+池化层)
  7. Layer6(全连接层+Dropout)
  8. Layer7(全连接层+Dropout)
  9. Softmax层

不同之处

AlexNet的原始输入数据为224*224*3,但是因为自己电脑的性能不够好,只能通过mnist小型数据集进行替代。除此之外,AlexNet中包含LRN层(局部响应归一化层),由于现在很少使用LRN层,更多的是被L1、L2或者Dropout等代替,所以程序中只使用了Dropout,tensorflow中包含LRN操作,有兴趣可以看下。以下代码可以直接运行使用!

代码

# -*- coding: utf-8 -*-
"""
Created on Thu May 10 12:53:59 2018
@author: new
"""
#Tensorflow在mnist数据集上实现Alexnet
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import scipy.misc
import matplotlib.image as mpimg
from skimage import io
#这里可以通过tensorflow内嵌的函数现在mnist数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
sess = tf.InteractiveSession()
#Layer1
W_conv1 =tf.Variable(tf.truncated_normal([3, 3, 1, 32],stddev=0.1))
b_conv1 = tf.Variable(tf.constant(0.1,shape=[32]))
#调整x的大小
x_image = tf.reshape(x, [-1,28,28,1])
h_conv1 = tf.nn.relu(tf.nn.conv2d(x_image, W_conv1,strides=[1, 1, 1, 1], padding='SAME') + b_conv1)
h_pool1 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1],strides=[1, 1, 1, 1], padding='SAME')
#Layer2
W_conv2 = tf.Variable(tf.truncated_normal([5, 5, 32, 64],stddev=0.1))
b_conv2 = tf.Variable(tf.constant(0.1,shape=[64]))
h_conv2 = tf.nn.relu(tf.nn.conv2d(h_pool1, W_conv2,strides=[1, 1, 1, 1], padding='SAME') + b_conv2)
h_pool2 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')
#Layer3
W_conv3 = tf.Variable(tf.truncated_normal([5, 5, 64, 64],stddev=0.1))
b_conv3 = tf.Variable(tf.constant(0.1,shape=[64]))
h_conv3 = tf.nn.relu(tf.nn.conv2d(h_pool2, W_conv3,strides=[1, 1, 1, 1], padding='SAME') + b_conv3)
#Layer4
W_conv4 = tf.Variable(tf.truncated_normal([5, 5, 64, 32],stddev=0.1))
b_conv4 = tf.Variable(tf.constant(0.1,shape=[32]))
h_conv4 = tf.nn.relu(tf.nn.conv2d(h_conv3, W_conv4,strides=[1, 1, 1, 1], padding='SAME') + b_conv4)
#Layer5
W_conv5 = tf.Variable(tf.truncated_normal([5, 5, 32, 64],stddev=0.1))
b_conv5 = tf.Variable(tf.constant(0.1,shape=[64]))
h_conv5 = tf.nn.relu(tf.nn.conv2d(h_conv4, W_conv5,strides=[1, 1, 1, 1], padding='SAME') + b_conv5)
h_pool3 = tf.nn.max_pool(h_conv5, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')
#Layer6-全连接层
W_fc1 = tf.Variable(tf.truncated_normal([7*7*64,1024],stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1,shape=[1024]))
#对h_pool2数据进行铺平
h_pool2_flat = tf.reshape(h_pool3, [-1, 7*7*64])
#进行relu计算,matmul表示(wx+b)计算
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
#Layer7-全连接层,这里也可以是[1024,其它],大家可以尝试下
W_fc2 = tf.Variable(tf.truncated_normal([1024,1024],stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1,shape=[1024]))
h_fc2 = tf.nn.relu(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
h_fc2_drop = tf.nn.dropout(h_fc2, keep_prob)
#Softmax层
W_fc3 = tf.Variable(tf.truncated_normal([1024,10],stddev=0.1))
b_fc3 = tf.Variable(tf.constant(0.1,shape=[10]))
y_conv = tf.matmul(h_fc2_drop, W_fc3) + b_fc3
#在这里通过tf.nn.softmax_cross_entropy_with_logits函数可以对y_conv完成softmax计算,同时计算交叉熵损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))#定义训练目标以及加速优化器
train_step = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy)
#计算准确率
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#初始化变量
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
for i in range(20000):batch = mnist.train.next_batch(10)if i%100 == 0:train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})print("step %d, training accuracy %g"%(i, train_accuracy))train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})#保存模型
save_path = saver.save(sess, "./model/save_net.ckpt")print("test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images[:3000], y_: mnist.test.labels[:3000], keep_prob: 1.0}))

训练结果

可以发现在mnist数据集上进行训练,15轮之后即可收敛。

step 0, training accuracy 0
step 100, training accuracy 0.2
step 200, training accuracy 0.2
step 300, training accuracy 0.1
step 400, training accuracy 0.2
step 500, training accuracy 0.4
step 600, training accuracy 0.6
step 700, training accuracy 0.5
step 800, training accuracy 0.9
step 900, training accuracy 0.8
step 1000, training accuracy 0.7
step 1100, training accuracy 0.9
step 1200, training accuracy 0.9
step 1300, training accuracy 0.9
step 1400, training accuracy 0.9
step 1500, training accuracy 0.9
step 1600, training accuracy 1
step 1700, training accuracy 1
step 1800, training accuracy 1
step 1900, training accuracy 1

Tensorflow基于mnist数据集实现AlexNet相关推荐

  1. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  2. 基于TensorFlow和mnist数据集的手写数字识别系统 ,可识别电话号码,识别准确率高,有对比实验,两组模型,可讲解代码

    基于TensorFlow和mnist数据集的手写数字识别系统 ,可识别电话号码,识别准确率高,有对比实验,两组模型,可讲解代码

  3. TensorFlow基于cifar10数据集实现进阶的卷积网络

    TensorFlow基于cifar10数据集实现进阶的卷积网络 学习链接 CIFAR10模型及数据集介绍 综述 CIFAR10数据集介绍 CIFAR10数据集可视化 CIFAR10模型 CIFAR10 ...

  4. GAN生成对抗网络基本概念及基于mnist数据集的代码实现

    本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用. GAN网络是由两个目标相对立的网络构成的,在 ...

  5. 神经网络--基于mnist数据集取得最高的识别准确率

    前言: Hello大家好,我是Dream. 今天来学习一下如何基于mnist数据集取得最高的识别准确率,本文是从零开始的,如有需要可自行跳至所需内容~ 本文目录: 1.调用库函数 2.调用数据集 3. ...

  6. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  7. TensorFlow读取MNIST数据集错误的问题

    TensorFlow读取mnist数据集错误的问题 运行程序出现"URLError"错误的问题 可能是服务器或路径的原因,可以自行下载数据集后,将数据集放到代码所在的文件夹下,并将 ...

  8. pytorch训练GAN的代码(基于MNIST数据集)

    论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...

  9. 基于MNIST数据集实现车牌识别--初步演示版

    在前几天写的一篇博文<如何从TensorFlow的mnist数据集导出手写体数字图片>中,我们介绍了如何通过TensorFlow将mnist手写体数字集导出到本地保存为bmp文件. 车牌识 ...

  10. TensorFlow基于minist数据集实现手写字识别实战的三个模型

    手写字识别 model1:输入层→全连接→输出层softmax model2:输入层→全连接→隐含层→全连接→输出层softmax model3:输入层→卷积层1→卷积层2→全连接→dropout层→ ...

最新文章

  1. C++知识点14——类与static
  2. 对梯度下降法的简单理解
  3. c语言程序设计一元二次函数,计算一元二次函数的根,大家看看那里有错了。。。。...
  4. Cocos2d-x 在缓存创建图片
  5. 在SqlServer存储过程中使用Cursor(游标)操作记录
  6. CORS解决WebApi跨域问题(转)
  7. Hadoop HIVE
  8. nginx配置php 9000,Nginx支持php配置
  9. 邮件实用技巧九:如何快速查看历史邮件
  10. html ios按钮,ios样式开关按钮jQuery插件
  11. MICRO SIM卡(SIM小卡)尺寸图及剪卡图解
  12. 拆解查看unity游戏资源
  13. C++自动化(模板元)编程基础与应用(4)
  14. iPhone前置排线教程
  15. 解除RAR和ZIP压缩包密码的不同方法
  16. 如何使用ArcGIS制作真实的植被
  17. web api(基于NFine框架) 中接口跳转数据格式错误解决方案
  18. Window10总是自动打开网络代理的解决方案法
  19. Replacing Elements (CodeForces - 1473A)
  20. springframework(九)AOP的advices,中规中矩的使用方式

热门文章

  1. 在向服务器发送请求时发生传输级错误。 (provider: TCP 提供程序, error: 0 - 远程主机强迫关闭了一个现有的连接。)...
  2. 设置布局默认为LinearLayout,却成了RelativeLayout
  3. 淘宝前端框架kissyui
  4. 要开始Ubuntu之旅拉~
  5. C# datagridview 实现按指定某列或多列进行排序
  6. (已解决)Mon Apr 08 14:02:29 CST 2019 WARN: Establishing SSL connection without server's
  7. 批处理(bat)choice命令详解
  8. Multipart生成的临时文件
  9. Consider defining a bean of type ‘com.bsj.system.service.RedisService‘ in your configuration.
  10. C 线程同步的四种方式(Windows)