,AlexNet是与LeNet不同的 一种新的深度学习模型。

论文原文百度云资源链接:链接:https://pan.baidu.com/s/1WdZnD6aVzUXvzs9XxshROQ 提取码:hans

第一步:模型实现

import os
import cv2
import numpy as np
import paddle
from paddle.io import Dataset
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
from paddle.io import Dataset
from PIL import Image
from PIL import ImageFile
import paddle.nn as nn
import paddle.nn.functional as F
# 打印所使用的GPU编号
print(paddle.device.get_device())
ImageFile.LOAD_TRUNCATED_IMAGES = True# 搭建Alexnet网络class alexnet(paddle.nn.Layer):def __init__(self, ):super(alexnet, self).__init__()self.conv1 = paddle.nn.Conv2D(in_channels=3, out_channels=96, kernel_size=7, stride=2, padding=2)self.conv2 = paddle.nn.Conv2D(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2)self.conv3 = paddle.nn.Conv2D(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)self.conv4 = paddle.nn.Conv2D(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)self.conv5 = paddle.nn.Conv2D(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)self.mp1 = paddle.nn.MaxPool2D(kernel_size=3, stride=2)self.mp2 = paddle.nn.MaxPool2D(kernel_size=3, stride=2)self.L1 = paddle.nn.Linear(in_features=256*3*3, out_features=1024)self.L2 = paddle.nn.Linear(in_features=1024, out_features=512)self.L3 = paddle.nn.Linear(in_features=512, out_features=10)def forward(self, x):x = self.conv1(x)x = paddle.nn.functional.relu(x)x = self.mp1(x)x = self.conv2(x)x = paddle.nn.functional.relu(x)x = self.mp2(x)x = self.conv3(x)x = paddle.nn.functional.relu(x)x = self.conv4(x)x = paddle.nn.functional.relu(x)x = self.conv5(x)x = paddle.nn.functional.relu(x)x = paddle.flatten(x, start_axis=1, stop_axis=-1)x = self.L1(x)x = paddle.nn.functional.relu(x)x = self.L2(x)x = paddle.nn.functional.relu(x)x = self.L3(x)return x

第二步:查看一下网络结构;

# 网络结构   应用paddle.summary检查网络结构是否正确。
model = alexnet()paddle.summary(model, (100,3,32,32))

运行后的输出结果。

---------------------------------------------------------------------------Layer (type)       Input Shape          Output Shape         Param #
===========================================================================Conv2D-1      [[100, 3, 32, 32]]   [100, 96, 15, 15]       14,208     MaxPool2D-1   [[100, 96, 15, 15]]    [100, 96, 7, 7]           0       Conv2D-2      [[100, 96, 7, 7]]     [100, 256, 7, 7]       614,656    MaxPool2D-2    [[100, 256, 7, 7]]    [100, 256, 3, 3]          0       Conv2D-3      [[100, 256, 3, 3]]    [100, 384, 3, 3]       885,120    Conv2D-4      [[100, 384, 3, 3]]    [100, 384, 3, 3]      1,327,488   Conv2D-5      [[100, 384, 3, 3]]    [100, 256, 3, 3]       884,992    Linear-1        [[100, 2304]]         [100, 1024]         2,360,320   Linear-2        [[100, 1024]]          [100, 512]          524,800    Linear-3         [[100, 512]]          [100, 10]            5,130
===========================================================================
Total params: 6,616,714
Trainable params: 6,616,714
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 1.17
Forward/backward pass size (MB): 39.61
Params size (MB): 25.24
Estimated Total Size (MB): 66.02
---------------------------------------------------------------------------

在网络设计过程中,往往会出现结构性差错的地方就在卷积层与全连接层之间出现,在进行Flatten(扁平化)之后,出现数据维度对不上。可以在网络定义的过程中,首先将Flatten之后的全连接层去掉,通过paddle.summary输出结构确认卷积层数出为 256×3×3之后,再将全连接层接上。如果出现差错,可以进行每一层校验。

在上面模型的基础上,进行下面相关操作(加载数据,训练,预测)

 第三步,加载Cifar10数据

 原文根据AlexNet的结构,结合 The CIFAR-10 dataset 图片的特点(32×32×3),对AlexNet网络结构进行了微调:

