【火炉炼AI】深度学习003-构建并训练深度神经网络模型

(本文所使用的Python库和版本号: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2 )

前面我们讲解过单层神经网络模型,发现它结构简单,难以解决一些实际的比较复杂的问题,故而现在发展出了深度神经网络模型。

深度神经网络的深度主要表现在隐含层的层数上,前面的单层神经网络只有一个隐含层,而深度神经网络使用>=2个隐含层。其基本结构为:

图中有两个隐含层,分别是酱色的圆圈(有4个神经元)和绿色的圆圈(有2个神经元),所以这个深度神经网络的结构就是3-4-2结构。图片来源于2017/7/20 朱兴全教授学术讲座观点与总结第二讲:单个神经元/单层神经网络。

对于一些很复杂的深度神经网络,隐含层的个数可能有几百上千个,比如ResNet网络结构的等,其训练过程也更复杂,耗时更长。那么这些模型就非常地“深”了。

使用更深层的神经网络,可以得到更好的表达效果,这可以直观地理解为:在每一个网络层中,输入特征的特点被一步步的抽象出来;下一层网络直接使用上一层抽象的特征进行进一步的线性组合或非线性组合,从而一步一步地得到输出。

1. 构建并训练深度神经网络模型

1.1 准备数据集

本次使用自己生成的一些数据,如下生成代码:

# 准备数据集
# 此处自己生成一些原始的数据点
dataset_X=np.linspace(-10,10,100)
dataset_y=2*np.square(dataset_X)+7 # 即label是feature的平方*2,偏置是7
dataset_y /=np.linalg.norm(dataset_y) # 归一化处理
dataset_X=dataset_X[:,np.newaxis]
复制代码

该数据集的数据分布为:

1.2 构建并训练模型

直接上代码:

# 构建并训练模型
import neurolab as nl
x_min, x_max = dataset_X[:,0].min(), dataset_X[:,0].max()
multilayer_net = nl.net.newff([[x_min, x_max]], [10, 10, 1])
# 模型结构:隐含层有两层,每层有10个神经元,输出层一层。
multilayer_net.trainf = nl.train.train_gd # 设置训练算法为梯度下降
dataset_y=dataset_y[:,np.newaxis]
error = multilayer_net.train(dataset_X, dataset_y, epochs=800, show=100, goal=0.01)
复制代码

-------------------------------------输---------出--------------------------------

Epoch: 100; Error: 2.933891201182385; Epoch: 200; Error: 0.032819979078409965; Epoch: 300; Error: 0.040183833367277225; The goal of learning is reached

--------------------------------------------完-------------------------------------

看来,虽然我们设置要800个循环,但是到达目标0.01时,便自动退出。可以画图看一下error的走势

1.3 用训练好的模型来预测新数据

此处我们没有新数据,假设原始的dataset_X是新数据,那么可以预测出这些新数据的结果,并比较一下真实值和预测值之间的差异,可以比较直观的看出模型的预测效果

# 用训练好的模型来预测
predict_y=multilayer_net.sim(dataset_X)
plt.scatter(dataset_X,dataset_y,label='dataset')
plt.scatter(dataset_X,predict_y,label='predicted')
plt.legend()
plt.title('Comparison of Truth and Predicted')
复制代码

可以看出模型的预测值和真实值大致相同,至少表明模型在训练集上表现比较好。

关于深度神经网络的更具体内容,可以参考博文 神经网络浅讲:从神经元到深度学习.

其实,要解决复杂的问题,不一定要增加模型的深度(即增加隐含层数,但每一层的神经元个数比较少,即模型结构是深而瘦的),还可以增加模型的宽度(即一个或少数几个隐含层,但是增加隐含层的神经元个数,即模型结构是浅而肥的),那么哪一种比较好?

在文章干货|神经网络最容易被忽视的基础知识一中提到:虽然有研究表明,浅而肥的网络结构也能拟合任何函数,但它需要非常的“肥胖”,可能一个隐含层需要成千上万个神经元,这样会导致模型中参数的数量极大地增加。如下比较图:

从上图可以看出:当准确率差不多的时候,参数的数量却相差数倍。这也说明我们一般用深层的神经网络而不是浅层“肥胖”的网络。

########################小**********结###############################

1,深度神经网络的构建和训练已经有成熟的框架来实现,比如Keras,Tensorflow,PyTorch等,用起来更加的简单,此处仅仅用来解释内部结构和进行简单的建模训练。

2,为了解决更加复杂的问题,一般我们选用深而瘦的模型结构,不选用浅而肥的模型,因为这种模型的参数数量非常大,训练耗时长。

#################################################################

注:本部分代码已经全部上传到(我的github)上,欢迎下载。

参考资料:

1, Python机器学习经典实例,Prateek Joshi著,陶俊杰,陈小莉译

