【Kaggle-MNIST之路】CNN结构改进+改进过的损失函数(五)
简述
基于之前的框架,修改了一下CNN的结构
【Kaggle-MNIST之路】CNN+改进过的损失函数+多次的epoch(四)
- 评分:0.988
- 排名:1200+
代码
- 和之前的一样,会把模型生成出来,用于后续的保存等工作。
- 可以用之前的方法来改一下代码。但是需要把生成数据的那个版本的类给替换掉才行。
import pandas as pd
import torch.utils.data as data
import torch
import torch.nn as nnfile = './all/train.csv'
LR = 0.01class MNISTCSVDataset(data.Dataset):def __init__(self, csv_file, Train=True):self.dataframe = pd.read_csv(csv_file, iterator=True)self.Train = Traindef __len__(self):if self.Train:return 42000else:return 28000def __getitem__(self, idx):data = self.dataframe.get_chunk(100)ylabel = data['label'].as_matrix().astype('float')xdata = data.ix[:, 1:].as_matrix().astype('float')return ylabel, xdataclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.layer1 = nn.Sequential(# (1, 28, 28)nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3, # 卷积filter, 移动块长stride=1, # filter的每次移动步长),nn.ReLU(),nn.BatchNorm2d(32),nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3, # 卷积filter, 移动块长stride=1, # filter的每次移动步长),nn.ReLU(),nn.BatchNorm2d(32),nn.Conv2d(in_channels=32,out_channels=32,kernel_size=5, # 卷积filter, 移动块长stride=2, # filter的每次移动步长padding=2,),nn.ReLU(),nn.BatchNorm2d(32),nn.Dropout(0.4),)self.layer2 = nn.Sequential(nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3, # 卷积filter, 移动块长stride=1, # filter的每次移动步长),nn.ReLU(),nn.BatchNorm2d(64),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3, # 卷积filter, 移动块长stride=1, # filter的每次移动步长),nn.ReLU(),nn.BatchNorm2d(64),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=5, # 卷积filter, 移动块长stride=2, # filter的每次移动步长padding=2,),nn.ReLU(),nn.BatchNorm2d(64),nn.Dropout(0.4),)self.layer3 = nn.Linear(64 * 4 * 4, 10)def forward(self, x):# print(x.shape)x = self.layer1(x)# print(x.shape)x = self.layer2(x)# print(x.shape)x = x.view(x.size(0), -1)# print(x.shape)x = self.layer3(x)# print(x.shape)return xnet = CNN()
loss_function = nn.MultiMarginLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LR)
EPOCH = 10
for epoch in range(EPOCH):mydataset = MNISTCSVDataset(file)train_loader = torch.utils.data.DataLoader(mydataset, batch_size=1, shuffle=True)print('epoch %d' % epoch)for step, (yl, xd) in enumerate(train_loader):xd = xd.reshape(100, 1, 28, 28).float()output = net(xd)yl = yl.long()loss = loss_function(output, yl.squeeze())optimizer.zero_grad()loss.backward()optimizer.step()if step % 40 == 0:print('step %d' % step, loss)torch.save(net, 'divided-net.pkl')
【Kaggle-MNIST之路】CNN结构改进+改进过的损失函数(五)相关推荐
- 【Kaggle-MNIST之路】CNN结构再改进+交叉熵损失函数(六)
简述 这里再添加了一个卷积层.用一个kernal=4的卷积层做一个卷积之后,再做映射. 基于之前的一个版本 [Kaggle-MNIST之路]CNN结构改进+改进过的损失函数(五) 成绩:0.9898 ...
- 多种深度模型实现手写字母MNIST的识别(CNN,RNN,DNN,逻辑回归,CRNN,LSTM/Bi-LSTM,GRU/Bi-GRU)
多种深度模型实现手写字母MNIST的识别(CNN,RNN,DNN,逻辑回归,CRNN,LSTM/Bi-LSTM,GRU/Bi-GRU 1.CNN模型 1.1 代码 1.2 运行结果 2.RNN模型 2 ...
- python机器学习及实践_机器学习入门之《Python机器学习及实践:从零开始通往Kaggle竞赛之路》...
本文主要向大家介绍了机器学习入门之<Python机器学习及实践:从零开始通往Kaggle竞赛之路>,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助. <Python 机 ...
- 【Kaggle-MNIST之路】CNN+改进过的损失函数+多次的epoch(四)
简述 上一个代码 在看完很多版本的代码,看了下,发现一个问题,随着epoch的次数逐渐上升,精度会一定程度上上升.(有时候也不一定) 所以,怀疑自己的这个代码还有改进的空间,所以,在提高了一下epoc ...
- 【Kaggle-MNIST之路】CNN+改进过的损失函数(三)
简述 在上一个版本上的CNN的框架的基础上. 上一个版本 卷积神经网络CNN入门[pytorch学习] 调用了上面的框架. 目前: 分数:0.9160 排名:2400+ 框架代码 import pan ...
- Keras在mnist上的CNN实践,并且自定义loss函数曲线图
使用keras实现CNN,直接上代码: from keras.datasets import mnist from keras.models import Sequential from keras. ...
- Tensorflow的MNIST进阶教程CNN网络参数理解
背景 问题说明 分析 LeNet5参数 MNIST程序参数 遗留问题 小结 背景 之前博文中关于CNN的模型训练功能上是能实现,但是研究CNN模型内部结构的时候,对各个权重系数ww,偏差bb的shap ...
- Python机器学习及实践+从零开始通往Kaggle竞赛之路
内容简介 本书面向所有对机器学习与数据挖掘的实践及竞赛感兴趣的读者,从零开始,以Python编程语言为基础,在不涉及大量数学模型与复杂编程知识的前提下,逐步带领读者熟悉并且掌握当下最流行的机器学习.数 ...
- 《Python机器学习及实践:从零开始通往Kaggle竞赛之路》第1章 简介篇 学习笔记(三)“良/恶性乳腺癌肿瘤预测”总结
目录 "良/恶性乳腺癌肿瘤预测" 1.机器学习的三个关键术语 (1)任务 (2)经验 (3)性能 2.机器学习的学习过程 (1)观察测试集数据分布 (2)初始化二类分类器 (3)训 ...
最新文章
- 257.二叉树的所有路径
- GNU C之__attribute__
- MySQL LIMIT:限制查询结果的记录条数
- Linux基础命令(常用的)
- sublime-text3按tab跳出括号
- 帝国cms模板仿企业网站
- 连接池dbcp跟c3p0
- 除了随机还要进化——对Infinity进一步的想法
- python 几何计算_计算几何-凸包算法 Python实现与Matlab动画演示
- 浅析EL表达式注入漏洞
- 王者荣耀背景html,《王者荣耀》登录界面背景怎么修改 登录背景图片更换方法...
- matlab fvtool 滤波器频响
- 超好用的mac虚拟机软件:VM虚拟机 mac中文版
- 函数签名function signature是什么意思
- JavaEE 13个核心规范
- 软件工程师嵌入式开发交流论坛推荐排行
- jQuery—弹窗广告
- 常见的agv控制系统及功能有哪些?
- 货郎担问题(TSP问题)
- 22款受欢迎的计算机取证工具
热门文章
- Linux Platform Device and Driver
- 运行Android应用时提示ADB是否存在于指定路径问题
- 面向切面编程-日志切面应用
- 云计算之KVM虚拟化实战
- 树莓派3B用Ubuntu MATE安装ros
- 详解设计模式之工厂模式(简单工厂+工厂方法+抽象工厂)
- JQuery中的html(),text(),val()区别
- 黄猫被汽车撞死 花猫雨夜苦守
- activity 变成后台进程后被杀死_Android后台杀死系列之二:ActivityManagerService与App现场恢复机制...
- 开始→运行→输入的命令集锦(转载)