方法一:直接在epoch过程中求取准确率

简介:此段代码是LeNet5中截取的。

def train_model(model,train_loader):

optimizer = torch.optim.Adam(model.parameters())

loss_func = nn.CrossEntropyLoss()

EPOCHS = 5

for epoch in range(EPOCHS):

correct = 0

for batch_idx,(X_batch,y_batch) in enumerate(train_loader):

optimizer.zero_grad()

#这里是只取训练数据的意思吗,X_batch和y_batch是怎么分开的?

#答:X_batch和y_batch是一一对应的,只不过顺序打乱了,参考torch.utils.data.ipynb

output = model(X_batch.float()) #X_batch.float()是什么意思

loss = loss_func(output,y_batch)

loss.backward()

optimizer.step()

# Total correct predictions

#第一个1代表取每行的最大值,第二个1代表只取最大值的索引

#这两行代码是求准确率的地方

predicted = torch.max(output.data,1)[1]

correct += (predicted == y_batch).sum()

#print(correct)

if batch_idx % 100 == 0:

print('Epoch :{}[{}/{}({:.0f}%)]\t Loss:{:.6f}\t Accuracy:{:.3f}'.format(epoch,batch_idx * len(X_batch),len(train_loader.dataset),100.*batch_idx / len(train_loader),loss.data.item(),float(correct*100)/float(BATCH_SIZE)*(batch_idx+1)))

if __name__ == '__main__':

myModel = LeNet5()

print(myModel)

train_model(myModel,train_loader)

evaluate(myModel,test_loader,BATCH_SIZE)

方法二:构建函数,然后在epoch中调用该函数

简介:此段代码是对Titanic(泰坦尼克号)数据分析截取。

epochs = 10

log_step_freq = 30

dfhistory = pd.DataFrame(columns = ['epoch','loss',metric_name,'val_loss','val_'+metric_name])

print('Start Training...')

nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')

print('========='*8 + '%s'%nowtime)

for epoch in range(1,epochs+1):

#1.训练循环

net.train()

loss_sum = 0.0

metric_sum = 0.0

step = 1

for step,(features,labels) in enumerate(dl_train,1):

#梯度清零

optimizer.zero_grad()

#正向传播求损失

predictions = net(features)

loss = loss_func(predictions,labels)

metric = metric_func(predictions,labels)

#反向传播求梯度

loss.backward()

optimizer.step()

#打印batch级别日志

loss_sum += loss.item()

metric_sum += metric.item()

if step%log_step_freq == 0:

print(('[Step = %d] loss: %.3f,' + metric_name+': %.3f %%')%(step,loss_sum/step,100*metric_sum/step))

#2,验证循环

net.eval()

val_loss_sum = 0.0

val_metric_sum = 0.0

val_step =1

for val_step,(features,labels) in enumerate(dl_valid,1):

#关闭梯度计算

with torch.no_grad():

pred = net(features)

val_loss = loss_func(pred,labels)

val_metric = metric_func(labels,pred)

val_loss_sum += val_loss.item()

val_metric_sum += val_metric.item()

#3,记录日志

info = (epoch,loss_sum/step,100*metric_sum/step,

val_loss_sum/val_step,100*val_metric_sum/val_step)

dfhistory.loc[epoch-1] = info

#打印epoch级别日志

print(('\nEPOCH = %d,loss = %.3f,' + metric_name+\

'=%.3f %%,val_loss = %.3f'+' val_'+metric_name+'= %.3f %%')%info)

nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')

print('\n'+'=========='*8 + '%s'%nowtime)

print('Finishing Training...')

