采用pytorch搭建神经网络,解决kaggle平台手写字识别问题。
数据来源:https://www.kaggle.com/competitions/digit-recognizer/data
参考pytorch官网:https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

数据预处理

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to loadimport numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directoryimport os
for dirname, _, filenames in os.walk('/kaggle/input'):for filename in filenames:print(os.path.join(dirname, filename))# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

输出:
/kaggle/input/digit-recognizer/sample_submission.csv
/kaggle/input/digit-recognizer/train.csv
/kaggle/input/digit-recognizer/test.csv

读取数据查看数据结构

train_df = pd.read_csv("/kaggle/input/digit-recognizer/train.csv")
train_df.head(5)


可以出第一行为标签,后面的pixel0~783为像素点灰度值

对数据进行归一化和预处理

train_feature_df = train_feature_df/255.0
train_feature_df = train_feature_df.apply(lambda x: x.values.reshape(28, 28),axis=1)


此时的数据格式为data列中存储的是28*28的numpy类型的矩阵

自定义DataSet

在pytorch官网手册中的Dataset部分有关于自定义Dataset的详细讲解(https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)。

import pandas as pd
import torch
from torch.utils.data import Datasetclass CustomImageDataset(Dataset):def __init__(self, img_label, img_data, transform=None, target_transform=None):self.img_labels = img_labelself.images = img_dataself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):# 两个均为series类型所以不能调用loc函数,直接索引取即可image = self.images[idx]label = self.img_labels[idx]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

使用我们自定的Dataset对数据进行转换,转换为Dataset类型

from torchvision.transforms import ToTensor
train_dataset = CustomImageDataset(train_label_df, train_feature_df, ToTensor())
train_dataset[0]

查看数据结构可知,数据为[(*label数据(数值类型),图像数据(tensor类型))]

查看数据的图像
此部分代码也可在pytorch官网中查看详细讲解

import matplotlib.pyplot as plt
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(train_dataset), size=(1,)).item()img, label = train_dataset[sample_idx]figure.add_subplot(rows, cols, i)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

输出图像:

将Dataset转化为可迭代图像dataloader

先对训练集数据按训练集:测试集=8:2的比例进行分割,分割后的数据分别转换为dataloader。

from torch.utils.data import DataLoader
from torch.utils.data import random_split
train_dataset, test_dataset = random_split(train_dataset, [int(0.8*len(train_dataset)), int(0.2*len(train_dataset))])train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

构建标准神经网络(SNN)

import torch.optim as optim
from torch import nnclass NeuralNetWork(nn.Module):def __init__(self):super(NeuralNetWork, self).__init__()self.flatten = nn.Flatten() # 把28*28压平成784self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512,256),nn.ReLU(),nn.Linear(256, 10),nn.Softmax(dim=1))def forward(self,x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsdevice = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device:{device}")
model = NeuralNetWork().to(device)

定义超参数、损失函数和优化器

迭代次数:n_epochs=0
学习率:learn_rate=0.01
批处理个数:batch=64 #在上方dataloader定义时已指定
损失函数这里使用交叉熵损失函数(CrossEntropyLoss)
优化器使用随机梯度优化器(SGD)