import sys,os,math,time
import matplotlib.pyplot as plt
from numpy import *import paddle
from paddle.vision.transforms import Normalize
normalize = Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], data_format='HWC')from paddle.vision.datasets import Cifar10
cifar10_train = Cifar10(mode='train', transform=normalize)
cifar10_test = Cifar10(mode='test', transform=normalize)train_dataset = [cifar10_train.data[id][0].reshape(3,32,32) for id in range(len(cifar10_train.data))]
train_labels = [cifar10_train.data[id][1] for id in range(len(cifar10_train.data))]class Dataset(paddle.io.Dataset):def __init__(self, num_samples):super(Dataset, self).__init__()self.num_samples = num_samplesdef __getitem__(self, index):data = train_dataset[index]label = train_labels[index]return paddle.to_tensor(data,dtype='float32'), paddle.to_tensor(label,dtype='int64')def __len__(self):return self.num_samples_dataset = Dataset(len(cifar10_train.data))
train_loader = paddle.io.DataLoader(_dataset, batch_size=100, shuffle=True)

第四步 训练网络

test_dataset = [cifar10_test.data[id][0].reshape(3,32,32) for id in range(len(cifar10_test.data))]
test_label = [cifar10_test.data[id][1] for id in range(len(cifar10_test.data))]test_input = paddle.to_tensor(test_dataset, dtype='float32')
test_l = paddle.to_tensor(array(test_label)[:,newaxis])optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
def train(model):model.train()epochs = 2accdim = []lossdim = []testaccdim = []for epoch in range(epochs):for batch, data in enumerate(train_loader()):out = model(data[0])loss = paddle.nn.functional.cross_entropy(out, data[1])acc = paddle.metric.accuracy(out, data[1])loss.backward()optimizer.step()optimizer.clear_grad()accdim.append(acc.numpy())lossdim.append(loss.numpy())predict = model(test_input)testacc = paddle.metric.accuracy(predict, test_l)testaccdim.append(testacc.numpy())if batch%10 == 0 and batch>0:print('Epoch:{}, Batch: {}, Loss:{}, Accuracys:{}{}'.format(epoch, batch, loss.numpy(), acc.numpy(), testacc.numpy()))plt.figure(figsize=(10, 6))plt.plot(accdim, label='Accuracy')plt.plot(testaccdim, label='Test')plt.xlabel('Step')plt.ylabel('Acc')plt.grid(True)plt.legend(loc='upper left')plt.tight_layout()train(model)

训练参数:

BatchSize:100
LearningRate:0.001

  如果BatchSize过小,训练速度变慢。

训练参数:

BatchSize:5000
LearningRate:0.0005

BatchSize:5000,Lr=0.001, DropOut:0.2:

BatchSize:5000,Lr=0.0001, DropOut:0.2:

BatchSize:5000,Lr=0.0005, DropOut:0.5:

参考链接:

(164条消息) 在Paddle中利用AlexNet测试CIFAR10数据集合_卓晴的博客-CSDN博客

