需求说明

  1. 训练过程在 “HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务” 中已经描述过。
  2. 训练结束后,会生成如下的 checkpoints 文件:
  3. 现在想用 checkpoint-500 中保存的模型进行预测,看它在测试集上的效果怎么样,即损失值是多少。

需求解决

关键代码

  • 调用 Trainer 的 predict 方法,参数传入测试集 Dataset
    关于 predict 的更多用法可以参考官方文档:https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Trainer.predict
from dataset import GazeCaptureDataset
from transformers import TrainingArguments
from transformers import DeiTForImageClassification
from torch import nn
from transformers import Trainer
from transformers import DeiTConfig# 数据集根路径
root_path = r"D:\datasets\GazeCapture_new"
# 1. 定义 Dataset
test_dataset = GazeCaptureDataset(root_path, data_type='test')# 2. 定义 DeiT 图像模型
configuration = DeiTConfig(num_labels=2, problem_type="regression")
model = DeiTForImageClassification(configuration).from_pretrained('gaze_trainer/checkpoint-500')# 3. 测试
## 3.1 定义测试参数
testing_args = TrainingArguments(output_dir="pred_trainer")## 3.2 自定义 Trainer
class CustomTester(Trainer):# 重写计算 loss 的函数def compute_loss(self, model, inputs, return_outputs=False):# 获取标签值labels = inputs.get("labels")# 获取输入值x = inputs.get("pixel_values")# 模型输出值outputs = model(x)logits = outputs.get('logits')# 定义损失函数为平滑 L1 损失loss_fct = nn.SmoothL1Loss()# 计算输出值和标签的损失loss = loss_fct(logits, labels)return (loss, outputs) if return_outputs else loss## 3.3 定义 Trainer 对象
tester = CustomTester(model=model,args=testing_args,
)## 3.4 调用 predict 方法,开始测试
output = tester.predict(test_dataset=test_dataset)# 4. 测试结果
print(output)

Dataset

dataset.py 代码如下:

import os.pathfrom torch.utils.data import Dataset
from transform import transform
import numpy as np# 读取数据,如果是训练数据,随即打乱数据顺序
def get_label_list(label_path):# 存储所有标签文件中的所有内容full_lines = []# 获取所有标签文件的名称,如 00002.label, 00003.label, ......label_names = os.listdir(label_path)# 遍历每一个标签文件,并读取其中内容for label_name in label_names:# 标签文件全路径,如 D:\datasets\GazeCapture_new\Label\train\00002.labellabel_abs_path = os.path.join(label_path, label_name)# 读取每一个标签文件中的内容with open(label_abs_path) as flist:# 存储该标签文件中的所有内容full_line = []for line in flist:full_line.append(line.strip())# 移除首行表头 'Face Left Right Grid Xcam, Ycam Xdot, Ydot Device'full_line.pop(0)full_lines.extend(full_line)return full_linesclass GazeCaptureDataset(Dataset):def __init__(self, root_path, data_type):self.data_dir = root_path# 标签文件的根路径,如 D:\datasets\GazeCapture_new\Label\trainlabel_root_path = os.path.join(root_path + '/Label', data_type)# 获取所有标签文件中的所有内容self.full_lines = get_label_list(label_root_path)# 每一行内容的分隔符self.delimiter = ' '# 数据集长度,也就是一共有多少个图片self.num_samples = len(self.full_lines)def __len__(self):return self.num_samplesdef __getitem__(self, idx):# 标签文件的一行,对应一个训练实例line = self.full_lines[idx]# 将标签文件中的一行内容按照分隔符进行分割Face, Left, Right, Grid, XYcam, XYdot, Device = line.split(self.delimiter)# 获取网络的输入:人脸图片face_path = os.path.join(self.data_dir + '/Image/', Face)# 读取人脸图像with open(face_path, 'rb') as f:img = f.read()# 将人脸图像进行格式转化:缩放、裁剪、标准化pixel_values = transform(img)# 获取标签值labels = np.array(XYcam.split(","), np.float32)# 注意返回值的形式一定要是 {"labels": xxx, "pixel_values": xxx}result = {"labels": labels}result["pixel_values"] = pixel_valuesreturn result

输出结果如下:

***** Running Prediction *****Num examples = 1716Batch size = 8
100%|██████████| 215/215 [01:52<00:00,  1.90it/s]
PredictionOutput(predictions=array([[-2.309026 , -2.752627 ],[-2.0178156, -3.0546618],[-1.8222798, -3.309564 ],...,[-2.6463585, -2.3462727],[-2.2149038, -2.7406967],[-1.7267275, -3.3450181]], dtype=float32), label_ids=array([[ 0.969375, -7.525975],[ 0.969375, -7.525975],[ 0.969375, -7.525975],...,[ 5.5845  ,  1.93875 ],[ 5.5845  ,  1.93875 ],[ 5.5845  ,  1.93875 ]], dtype=float32), metrics={'test_loss': 2.8067691326141357, 'test_runtime': 118.2811, 'test_samples_per_second': 14.508, 'test_steps_per_second': 1.818})

