theano学习笔记(2)基础函数

1、随机函数库的调用

2、卷积神经网络

[python] view plaincopy
  1. #-*-coding:utf-8-*-
  2. import theano
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from loaddata import loadmnist
  6. import theano.tensor as T
  7. #softmax函数
  8. class softmax:
  9. #outdata为我们标注的输出,hiddata网络输出层的输入,nin,nout为输入、输出神经元个数
  10. def __init__(self,hiddata,outdata,nin,nout):
  11. self.w=theano.shared(value=np.zeros((nin,nout),dtype=theano.config.floatX),name='w');
  12. self.b=theano.shared(value=np.zeros((nout,),dtype=theano.config.floatX),name='b')
  13. prey=T.nnet.softmax(T.dot(hiddata,self.w)+self.b)#通过softmax函数,得到输出层每个神经元数值(概率)
  14. self.loss=-T.mean(T.log(prey)[T.arange(outdata.shape[0]),outdata])#损失函数
  15. self.para=[self.w,self.b]
  16. self.predict=T.argmax(prey,axis=1)
  17. self.error=T.mean(T.neq(T.argmax(prey,axis=1),outdata))
  18. #输入层到隐藏层
  19. class HiddenLayer:
  20. def __init__(self,inputx,nin,nout):
  21. a=np.sqrt(6./(nin+nout))
  22. ranmatrix=np.random.uniform(-a,a,(nin,nout));
  23. self.w=theano.shared(value=np.asarray(ranmatrix,dtype=theano.config.floatX),name='w')
  24. self.b=theano.shared(value=np.zeros((nout,),dtype=theano.config.floatX),name='b')
  25. self.out=T.tanh(T.dot(inputx,self.w)+self.b)
  26. self.para=[self.w,self.b]
  27. #传统三层感知器
  28. class mlp:
  29. def __init__(self,nin,nhid,nout):
  30. x=T.fmatrix('x')
  31. y=T.ivector('y')
  32. #前向
  33. hlayer=HiddenLayer(x,nin,nhid)
  34. olayer=softmax(hlayer.out,y,nhid,nout)
  35. #反向
  36. paras=hlayer.para+olayer.para
  37. dparas=T.grad(olayer.loss,paras)
  38. updates=[(para,para-0.1*dpara) for para,dpara in zip(paras,dparas)]
  39. self.trainfunction=theano.function(inputs=[x,y],outputs=olayer.loss,updates=updates)
  40. def train(self,trainx,trainy):
  41. return self.trainfunction(trainx,trainy)
  42. #卷积神经网络的每一层,包括卷积、池化、激活映射操作
  43. #img_shape为输入特征图,img_shape=(batch_size,特征图个数,图片宽、高)
  44. #filter_shape为卷积操作相关参数,filter_shape=(输入特征图个数、输出特征图个数、卷积核的宽、卷积核的高)
  45. #,这样总共filter的个数为:输入特征图个数*输出特征图个数*卷积核的宽*卷积核的高
  46. class LeNetConvPoolLayer:
  47. def __init__(self,inputx,img_shape,filter_shape,poolsize=(2,2)):
  48. #参数初始化
  49. assert img_shape[1]==filter_shape[1]
  50. a=np.sqrt(6./(filter_shape[0]+filter_shape[1]))
  51. v=np.random.uniform(low=-a,high=a,size=filter_shape)
  52. wvalue=np.asarray(v,dtype=theano.config.floatX)
  53. self.w=theano.shared(value=wvalue,name='w')
  54. bvalue=np.zeros((filter_shape[0],),dtype=theano.config.floatX)
  55. self.b=theano.shared(value=bvalue,name='b')
  56. covout=T.nnet.conv2d(inputx,self.w)#卷积操作
  57. covpool=T.signal.downsample.max_pool_2d(covout,poolsize)#池化操作
  58. self.out=T.tanh(covpool+self.b.dimshuffle('x', 0, 'x', 'x'))
  59. self.para=[self.w,self.b]
  60. trainx,trainy=loadmnist()
  61. trainx=trainx.reshape(-1,1,28,28)
  62. batch_size=30
  63. m=trainx.shape[0]
  64. ne=m/batch_size
  65. batchx=T.tensor4(name='batchx',dtype=theano.config.floatX)
  66. batchy=T.ivector('batchy')
  67. #
  68. cov1_layer=LeNetConvPoolLayer(inputx=batchx,img_shape=(batch_size,1,28,28),filter_shape=(20,1,5,5))
  69. cov2_layer=LeNetConvPoolLayer(inputx=cov1_layer.out,img_shape=(batch_size,20,12,12),filter_shape=(50,20,5,5))
  70. cov2out=cov2_layer.out.flatten(2)
  71. hlayer=HiddenLayer(cov2out,4*4*50,500)
  72. olayer=softmax(hlayer.out,batchy,500,10)
  73. paras=cov1_layer.para+cov2_layer.para+hlayer.para+olayer.para
  74. dparas=T.grad(olayer.loss,paras)
  75. updates=[(para,para-0.1*dpara) for para,dpara in zip(paras,dparas)]
  76. train_function=theano.function(inputs=[batchx,batchy],outputs=olayer.loss,updates=updates)
  77. test_function=theano.function(inputs=[batchx,batchy],outputs=[olayer.error,olayer.predict])
  78. testx,testy=loadmnist(True)
  79. testx=testx.reshape(-1,1,28,28)
  80. train_history=[]
  81. test_history=[]
  82. for it in range(20):
  83. sum=0
  84. for i in range(ne):
  85. a=trainx[i*batch_size:(i+1)*batch_size]
  86. loss_train=train_function(trainx[i*batch_size:(i+1)*batch_size],trainy[i*batch_size:(i+1)*batch_size])
  87. sum=sum+loss_train
  88. sum=sum/ne
  89. print 'train_loss:',sum
  90. test_error,predict=test_function(testx,testy)
  91. print 'test_error:',test_error
  92. train_history=train_history+[sum]
  93. test_history=test_history+[test_error]
  94. n=len(train_history)
  95. fig1=plt.subplot(111)
  96. fig1.set_ylim(0.001,0.2)
  1. fig1.plot(np.arange(n),train_history,'-')
