简述

基于之前的框架,修改了一下CNN的结构
【Kaggle-MNIST之路】CNN+改进过的损失函数+多次的epoch(四)

  • 评分:0.988
  • 排名:1200+

代码

  • 和之前的一样,会把模型生成出来,用于后续的保存等工作。
  • 可以用之前的方法来改一下代码。但是需要把生成数据的那个版本的类给替换掉才行。
import pandas as pd
import torch.utils.data as data
import torch
import torch.nn as nnfile = './all/train.csv'
LR = 0.01class MNISTCSVDataset(data.Dataset):def __init__(self, csv_file, Train=True):self.dataframe = pd.read_csv(csv_file, iterator=True)self.Train = Traindef __len__(self):if self.Train:return 42000else:return 28000def __getitem__(self, idx):data = self.dataframe.get_chunk(100)ylabel = data['label'].as_matrix().astype('float')xdata = data.ix[:, 1:].as_matrix().astype('float')return ylabel, xdataclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.layer1 = nn.Sequential(# (1, 28, 28)nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3,  # 卷积filter, 移动块长stride=1,  # filter的每次移动步长),nn.ReLU(),nn.BatchNorm2d(32),nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3,  # 卷积filter, 移动块长stride=1,  # filter的每次移动步长),nn.ReLU(),nn.BatchNorm2d(32),nn.Conv2d(in_channels=32,out_channels=32,kernel_size=5,  # 卷积filter, 移动块长stride=2,  # filter的每次移动步长padding=2,),nn.ReLU(),nn.BatchNorm2d(32),nn.Dropout(0.4),)self.layer2 = nn.Sequential(nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,  # 卷积filter, 移动块长stride=1,  # filter的每次移动步长),nn.ReLU(),nn.BatchNorm2d(64),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,  # 卷积filter, 移动块长stride=1,  # filter的每次移动步长),nn.ReLU(),nn.BatchNorm2d(64),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=5,  # 卷积filter, 移动块长stride=2,  # filter的每次移动步长padding=2,),nn.ReLU(),nn.BatchNorm2d(64),nn.Dropout(0.4),)self.layer3 = nn.Linear(64 * 4 * 4, 10)def forward(self, x):# print(x.shape)x = self.layer1(x)# print(x.shape)x = self.layer2(x)# print(x.shape)x = x.view(x.size(0), -1)# print(x.shape)x = self.layer3(x)# print(x.shape)return xnet = CNN()
loss_function = nn.MultiMarginLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
EPOCH = 10
for epoch in range(EPOCH):mydataset = MNISTCSVDataset(file)train_loader = torch.utils.data.DataLoader(mydataset, batch_size=1, shuffle=True)print('epoch %d' % epoch)for step, (yl, xd) in enumerate(train_loader):xd = xd.reshape(100, 1, 28, 28).float()output = net(xd)yl = yl.long()loss = loss_function(output, yl.squeeze())optimizer.zero_grad()loss.backward()optimizer.step()if step % 40 == 0:print('step %d' % step, loss)torch.save(net, 'divided-net.pkl')

【Kaggle-MNIST之路】CNN结构改进+改进过的损失函数(五)相关推荐

  1. 【Kaggle-MNIST之路】CNN结构再改进+交叉熵损失函数(六)

    简述 这里再添加了一个卷积层.用一个kernal=4的卷积层做一个卷积之后,再做映射. 基于之前的一个版本 [Kaggle-MNIST之路]CNN结构改进+改进过的损失函数(五) 成绩:0.9898 ...

  2. 多种深度模型实现手写字母MNIST的识别(CNN,RNN,DNN,逻辑回归,CRNN,LSTM/Bi-LSTM,GRU/Bi-GRU)

    多种深度模型实现手写字母MNIST的识别(CNN,RNN,DNN,逻辑回归,CRNN,LSTM/Bi-LSTM,GRU/Bi-GRU 1.CNN模型 1.1 代码 1.2 运行结果 2.RNN模型 2 ...

  3. python机器学习及实践_机器学习入门之《Python机器学习及实践:从零开始通往Kaggle竞赛之路》...

    本文主要向大家介绍了机器学习入门之<Python机器学习及实践:从零开始通往Kaggle竞赛之路>,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助. <Python 机 ...

  4. 【Kaggle-MNIST之路】CNN+改进过的损失函数+多次的epoch(四)

    简述 上一个代码 在看完很多版本的代码,看了下,发现一个问题,随着epoch的次数逐渐上升,精度会一定程度上上升.(有时候也不一定) 所以,怀疑自己的这个代码还有改进的空间,所以,在提高了一下epoc ...

  5. 【Kaggle-MNIST之路】CNN+改进过的损失函数(三)

    简述 在上一个版本上的CNN的框架的基础上. 上一个版本 卷积神经网络CNN入门[pytorch学习] 调用了上面的框架. 目前: 分数:0.9160 排名:2400+ 框架代码 import pan ...

  6. Keras在mnist上的CNN实践,并且自定义loss函数曲线图

    使用keras实现CNN,直接上代码: from keras.datasets import mnist from keras.models import Sequential from keras. ...

  7. Tensorflow的MNIST进阶教程CNN网络参数理解

    背景 问题说明 分析 LeNet5参数 MNIST程序参数 遗留问题 小结 背景 之前博文中关于CNN的模型训练功能上是能实现,但是研究CNN模型内部结构的时候,对各个权重系数ww,偏差bb的shap ...

  8. Python机器学习及实践+从零开始通往Kaggle竞赛之路

    内容简介 本书面向所有对机器学习与数据挖掘的实践及竞赛感兴趣的读者,从零开始,以Python编程语言为基础,在不涉及大量数学模型与复杂编程知识的前提下,逐步带领读者熟悉并且掌握当下最流行的机器学习.数 ...

  9. 《Python机器学习及实践:从零开始通往Kaggle竞赛之路》第1章 简介篇 学习笔记(三)“良/恶性乳腺癌肿瘤预测”总结

    目录 "良/恶性乳腺癌肿瘤预测" 1.机器学习的三个关键术语 (1)任务 (2)经验 (3)性能 2.机器学习的学习过程 (1)观察测试集数据分布 (2)初始化二类分类器 (3)训 ...

最新文章

  1. 257.二叉树的所有路径
  2. GNU C之__attribute__
  3. MySQL LIMIT:限制查询结果的记录条数
  4. Linux基础命令(常用的)
  5. sublime-text3按tab跳出括号
  6. 帝国cms模板仿企业网站
  7. 连接池dbcp跟c3p0
  8. 除了随机还要进化——对Infinity进一步的想法
  9. python 几何计算_计算几何-凸包算法 Python实现与Matlab动画演示
  10. 浅析EL表达式注入漏洞
  11. 王者荣耀背景html,《王者荣耀》登录界面背景怎么修改 登录背景图片更换方法...
  12. matlab fvtool 滤波器频响
  13. 超好用的mac虚拟机软件:VM虚拟机 mac中文版
  14. 函数签名function signature是什么意思
  15. JavaEE 13个核心规范
  16. 软件工程师嵌入式开发交流论坛推荐排行
  17. jQuery—弹窗广告
  18. 常见的agv控制系统及功能有哪些?
  19. 货郎担问题(TSP问题)
  20. 22款受欢迎的计算机取证工具

热门文章

  1. Linux Platform Device and Driver
  2. 运行Android应用时提示ADB是否存在于指定路径问题
  3. 面向切面编程-日志切面应用
  4. 云计算之KVM虚拟化实战
  5. 树莓派3B用Ubuntu MATE安装ros
  6. 详解设计模式之工厂模式(简单工厂+工厂方法+抽象工厂)
  7. JQuery中的html(),text(),val()区别
  8. 黄猫被汽车撞死 花猫雨夜苦守
  9. activity 变成后台进程后被杀死_Android后台杀死系列之二:ActivityManagerService与App现场恢复机制...
  10. 开始→运行→输入的命令集锦(转载)