可以看到该模型在测试集的损失值是 2.8067691326141357

Hugging(transformers)读取自定义 checkpoint、使用 Trainer 进行测试回归任务相关推荐

  1. 使用QSettings保存和读取自定义数据类型

    使用QSettings保存和读取自定义数据类型 Date Author Version Note 2021.02.02 Dog Tao V1.0 完成文档的撰写. 文章目录 使用QSettings保存 ...

  2. 绝地求生自定义服务器租一天多少钱,绝地求生自定义服务器不限号测试_绝地求生自定义服务器设置方法_游戏吧...

    绝地求生官方宣布将要向玩家开放自定义服务器的不限号测试,那么大家怎么参与呢,下面游戏吧小编为打击带来绝地求生自定义服务器不限号测试参与方式介绍. 自定义服务器不限号测试公告 玩家们大家好, 今年初,我 ...

  3. 吃鸡买个自定义服务器,绝地求生怎么玩自定义服务器_自定义服务器不限号测试什么时候开始...

    刚刚,<绝地求生>官方发布了"自定义服务器不限号测试公告"从本周开始,即可参与测试活动.同时对自定义服务器的功能界面进行了更新,具体内容如下,一起来看看吧. 今年初,我 ...

  4. HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务

    资料 Hugging Face 官方文档:https://huggingface.co/ Hugging Face 代码链接:https://github.com/huggingface/transf ...

  5. 编写transformers的自定义pytorch训练循环(Dataset和DataLoader解析和实例代码)

    文章目录 一.Dataset和DataLoader加载数据集 1.torch.utils.data 2. 加载数据流程 3. Dataset 4. dataloader类及其参数 5. dataloa ...

  6. C#中采用自定义方式读取自定义配置文件

    在C#中读取配置文件的时候,我们通常都是通过ConfigurationManager类来直接获取配置文件中的相关数据的,但是如果我们不想读取项目的默认配置文件App.config,则很不爽,所以经过了 ...

  7. 关于Rust读取自定义toml文件

    参考链接:https://github.com/baoyachi/read-toml 思路,关于自定义的toml文件应该有一定配置规则,不清楚或者不了解toml文件配置的点击这个链接 知道toml的使 ...

  8. spark通过实现FileFormat方式读取自定义文件格式

    Spark内部已经实现了很多常用数据源的适配,对于不支持的自定义的数据源,也提供了相应的接口.最近刚好遇到需要通过Spark读取在HDFS上的自定义文件格式的需求,网上找到的很多资料都以实现 Crea ...

  9. NPOI读取自定义的颜色bug修复方案

    有些自定义的颜色,npoi无法读取,需要手动写方案实现,分别是在打开wrokbook的时候,加上你自定义的颜色,然后在读取excel内容的时候通过一些特征判断是否是自定义的颜色,改变其index的值, ...

最新文章

  1. JAVA作业 02 JAVA语法基础
  2. java mysql自动备份_java定时备份数据之二_MySQL
  3. pano2vr怎么制作漫游_春节7天长假,在云南怎么玩?
  4. Gnuplot 简单使用
  5. 幅度调制(AM调制、DSB(双边带)调制、SSB、VSB)
  6. 应对亚洲劳动力成本不断上升的挑战
  7. 【Python - wxpython】- 卫星通信系统链路计算软件
  8. vs2005的MSDN的下载
  9. 94G的kindle电子书btsync分享
  10. mysql now()函数调用系统时间不对修正方法
  11. 2022 华东师范大学 数据学院复试机考
  12. pure-ftpd安装与使用
  13. 红米note3android5.0,小米红米note3MOSBeta5.0安卓8.1.0来去电归属农历等本地化增强适配...
  14. Docker 容器安装监控软件 cAdvisor
  15. 计量经济学笔记6-Eviews操作-自相关的检验与消除(DW、LM检验与FGLS、广义差分变换)
  16. python判断用户名密码是否正确_python实现用户名密码校验
  17. 暑假学习 Python爬虫基础(4)
  18. 清华系VS浙大系 谁才是国内区块链领域的“黄埔军校”?
  19. 关于域名评级标准【PoSEO等级】
  20. 1397 火车票退票费计算(函数专题)

热门文章

  1. 华为 Mate X3折叠屏手机 参数配置 华为 Mate X3评测怎么样
  2. Windows11运行App
  3. Typora+PicGo+GitHub图床的搭建和常见问题的个人解决方法
  4. PicGo + GitHub + Typora 搭建个人图床工具
  5. 17计及电转气协同的含碳捕集与垃圾焚烧虚拟电厂优化调度
  6. 愿每一个人的创造力都能被激发:剪映Mac版来了!
  7. OpenCV4学习笔记(41)——ORB特征提取描述算法
  8. 从零开始的python爬虫速成指南
  9. 客厅装修要不要用集成墙面?
  10. Excel 2016崩溃恢复后数字变成时间格式