mxnet.gluon 加载预训练
import mxnet as mx
from mxnet.gluon import nn
from mxnet import gluon,nd,autograd,init
from mxnet.gluon.data.vision import datasets,transforms
from IPython import display
import matplotlib.pyplot as plt
import time
import numpy as np#下载fashionMNIST数据集
fashion_train_data = datasets.FashionMNIST(train=True)
#获取图片数据和对应的标签
images,labels = fashion_train_data[:]#transforms链式转换数据
transformer = transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.13,0.31)])
#转换数据
fashion_data = fashion_train_data.transform_first(transformer)#设置batch的大小
batch_size = 256
#在windows系统上,请将num_workers设置为0,否则会导致线程错误
train_data = gluon.data.DataLoader(fashion_data,batch_size=batch_size,shuffle=True,num_workers=0)#加载验证数据
fashion_val_data = gluon.data.vision.FashionMNIST(train=False)
val_data = gluon.data.DataLoader(fashion_val_data.transform_first(transformer),batch_size=batch_size,num_workers=0)
#定义使用的GPU,使用GPU加速训练,如果有多个GPU,可以定义多个
gpu_devices = [mx.gpu(0)]
#定义网络结构
LeNet = nn.HybridSequential()
#构建一个LeNet的网络结构
LeNet.add(nn.Conv2D(channels=6,kernel_size=5,activation="relu"),nn.MaxPool2D(pool_size=2,strides=2),nn.Conv2D(channels=16,kernel_size=3,activation="relu"),nn.MaxPool2D(pool_size=2,strides=2),nn.Flatten(),nn.Dense(120,activation="relu"),nn.Dense(84,activation="relu"),nn.Dense(10)
)
LeNet.hybridize()
#初始化神经网络的权重参数,使用GPU来加速训练
LeNet.collect_params().initialize(force_reinit=True,ctx=gpu_devices)
#定义softmax损失函数
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
#设置优化算法,使用随机梯度下降sgd算法,学习率设置为0.1
trainer = gluon.Trainer(LeNet.collect_params(),"sgd",{"learning_rate":0.1})
#计算准确率
def acc(output,label):return (output.argmax(axis=1) == label.astype("float32")).mean().asscalar()#设置迭代的轮数
epochs = 10
#训练模型
for epoch in range(epochs):train_loss,train_acc,val_acc = 0,0,0epoch_start_time = time.time()for data,label in train_data:#使用GPU来加载数据加速训练data_list = gluon.utils.split_and_load(data,gpu_devices)label_list = gluon.utils.split_and_load(label,gpu_devices)#前向传播with autograd.record():#获取多个GPU上的预测结果pred_Y = [LeNet(x) for x in data_list]#计算多个GPU上预测值的损失losses = [softmax_cross_entropy(pred_y,Y) for pred_y,Y in zip(pred_Y,label_list)]#反向传播更新参数for l in losses:l.backward()trainer.step(batch_size)#计算训练集上的总损失train_loss += sum([l.sum().asscalar() for l in losses])#计算训练集上的准确率train_acc += sum([acc(output_y,y) for output_y,y in zip(pred_Y,label_list)])for data,label in val_data:data_list = gluon.utils.split_and_load(data,ctx_list=gpu_devices)label_list = gluon.utils.split_and_load(label,ctx_list=gpu_devices)#计算验证集上的准确率val_acc += sum(acc(LeNet(val_X),val_Y) for val_X,val_Y in zip(data_list,label_list))print("epoch %d,loss:%.3f,train acc:%.3f,test acc:%.3f,in %.1f sec"%(epoch+1,train_loss/len(labels),train_acc/len(train_data),val_acc/len(val_data),time.time()-epoch_start_time))
#保存模型参数
LeNet.export("lenet",epoch=1)#加载模型文件
LeNet = gluon.nn.SymbolBlock.imports("lenet-symbol.json",["data"],"lenet-0001.params")
mxnet.gluon 加载预训练相关推荐
- 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次都特别慢
欢迎大家关注笔者,你的关注是我持续更博的最大动力 原创文章,转载告知,盗版必究 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次 ...
- PyTorch 加载预训练权重
前言 使用PyTorch官方提供的权重或者其他第三方提供的权重对相同模型的参数进行初始化,在数据量较少的前提下,可以帮助模型更快地收敛到最优点,达到更好的效果,即迁移学习. 在大部分的迁移学习场景 ...
- torch编程-加载预训练权重-模型冻结-解耦-梯度不反传
1)加载预训练权重 net = torchvision.models.resnet50(pretrained=False) # 构建模型 pretrained_model = torch.load(p ...
- 实践:jieba分词和pkuseg分词、去除停用词、加载预训练词向量
一:jieba分词和pkuseg分词 原代码文件 链接:https://pan.baidu.com/s/1J8kmTFk8lec5ubfwBaSnLg 提取码:e4nv 目录: 1:分词介绍: 目标: ...
- paddlepaddle加载预训练词向量
文章目录 1.一些用到的api文档 2.加载预训练词向量 2.1小数据 2.2核心代码 2.3验证结果 3.可能有用的 tensorflow的加载方法可以看我之前写的: tensorflow加载词向量 ...
- 深度学习加载预训练权重好处
深度学习加载预训练权重好处: 在模型开始训练前,使模型参数得到一个好的初始化,对于后面的训练学习有非常大的帮助.
- mxnet加载预训练
关乎symbol和module的一些基本属性 # 查看json每一个op的属性:kernel size.padding.stride等 sym.attr_dict() # 返回一个字典,根据key获取 ...
- Pytorch加载预训练网络,替换分类层并重新训练
定义网络时,在网络类的构造函数网络结构定义中添加如下语句: for p in self.parameters():p.requires_grad = False 该语句的功能是固定定义在该语句之前的网 ...
- pytorch加载预训练 加载部分参数
最简单的: state_dict = torch.load(weight_path) self.load_state_dict(state_dict,strict=False) 加载cpu: m ...
最新文章
- “刷脸”之后 声纹识别有望成为新秀
- Flex中如何通过horizontalTickAligned和verticalTickAligned样式指定线图LineChart横竖方向轴心标记的例子...
- 【学术相关】翻倍!研究生招生规模持续扩张!
- go 17个字符串函数使用示例
- 用int还是用Integer?
- linux c99 可变长数组,C中不支持可变长度数组C99(Variable length arrays C99 not supported in C)...
- [译] 用 Swift 创建自定义的键盘
- 无线充qi协议c语言详解,无线充电Qi协议正向通信FSK的解调设计
- mapxtreme java_用mapXtreme Java开发web gis应用 (下)
- office2010 启动man_发现office2010启动挺慢的,各位一样吗
- 使用this.$refs.XXX修改某个元素样式并添加点击事件
- 计算机英语截短词,英语词汇构词法(Word Formation)——截短法
- 挑选电脑免费加密软件特别注意哪些?
- scrapy实战项目(简单的爬取知乎项目)
- 使用Sivarc使PLC程序标准化
- 关于计算机系调查问卷表,计算机系统调查问卷.xls
- 安全漏洞SCAP规范标准
- 字符串常见方法总结: 构造方法、静态方法、 其它方法
- 卡奴、车奴、房奴,你是哪种?
- 基于android记事本毕业论文,基于Android的记事本应用的设计与实现-毕业论文.doc...
热门文章
- linux shell 2 /dev/null的解释
- Linux2.6内核--VFS层中和进程相关的数据结构
- 《UNIX环境高级编程》--符号链接
- Android 开发中的多线程编程技术
- Ubuntu 无线密码破解利器aircrack-ng
- c++ 取成员函数地址_c及c++指针及引用简单解释(自学学习心得)
- 打印机尚未链接到此计算机,win10系统无法连接打印机显示未指定设备的解决方法...
- NeHe教程Qt实现——lesson07
- 定义一个不能被拷贝的类
- wpf listview mysql_Kivy:使用MySQL的Kivy页面的Listview实现