# 定义参数
n_epochs = 10
learn_rate = 0.01def train_loop(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)for batch,(X,y) in enumerate(dataloader):X = X.float().to(device)y = y.to(device)pred = model(X)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()if batch%64 == 0:# current代表当前是第几条数据loss,current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test_loop(dataloader, model, loss_fn):# test中总的size表示总的图像数size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X = X.float().to(device)y = y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= size # 小数的形式,下面*100是转成百分数的形式print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")loss_fn = nn.CrossEntropyLoss()# 随机梯度下降进行反向传播
optimizer = optim.SGD(model.parameters(), lr=learn_rate, momentum=0.5)for t in range(n_epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done!")

从输出的结构可以看出我们预测的准确率可以达到82.8%,属于很不错的结果了。

数据预测

对test.scv文件中的数据进行预测

# 对数据进行预测
test_df = pd.read_csv("/kaggle/input/digit-recognizer/test.csv")
test_ts = torch.from_numpy(test_df.values).reshape(len(test_df),28,28).float().to(device)
print(f"test_ts.shape:{test_ts.shape}")
pred = model(test_ts)
print(f"result:{pred}",f"pred.shape{pred.shape}")
# 将数据从独热编码的格式中还原
pred_data  = pred.argmax(1)# 注意这里的+1是为了和赛题的文件的预测格式保持一致
result_df  = pd.DataFrame({}, index=test_df.index+1)
result_df.index.name = "ImageId"
result_df["Label"] = pred_data.cpu().detach().numpy()
result_df.head()

保存结果

result_df.to_csv("./data/output/result.csv")

上传提交

查看kaggle得分82.59还算不错的分数,后续可以对算法进行改进使用卷积神经网络可能会得到更高的分数哦。

Pytorch入门练习-kaggle手写字识别神经网络(SNN)实现相关推荐

  1. Pytorch入门练习2-kaggle手写字识别神经网络(CNN)实现

    目录 数据预处理 自定义数据集 构建网络结构 对卷积神经网络进行训练和评估 对数据进行预测 保存预测数据,提交代码 SNN由于无法考虑到图片数据的维度关系,在预测精度上会被限制,本章我们采用CNN卷积 ...

  2. 利用卷积神经网络实现手写字识别

    本文我们介绍一下卷积神经网络,然后基于pytorch实现一个卷积神经网络,并实现手写字识别 卷积神经网络介绍 传统神经网络处理图片问题的不足 让我们先复习一下神经网络的工作流程: 搭建一个神经网络 将 ...

  3. 利用神经网络实现手写字识别

    神经网络介绍 神经网络即多层感知机 如果不知道感知机的可以看博主之前的文章感知机及Python实现 神经网络实现及手写字识别 关于数据集: 从http://yann.lecun.com/exdb/mn ...

  4. 基于tensorflow的MNIST手写字识别

    一.卷积神经网络模型知识要点卷积卷积 1.卷积 2.池化 3.全连接 4.梯度下降法 5.softmax 本次就是用最简单的方法给大家讲解这些概念,因为具体的各种论文网上都有,连推导都有,所以本文主要 ...

  5. 最终章 | TensorFlow战Kaggle“手写识别达成99%准确率

    刘颖,某互联网创业公司COO,技术出身,做产品里最懂运营的. 这是一个TensorFlow的系列文章,本文是第三篇,在这个系列中,你讲了解到机器学习的一些基本概念.TensorFlow的使用,并能实际 ...

  6. TensorFlow基于minist数据集实现手写字识别实战的三个模型

    手写字识别 model1:输入层→全连接→输出层softmax model2:输入层→全连接→隐含层→全连接→输出层softmax model3:输入层→卷积层1→卷积层2→全连接→dropout层→ ...

  7. python手写汉字识别_TensorFlow 2.0实践之中文手写字识别

    问题导读: 1.相比于简单minist识别,汉字识别具有哪些难点? 2.如何快速的构建一个OCR网络模型? 3.读取的时候有哪些点需要注意? 4.如何让模型更简单的收敛? 还在玩minist?fash ...

  8. .net 数字转汉字_TensorFlow 2.0 中文手写字识别(汉字OCR)

    TensorFlow 2.0 中文手写字识别(汉字OCR) 在开始之前,必须要说明的是,本教程完全基于TensorFlow2.0 接口编写,请误与其他古老的教程混为一谈,本教程除了手把手教大家完成这个 ...

  9. 用TensorFlow教你手写字识别

    1.MNIST数据集 基于MNIST数据集实现手写字识别可谓是深度学习经典入门必会的技能,该数据集由60000张训练图片和10000张测试图片组成,每张均为28*28像素的黑白图片.关于数据集的获取, ...

最新文章

  1. 解决后退,清空验证码(其它文本框保留)
  2. dataframe保存为txt_竟然可以用 Python 抓取公号文章保存成 PDF
  3. Rabbitmq~对Vhost的配置
  4. 文本挖掘之 文本相似度判定
  5. 从哪些角度进行手机软件测试
  6. 重磅!阿里自研RISC-V处理器玄铁910成功运行安卓10,相关代码开源
  7. vscode下载与安装教程
  8. codeblocks20.03汉化版 附汉化补丁|codeblocks汉化版
  9. 如何在阿里云服务器部署程序并用域名直接访问
  10. 「干活」基因组组装 学习笔记 - 入门知识点和Genome Survey
  11. python2升级python3语法错误总结
  12. 常用药 学着照顾自己和身边的人
  13. 【Pyecharts】20W条淘宝文胸商品评论数据可视化~
  14. 疾病研究:重症肌无力医师指南
  15. cordova在使用jqmobile中出现的问题(使用$(window).on与window.addEventListener)
  16. 计算机学院年会,我校计算机学院应邀出席全国职业院校计算机系主任年会并作报告...
  17. 黑马程序员_日记9_Java学习感言
  18. 数据库系统工程师难考吗?
  19. 根据html改为ftl模板生成pdf文件,支持中文及换行
  20. 计算机启动时最先运行的程序储存在,2012年自考计算机应用基础试题及答案

热门文章

  1. PathFinder机甲大师电控组问题总结
  2. 我在犹太公司的15年
  3. 优维科技7周年庆|未来可“7”,从心出发
  4. 让我们进入面向对象的世界(一)
  5. python调用matlab 性能_Python调用MATLAB实现fmincon函数
  6. DirectX11教程
  7. linux中vim怎么分栏,Vim+Taglist+AutoComplPop之代码目录分栏信息和自动补全提示(Ubuntu环境)...
  8. Censorship
  9. 超分辨率重建学习总结、SR、super resolution、single image super resolution
  10. Ubuntu20.04换源