文章目录

  • 1.CNN实现
  • 2.Pytorch实现CNN
  • 3.使用ImangeNet预训练模型

1.CNN实现

CNN基础

2.Pytorch实现CNN

构建一个简单的CNN模型和训练过程

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = Trueimport torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset# 定义模型
class SVHN_Model1(nn.Module):def __init__(self):super(SVHN_Model1, self).__init__()# CNN提取特征模块self.cnn = nn.Sequential(nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2)),nn.ReLU(),  nn.MaxPool2d(2),nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2)),nn.ReLU(), nn.MaxPool2d(2),)# self.fc1 = nn.Linear(32*3*7, 11)self.fc2 = nn.Linear(32*3*7, 11)self.fc3 = nn.Linear(32*3*7, 11)self.fc4 = nn.Linear(32*3*7, 11)self.fc5 = nn.Linear(32*3*7, 11)self.fc6 = nn.Linear(32*3*7, 11)def forward(self, img):        feat = self.cnn(img)feat = feat.view(feat.shape[0], -1)c1 = self.fc1(feat)c2 = self.fc2(feat)c3 = self.fc3(feat)c4 = self.fc4(feat)c5 = self.fc5(feat)c6 = self.fc6(feat)return c1, c2, c3, c4, c5, c6model = SVHN_Model1()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(model.parameters(), 0.005)loss_plot, c0_plot = [], []
# 迭代10个Epoch
for epoch in range(10):for data in train_loader:c0, c1, c2, c3, c4, c5 = model(data[0])loss = criterion(c0, data[1][:, 0]) + \criterion(c1, data[1][:, 1]) + \criterion(c2, data[1][:, 2]) + \criterion(c3, data[1][:, 3]) + \criterion(c4, data[1][:, 4]) + \criterion(c5, data[1][:, 5])loss /= 6optimizer.zero_grad()loss.backward()optimizer.step()loss_plot.append(loss.item())c0_plot.append((c0.argmax(1) == data[1][:, 0]).sum().item()*1.0 / c0.shape[0])print(epoch)

3.使用ImangeNet预训练模型

class SVHN_Model2(nn.Module):def __init__(self):super(SVHN_Model1, self).__init__()model_conv = models.resnet18(pretrained=True)model_conv.avgpool = nn.AdaptiveAvgPool2d(1)model_conv = nn.Sequential(*list(model_conv.children())[:-1])self.cnn = model_convself.fc1 = nn.Linear(512, 11)self.fc2 = nn.Linear(512, 11)self.fc3 = nn.Linear(512, 11)self.fc4 = nn.Linear(512, 11)self.fc5 = nn.Linear(512, 11)def forward(self, img):        feat = self.cnn(img)# print(feat.shape)feat = feat.view(feat.shape[0], -1)c1 = self.fc1(feat)c2 = self.fc2(feat)c3 = self.fc3(feat)c4 = self.fc4(feat)c5 = self.fc5(feat)return c1, c2, c3, c4, c5

其他模型例如YOLO等到做出来会更新。