用python计算准确率_Python中计算模型精度的几种方法,Pytorch,中求,准确率相关推荐

  1. python随机数权重_Python实现基于权重的随机数2种方法

    问题: 例如我们要选从不同省份选取一个号码,每个省份的权重不一样,直接选随机数肯定是不行的了,就需要一个模型来解决这个问题. 简化成下面的问题: 字典的key代表是省份,value代表的是权重,我们现 ...

  2. Linux中增加软路由的两种方法,Linux中增加软路由的三种方法

    # route add –net IP netmask MASK eth0 # route add –net IP netmask MASK gw IP # route add –net IP/24 ...

  3. php中的数组有哪几种方法,PHP中常用的遍历数组方法有几种,分别是什么?( )...

    PHP中常用的遍历数组方法有几种,分别是什么?( ) 更多相关问题 序列对心电触发的原理叙述,正确的是()A.是利用心电图的R波触发采集MR信号B.是利用心电图的T波触 静脉输血法的评价 有关急性梗阻 ...

  4. html整体页面缩放的方法,html5中让页面缩放的4种方法

    1.viewport 这种方法,不是所有的浏览器都兼容 2.百分比 这种方法,可以兼容大部分浏览器,但是修改幅度比较大 .main .login .txt1{margin-top:8.59375%; ...

  5. python可以实现哪些功能_Python中实现机器学习功能的四种方法介绍

    本篇文章给大家带来的内容是关于Python中实现机器学习功能的四种方法介绍,有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助. 在本文中,我们将介绍从数据集中选择要素的不同方法; 并使用S ...

  6. python怎么清屏_python实现清屏的方法 Python Shell中清屏一般有两种方法。

    Python Shell 怎样清屏? Python Shell中清屏一般有两种方法. 奈何一个人随着年龄增长,梦想便不复轻盈:他开始用双手掂量生活,更看重果实而非花朵.--叶芝<凯尔特的搏暮&g ...

  7. python list去重函数_python中对list去重的几种方法

    这篇文章主要介绍了python中对list去重的多种方法,现在分享给大家,需要的朋友可以参考下 今天遇到一个问题,在同事随意的提示下,用了 itertools.groupby 这个函数.不过这个东西最 ...

  8. python中字符串怎么引用_Python:字符串中引用外部变量的3种方法

    方法一: username=input('username:') age=input('age:') job=input('job:') salary=input('salary') info1='' ...

  9. arcgis用python字段自动编号,arcgis中字段自动编号的两种方法

    <arcgis中字段自动编号的两种方法>由会员分享,可在线阅读,更多相关<arcgis中字段自动编号的两种方法(4页珍藏版)>请在人人文库网上搜索. 1.精选文档关于ARCGI ...

最新文章

  1. 利用编码特长,我赚取了每月1000美元的额外收入
  2. 使用initramfs启动Linux成功
  3. 【脚下有根】之Skia库的matrix代码解读
  4. C#将Excel数据表导入SQL数据库的两种方法(转)
  5. Java - 死锁 Dead Lock 定位分析
  6. WordPress的varnish内存缓存方案
  7. tf.unstack\tf.unstack
  8. s5-11 距离矢量路由选择协议
  9. C与java通讯小结
  10. 浙大计算机基础知识题1,浙大作业1计算机基础知识题.docx
  11. java虚拟机内存模型与垃圾回收知识复习总结
  12. HTML5手机游戏将迎美好未来 .
  13. 转为字符数组_py字符打印照片
  14. asp.net 微信小程序源码 微信分销源码 源文件完全开源 源码
  15. Apple Watch Ultra和Apple Watch Series 8 区别 续航 功能介绍
  16. AQSW公司OA系统需求分析
  17. 第一章 集总参数电路中电压、电流的约束关系
  18. pdf压缩工具_18MB秒变1MB,最好用的PDF在线压缩工具
  19. 【C++复习总结回顾】—— 【一】基础知识+字符串/string类
  20. 2020年INTERNET考试题

热门文章

  1. uva1511(找规律。。。)
  2. hdu3694(四边形的费马点)
  3. bzoj 1179 抢掠计划atm (缩点+有向无环图DP)
  4. Oracle 安装完怎么用,oracle 11g 安装完怎么用
  5. python中的utils模块_使用Python的package机制如何简化utils包设计详解
  6. 导航模块自带的rtk算法_这款百元国产RTK板卡要改变高精度定位市场格局吗?
  7. PKUWC2019游记WC2019游记
  8. [Design Pattern] 抽象工厂模式
  9. 【设计模式】——工厂方法FactoryMethod
  10. Linux sftp用法