from: http://blog.csdn.net/hjimce/article/details/46806923

深度学习(三)theano学习笔记(2)基础函数-未完待续相关推荐

  1. MySQL学习笔记(基础篇未完待补充)

    一.MySQL数据库基 目录 一.MySQL数据库基础篇 1.数据库概述与MySQL安装篇 第1章:数据库概述 1.为什么要使用数据库 2. 数据库与数据库管理系统 2.2 数据库与数据库管理系统的关 ...

  2. 5G网络学习(三)——大白话讲解PDU会话(未完待续)

    在介绍PDU会话之前让我们介绍一下什么是PDU PDU简介 PDU(Protocol Data Unit)是协议层的协议在对等层之间交换的信息叫协议数据单元. 封装 数据要通过网络进行传输,要从高层一 ...

  3. 火箭发射理论(基础篇-未完待续)//2021-1-27

    前言: 嗯,这个就没有那么多为什么了,浩瀚星海,对于人类而言,这是探索宇宙的第一步吧,所以对于我这种只有几十年生命周期的普通生物而言,这不言而喻.正如康德所言:有两种东西,我对它们的思考越是深沉和持久 ...

  4. CMake Cookbook笔记(12/23未完待续,游戏服务器观点阅读,编译器及指令集不涉及)

    文章目录 一.配置环境(略) 二.从可执行文件到库 1)将单个源码文件编译为可执行文件 2)切换生成器(-G) 3)构建和链接静态库和动态库(还有对象库的使用举例) 4)用条件句控制编译 5)向用户显 ...

  5. jQuery基础(未完待续)

    1.       jQuery核心函数 jQuery也可写$,通常情况下$可能会与其他框架中的对象冲突(php有$的用法),所以如果所用的框架没有$的用法,jQuery可用$代替 (1)$(docum ...

  6. 脚本基础(未完待续)

    脚本执行 1.赋予权限,chmod 755 hello.sh  ./hello.sh 2.通过bash执行脚本,bash hello.sh bash快捷键 dos2unix 文件名  windows文 ...

  7. pythonb超分辨成像_Papers | 超分辨 + 深度学习(未完待续)

    1. SRCNN 1.1. Contribution end-to-end深度学习应用在超分辨领域的开山之作(非 end-to-end 见 Story.3 ). 指出了超分辨方向上传统方法( spar ...

  8. 二叉树学习笔记(未完待续)

    摘要 二叉树学习笔记(未完待续). 博客 IT老兵驿站. 前言 昨天(2019-11-07)复习红黑树,发现红黑树和二叉树密不可分,所以这里再复习一下二叉树. 在大学的时候,这块我很认真地学习了一遍. ...

  9. Windows x64内核学习笔记(五)—— KPTI(未完待续)

    Windows x64内核学习笔记(五)-- KPTI(未完待续) KPTI 实验一:构造IDT后门并读取Cr3 参考资料 KPTI 描述:KPTI(Kernel page-table isolati ...

最新文章

  1. c 多文件全局变量_C/CPP : static 关键字 及 变量函数的不同
  2. find -exec 与xargs 区别
  3. pmp每日三题(2022年3月4日)
  4. php中$stu_by,PHP基础案例二:计算学生年龄
  5. 关于相似性度量与各类距离的意义
  6. druid监控页面_Spring boot学习(四)Spring boot整合Druid
  7. java 启动顺序_java语句执行顺序
  8. WebSocket+HTML5实现在线聊天室
  9. sap 用户权限表_系统管理(BASIS)之 SAP用户权限介绍
  10. linux 端口映射 命令,linux查看端口映射命令
  11. Word编辑中的域代码详解
  12. 办公用品管理系统服务器版,恒达办公用品管理系统
  13. LeetCode 三等分(题解+优化过程)
  14. Python崛金系列--4.python量化股票
  15. Android 实现水波纹效果
  16. h3c交换机端口加入vlan命令_7.2.2 H3C交换机VLAN接口基本属性配置
  17. fontsquirrel字体安装(特殊字体 @font-face)
  18. Python输出[m,n]既能被3整除又能被7整除的数的个数
  19. 推特正式起诉马斯克 要求强制其按原协议完成收购
  20. 互联网产品都有哪些类型?

热门文章

  1. 几周内搞定Java的10个方法
  2. x_html语言名词解释,第2章++XHTML标记语言(97页)-原创力文档
  3. 【算法的时间复杂度和空间复杂度】-算法02
  4. Dubbo负载均衡与集群容错
  5. android开发计算器微积分,不到1M的良心之作!连微积分都能算的计算器APP_TOM科技...
  6. 如何删除tmp计算机桌面,Win10系统中tmp文件删除不了应该如何解决?
  7. 创建docker容器时出现 docker: Error response from daemon, The container name is already in use by container
  8. go语言打印日期_判定是否掌握Go语言的最重要标准:对并发的掌握
  9. java做台球时老是闪屏_电脑老是闪屏的原因和解决办法
  10. 计算各种图形的周长(接口与多态)_JAVA