【火炉炼AI】深度学习003-构建并训练深度神经网络模型相关推荐

  1. unet是残差网络吗_深度学习系列(三)卷积神经网络模型(ResNet、ResNeXt、DenseNet、DenceUnet)...

    深度学习系列(三)卷积神经网络模型(ResNet.ResNeXt.DenseNet.Dence Unet) 内容目录 1.ResNet2.ResNeXt3.DenseNet4.Dence Unet 1 ...

  2. 深度学习(3)之经典神经网络模型整理:神经网络、CNN、RNN、LSTM

    本文章总结以下经典的神经网络模型整理,大体讲下模型结构及原理- 如果想深入了解模型架构及pytorch实现,可参考我的Pytorch总结专栏 -> 划重点!!!Pytorch总结文章之目录归纳 ...

  3. 【火炉炼AI】机器学习018-项目案例:根据大楼进出人数预测是否举办活动

    [火炉炼AI]机器学习018-项目案例:根据大楼进出人数预测是否举办活动 (本文所使用的Python库和版本号: Python 3.5, Numpy 1.14, scikit-learn 0.19, ...

  4. 【火炉炼AI】机器学习055-使用LBP直方图建立人脸识别器

    [火炉炼AI]机器学习055-使用LBP直方图建立人脸识别器 (本文所使用的Python库和版本号: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplo ...

  5. 深度学习入门(一)——深度学习是什么?

    深度学习入门(一)--深度学习是什么? 看了标题,你心中或许已经有了疑惑.什么是深度学习?这和人工智能有什么关系吗?神经网络不是生物学知识吗?什么是全连接神经网络?如果你对本次技术分享内容足够感兴趣且 ...

  6. 【深度学习】如何选择适合深度学习的GPU?

    如何选择适合深度学习的GPU? 为什么GPU比CPU更适合机器学习或者深度学习? 什么是张量处理单元(TPU)? 目前主流的GPU厂商:Nvidia和AMD 选择GPU时需要关注的主要属性 1. GP ...

  7. 【火炉炼AI】深度学习004-Elman循环神经网络

    [火炉炼AI]深度学习004-Elman循环神经网络 (本文所使用的Python库和版本号: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib ...

  8. 【火炉炼AI】深度学习001-神经网络的基本单元-感知器

    [火炉炼AI]深度学习001-神经网络的基本单元-感知器 (本文所使用的Python库和版本号: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotl ...

  9. 【火炉炼AI】深度学习008-Keras解决多分类问题

    [火炉炼AI]深度学习008-Keras解决多分类问题 参考文章: (1)[火炉炼AI]深度学习008-Keras解决多分类问题 (2)https://www.cnblogs.com/RayDean/ ...

最新文章

  1. 3.83亿开房记录被泄露后,万豪又又又泄露用户数据了
  2. 攻破c语言笔试与机试难点,如何攻破C语言学习、笔试与机试的难点.doc
  3. 8-spark学习笔记-sparksql
  4. pandas对象保存到mysql出错提示“BLOB/TEXT column used in key specification without a key length”解决办法
  5. MySQL利用UDF执行命令
  6. mysql建表语句增加注释_MySQL建表语句+添加注释
  7. Amazon Web Service 雲端運算平台攻略 【3】:免費架WordPress部落格的虛擬主機
  8. 【HTML】获取当前时间并显示在网页上
  9. Android项目中如何用好构建神器Gradle?
  10. html input type=quot;filequot;,input[type='file']默认样式
  11. 【bzoj1614】[Usaco2007 Jan]Telephone Lines架设电话线 二分+SPFA
  12. django.forms生成HTML,python – 在django中为表单自动生成表单字段
  13. UOS桌面操作系统专业版字体
  14. Xshell个人免费版下载
  15. 从ResNet101到ResNet50
  16. html简历如何转换成pdf,将拉勾的HTML简历转成PDF
  17. 转换、刻录DVD影碟光盘教程
  18. springboo集成axis2实现webservice服务
  19. 适当的资本运作能有效提高运营商对产业链的掌控力
  20. 电脑突然上不了网,而且ping网关可以通

热门文章

  1. tomcat - JVM 配置
  2. PHP和MySQL处理树状、分级、无限分类、分层数据的方法
  3. 第21天学习Java的笔记-数学工具类Arrays,Math
  4. 广度优先遍历二叉树(BFS)-C++实现
  5. vb excel遍历列_EXCEL如何把多个表格合并成一个表格
  6. Vscode解决Setting.json报警告:Problems loading reference ... Unable to load schema from ...
  7. 5G | 5G新基建最新进展及投资机会【包含五大板块】
  8. 神经网络 | DeepVO:Towards End-to-End Visual Odometry
  9. pycharm与github相配置连接(上传、删除、更新项目)
  10. mysql 金额 类型,SQL实现根据类型对金额进行归类