李宏毅ML2021 HW3创建pseudo dataset
在cnn训练中使用semi-supervised扩充数据
输入和输出
输入:CNN model,标签不准的Dataset,过滤阈值
输出:经过阈值筛选后的Pseudo Dataset
解决思路
- 获取满足筛选条件的数据的indice,用于生成筛选后的Dataset和pseudo label
- 新建一个自定义Dataset类
import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset, DataLoader, Subset,Dataset
'''
定义一个Dataset类
包含__init__()和__getitem__()方法
'''# inherit Dataset
class pseudo_dataset(Dataset):def __init__(self,unlabeled_set, indices, pseudo_labels):self.data = Subset(unlabeled_set,indices)self.target = torch.LongTensor(pseudo_labels)[indices]def __getitem__(self,index):if index < 0 : #Handle negative indicesindex += len(self)if index >= len(self):raise IndexError("index %d is out of bounds for axis 0 with size %d"%(index, len(self)))x = self.data[index][0]y = self.target[index].item() return x,y def get_pseudo_labels(dataset, model, threshold=0.65):# This functions generates pseudo-labels of a dataset using given model.# It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.# You are NOT allowed to use any models trained on external data for pseudo-labeling.device = "cuda" if torch.cuda.is_available() else "cpu"# Construct a data loader.data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)# Make sure the model is in eval mode.model.eval()# Define softmax function.softmax = nn.Softmax(dim=-1)#recorde the filtered resultmasks = []# recorde the predicted labelspred_labels = []# Iterate over the dataset by batches.for batch in tqdm(data_loader):img, _ = batch# Forward the data# Using torch.no_grad() accelerates the forward process.with torch.no_grad():logits = model(img.to(device))# Obtain the probability distributions by applying softmax on logits.probs = softmax(logits)# ---------- TODO ----------# Filter the data and construct a new dataset.pred_label = probs.argmax(dim=-1).tolist()pred_labels.extend(pred_label)mask = torch.max(probs,dim=1)[0] > thresholdmasks.extend(mask)indices = torch.arange(0,dataset.length)[masks] # lenpseudo_dataset = pseudo_dataset(dataset,indices, pseudo_labels)print('using {0:.2f}% unlabeld data'.format(100 * len(pseudo_dataset) / len(dataset)))# # Turn off the eval mode.model.train()return pseudo_dataset
李宏毅ML2021 HW3创建pseudo dataset相关推荐
- 李宏毅机器学习hw3
Homework 3 - Convolutional Neural Network 本文是对课程作业代码范例的复现,但也写了一些自己的理解和期间遇到的问题,如有写的不对的地方欢迎各位大佬指正.问题的解 ...
- Spark _24 _读取JDBC中的数据创建DataFrame/DataSet(MySql为例)(三)
两种方式创建DataSet 现在数据库中创建表不能给插入少量数据. javaapi: package SparkSql;import org.apache.spark.SparkConf; impor ...
- 李宏毅ML2021 HW7 BERT-Question Answering
参考代码:Colab 作业PPT: slide 作业所需的数据:data 作业说明:video 作业提交评分:kaggle 目录 1. 作业任务描述 1.1 用BERT做QA的基本原理 1.2 数据描 ...
- 李宏毅2021 HW3
第一次 样例代码执行 在训练过程中能够明显感到过拟合现象,train的accuracy达到99%而valid才0.5左右 第二次 添加了数据增强 train_tfm = transforms.Comp ...
- 2022李宏毅作业hw4 - 挫败感十足的一次作业。
系列文章: 2022李宏毅作业hw1-新冠阳性人员数量预测._亮子李的博客-CSDN博客_李宏毅hw1 hw-2 李宏毅2022年作业2 phoneme识别 单strong-hmm详细解释._亮子李的 ...
- 如何使用TensorFlow中的Dataset API
翻译 | AI科技大本营 参与 | zzq 审校 | reason_W 本文已更新至TensorFlow1.5版本 我们知道,在TensorFlow中可以使用feed-dict的方式输入数据信息,但是 ...
- C#-DataSet和DataTable详解
1.创建DataSet对象: DataSet ds = new DataSet("DataSetName"); 2.查看调用SqlDataAdapter.Fill创建的结构 da. ...
- c# mysql fill_C#里sqlDataAdapter.fill(DataSet,String)的用法
第二个参数 String是指定DataSet 里表的名字,例如 sqlDataAdapter.fill(DataSet,"学生表") 指定后,以后就可以这样调用这张表 DataSe ...
- 数据库-ADONET-使用强类型DataSet
使用强类型DataSet对象 使用ADONET访问DataSet内容的方式,与使用ADO和DAO的Recordset对象具有类似的编程格式. l ADONET和VBNET txtCompanyName ...
- 一个无法捕获ADO.NET Dataset的内存错误
Dataset是ADO.NET在内存保存数据所用的新结构.在某些方面上,Dataset和ADO的Recordset对象相似:不过,Dataset可以把整个schema(包括table.关系.关键字连同 ...
最新文章
- Blender着色器纹理材质创作教程含源文件 Shader Forge
- java获得单元格的值_java – 从单元格值Apache POI获取单元格索引
- Excel 技术篇-解决“单元格不能自动适应大小“问题
- 160个CrackMe005
- 用libevent实现简易的telnet
- 2018刑侦推理 java_2018年刑侦科目推理试题
- ISA server的常见身份验证方式
- Spark数据分析及处理_ELT
- 【蓝桥杯单片机】矩阵键盘和独立键盘新解(更稳定更高复用性)
- ubuntu12.04 安装配置jdk1.7
- html 行自动对齐,html – 行元素不会对齐
- Raphael的set使用
- python中的shallow copy 和 deep copy
- python Copula 模型实现
- 如何将图片调为半色调_为什么我们喜欢粗糙的唱片,半色调网点和其他缺陷?
- postgresql 客户端 uri 设置时区
- 分布式数据库中间件—TDDL
- 我花了一夜用数据结构给女朋友写个H5走迷宫游戏
- 关于泰勒展开的细节-《三体》读后感的读后感...
- 电脑显示黑屏但是鼠标能动怎么处理?