paddlepaddle 实现AlexNet模型,复现原创论文相关推荐

  1. COLING 2018 最佳论文解读:序列标注经典模型复现

    在碎片化阅读充斥眼球的时代,越来越少的人会去关注每篇论文背后的探索和思考. 在这个栏目里,你会快速 get 每篇精选论文的亮点和痛点,时刻紧跟 AI 前沿成果. 点击本文底部的「阅读原文」即刻加入社区 ...

  2. home credit default risk捷信消费金融违约风险模型复现(论文_毕业设计_作业)

    你能预测每个申请人偿还贷款的能力吗?由于信用记录不足或不存在,许多人难以获得贷款.而且,不幸的是,这些人经常被不可靠的贷方利用,例如高利贷,校园贷. 捷信努力为没有银行账户的人群扩大金融包容性.为了确 ...

  3. home credit default risk(捷信违约风险)机器学习模型复现(论文_毕业设计_作业)

    你能预测每个申请人偿还贷款的能力吗?由于信用记录不足或不存在,许多人难以获得贷款.而且,不幸的是,这些人经常被不可靠的贷方利用,例如高利贷,校园贷. 捷信努力为没有银行账户的人群扩大金融包容性.为了确 ...

  4. 《天池精准医疗大赛-人工智能辅助糖尿病遗传风险预测》模型复现和数据挖掘-论文_企业

    大赛概况 进入21世纪,生命科学特别是基因科技已经广泛而且深刻影响到每个人的健康生活,于此同时,科学家们借助基因科技史无前例的用一种全新的视角解读生命和探究疾病本质.人工智能(AI)能够处理分析海量医 ...

  5. 复现计算机论文模型,深度学习模型复现难?看看这篇句子对模型的复现论文

    原标题:深度学习模型复现难?看看这篇句子对模型的复现论文 在碎片化阅读充斥眼球的时代,越来越少的人会去关注每篇论文背后的探索和思考. 在这个栏目里,你会快速 get 每篇精选论文的亮点和痛点,时刻紧跟 ...

  6. 复现计算机论文模型,COLING 2018 最佳论文解读:序列标注经典模型复现

    原标题:COLING 2018 最佳论文解读:序列标注经典模型复现 在碎片化阅读充斥眼球的时代,越来越少的人会去关注每篇论文背后的探索和思考. 在这个栏目里,你会快速 get 每篇精选论文的亮点和痛点 ...

  7. 《天池精准医疗大赛-人工智能辅助糖尿病遗传风险预测》模型复现和数据挖掘-企业科研_论文作业

    大赛概况 进入21世纪,生命科学特别是基因科技已经广泛而且深刻影响到每个人的健康生活,于此同时,科学家们借助基因科技史无前例的用一种全新的视角解读生命和探究疾病本质.人工智能(AI)能够处理分析海量医 ...

  8. AlexNet网络复现

    AlexNet 学习流程 阅读AlexNet论文原文 搜集学习资源:视频讲解-博客资源 熟悉AlexNet网络结构 代码复现,清楚网络结构中层与层之间的操作 AlexNet论文 原论文:imagene ...

  9. 动手学深度学习(PyTorch实现)(八)--AlexNet模型

    AlexNet模型 1. AlexNet模型介绍 1.1 AlexNet的特点 1.2 AlexNet的结构 1.3 AlexNet参数数量 2. AlexNet的PyTorch实现 2.1 导入相应 ...

最新文章

  1. C语言求3x3数组对角线元素之和
  2. php读取部分文章显示不出来了,织梦使用PHP5.3环境时遇到部分文章出现”读取附加信息出错“的解决办法jz1...
  3. conda(pip) bad interpreter的解决办法
  4. php thumb 生成缩略图
  5. 怎么进bios设置硬盘启动顺序|电脑bios硬盘启动设置方法
  6. C++实现PCA变换
  7. JavaWeb开发框架——Spring
  8. [资讯]北京二套学区房奋斗目标
  9. laravel手册链接
  10. 化繁为简|华天软件参数化,将轴承设计变为数与数的组合
  11. 折线统计html,canvas制作简单的HTML图表,折线或者矩形统计(原创)
  12. 探真无阻塞加载javascript脚本技术,我们会发现很多意想不到的秘密
  13. 我学会了学计算机,我学会了电脑打字
  14. 2023秋招--梦加网络--游戏客户端--二面面经
  15. 淘宝卖家中心打开淘宝客推广网页空白
  16. 如何查看Steam的17位Id
  17. 【微机原理作业】8086存储器读写实验
  18. 着眼未来 巅峰对决 | “智算之道—2020人工智能应用挑战赛”圆满收官!
  19. 人民的名义泄漏版百度云46-56集百度网盘下载
  20. Game boy模拟器(5):集成

热门文章

  1. 绩效考核方面的书籍推荐:《绩效管理必读12篇》
  2. 制造企业如何通过APS智能排产进行生产计划规划?
  3. 新手小白H5微应用接入浙里办流程指南
  4. 竞争性传输函数:compet
  5. 微信支付接口调用之二维码失效时间的设置
  6. android 百度地图禁止双击放大缩小,百度地图API 在使用点聚合时,如果放大、缩小或移动地图时,添加的文字标签会消失...
  7. 真正解决layer弹层遮罩挡住窗体的问题
  8. 七巧板的制作(结合js 数组对象 for循环)
  9. 建筑间距对住房有什么影响
  10. org.quartz