Datawhale 零基础入门CV赛事-Task3 字符识别模型相关推荐

  1. Datawhale 零基础入门CV赛事-Task4 模型训练与验证

    文章目录 1.构造验证集 2.模型训练与验证 1.构造验证集 在机器学习模型(特别是深度学习模型)的训练过程中,模型是非常容易过拟合的.深度学习模型在不断的训练过程中训练误差会逐渐降低,但测试误差的走 ...

  2. Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增

    文章目录 数据读取 图像读取 1.pillow 2.opencv 数据读取 数据扩增 数据读取 导入需要的包以及文件路径 import json, glob import numpy as np fr ...

  3. Datawhale 零基础入门CV赛事-Task5 模型集成

    这里写目录标题 1.集成学习方法 2.深度学习中的集成学习 Dropout TTA Snapshot 1.集成学习方法 在机器学习中的集成学习可以在一定程度上提高预测精度,常见的集成学习方法有Stac ...

  4. 零基础入门CV - Task 03 字符识别模型.md

    1. 数据特征提取 学习了解数据特征的概念,实现用python进行数据特征提取. 以sklearn为例进行学习 我们将城市和环境作为字典数据,来进行特征的提取. sklearn.feature_ext ...

  5. 零基础入门CV之街道字符识别----Task1赛题的理解

    Datawhale 零基础入门CV赛事-Task1 赛题理解 本章内容将会对街景字符识别赛题进行赛题背景讲解,对赛题数据的读取进行说明,并给出集中解题思路. 1 赛题理解 赛题名称:零基础入门CV之街 ...

  6. 零基础入门CV赛事,理论结合实践

    Datawhale干货 作者:阿水,Datawhale成员 本次分享的背景是,Datawhle联合天池发布的学习赛:零基础入门CV赛事之街景字符识别.本文以该比赛为例,对计算机视觉赛事中,赛事理解和B ...

  7. 零基础入门CV赛事- 街景字符编码识别

    零基础入门CV赛事- 街景字符编码识别 Task01 学习目标 数据介绍 Task01任务内容 数据读取 解题思路 学习目标 熟悉计算机视觉赛事 完成典型的字符识别问题 掌握CV领域赛事的编程和解题思 ...

  8. 零基础入门CV赛事-Task1 赛题理解

    文章目录 赛题介绍 解题思路 1. 定长字符识别 2.不定长字符识别 3. 检测再识别 赛题介绍 赛题以街道字符为为赛题数据(比赛地址),该数据来自收集的SVHN街道字符,训练集数据包括3W张照片,验 ...

  9. 零基础入门CV之街道字符识别 Task1 赛题理解

    赛题任务 以计算机视觉中字符识别为背景,要求选手预测街道字符编码,这是一个典型的字符识别问题. 赛题数据 数据来源于公开数据集SVHN街道字符. 提供训练集数据3W张照片,验证集数据1W张照片: 每张 ...

最新文章

  1. python中怎么比较两个列表-Python3列表(list)比较操作教程
  2. 【数理知识】《随机过程》方兆本老师-目录
  3. kafka消费的三种模式_快速认识Kafka
  4. 【基础】有关T-SQL的10个好习惯
  5. 带你自学Python系列(三):列表遍历(for循环)
  6. zabbix分布式监控环境搭建
  7. C语言最后一次作业--总结报告
  8. 新手福音,机器学习工具Sklearn 中文文档 0.19版(最新)
  9. js高程笔记1-3章
  10. mysql delphi_delphi 7 连接 MySql
  11. 超级简单的前端 自动复制功能
  12. 绿色版本chrome设为默认浏览器
  13. 梦想家-致停不下来的我们
  14. BUUCTF-Crypto-Quoted-printable题解
  15. React的非受控组件和受控组件
  16. 渡河问题matlab程序,商人渡河问题(MATLAB版)
  17. 音响的灵魂! 世界顶级扬声器品牌介绍
  18. man posix_spawn
  19. angular.min.js:80 Error: [$injector:unpr] http://errors.angularjs.org/1.2.9/
  20. 关于鼠标右键无法正常加载一直转圈

热门文章

  1. Python: 序列list:保持元素顺序同时消除重复值
  2. mysqlreport 文档
  3. 《数据库技术原理与应用教程(第2版)》——第3章 数据管理中的数据模型 3.1 数据模型的基本概念...
  4. PHP函数func_get_args(),func_get_arg(),func_num_args()
  5. Android 图片文件操作、屏幕相关、.9图片的理解
  6. php---header函数的示例代码
  7. .NET深入学习笔记(2):C#中判断空字符串的4种方法性能比较与分析
  8. LaTeX tikz初探——基本图形绘制(1)
  9. mysql shell 配置mysql_Windows Mysql shell 配置
  10. 朋友圈点赞点用例的设计点