在cnn训练中使用semi-supervised扩充数据

输入和输出

输入:CNN model,标签不准的Dataset,过滤阈值
输出:经过阈值筛选后的Pseudo Dataset

解决思路

  1. 获取满足筛选条件的数据的indice,用于生成筛选后的Dataset和pseudo label
  2. 新建一个自定义Dataset类
import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset, DataLoader, Subset,Dataset
'''
定义一个Dataset类
包含__init__()和__getitem__()方法
'''# inherit Dataset
class pseudo_dataset(Dataset):def __init__(self,unlabeled_set, indices, pseudo_labels):self.data = Subset(unlabeled_set,indices)self.target = torch.LongTensor(pseudo_labels)[indices]def __getitem__(self,index):if index < 0 : #Handle negative indicesindex += len(self)if index >= len(self):raise IndexError("index %d is out of bounds for axis 0 with size %d"%(index, len(self)))x = self.data[index][0]y = self.target[index].item()  return x,y    def get_pseudo_labels(dataset, model, threshold=0.65):# This functions generates pseudo-labels of a dataset using given model.# It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.# You are NOT allowed to use any models trained on external data for pseudo-labeling.device = "cuda" if torch.cuda.is_available() else "cpu"# Construct a data loader.data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)# Make sure the model is in eval mode.model.eval()# Define softmax function.softmax = nn.Softmax(dim=-1)#recorde the filtered resultmasks = []# recorde the predicted labelspred_labels = []# Iterate over the dataset by batches.for batch in tqdm(data_loader):img, _ = batch# Forward the data# Using torch.no_grad() accelerates the forward process.with torch.no_grad():logits = model(img.to(device))# Obtain the probability distributions by applying softmax on logits.probs = softmax(logits)# ---------- TODO ----------# Filter the data and construct a new dataset.pred_label = probs.argmax(dim=-1).tolist()pred_labels.extend(pred_label)mask = torch.max(probs,dim=1)[0] > thresholdmasks.extend(mask)indices = torch.arange(0,dataset.length)[masks]  # lenpseudo_dataset = pseudo_dataset(dataset,indices, pseudo_labels)print('using {0:.2f}% unlabeld data'.format(100 * len(pseudo_dataset) / len(dataset)))# # Turn off the eval mode.model.train()return pseudo_dataset

李宏毅ML2021 HW3创建pseudo dataset相关推荐

  1. 李宏毅机器学习hw3

    Homework 3 - Convolutional Neural Network 本文是对课程作业代码范例的复现,但也写了一些自己的理解和期间遇到的问题,如有写的不对的地方欢迎各位大佬指正.问题的解 ...

  2. Spark _24 _读取JDBC中的数据创建DataFrame/DataSet(MySql为例)(三)

    两种方式创建DataSet 现在数据库中创建表不能给插入少量数据. javaapi: package SparkSql;import org.apache.spark.SparkConf; impor ...

  3. 李宏毅ML2021 HW7 BERT-Question Answering

    参考代码:Colab 作业PPT: slide 作业所需的数据:data 作业说明:video 作业提交评分:kaggle 目录 1. 作业任务描述 1.1 用BERT做QA的基本原理 1.2 数据描 ...

  4. 李宏毅2021 HW3

    第一次 样例代码执行 在训练过程中能够明显感到过拟合现象,train的accuracy达到99%而valid才0.5左右 第二次 添加了数据增强 train_tfm = transforms.Comp ...

  5. 2022李宏毅作业hw4 - 挫败感十足的一次作业。

    系列文章: 2022李宏毅作业hw1-新冠阳性人员数量预测._亮子李的博客-CSDN博客_李宏毅hw1 hw-2 李宏毅2022年作业2 phoneme识别 单strong-hmm详细解释._亮子李的 ...

  6. 如何使用TensorFlow中的Dataset API

    翻译 | AI科技大本营 参与 | zzq 审校 | reason_W 本文已更新至TensorFlow1.5版本 我们知道,在TensorFlow中可以使用feed-dict的方式输入数据信息,但是 ...

  7. C#-DataSet和DataTable详解

    1.创建DataSet对象: DataSet ds = new DataSet("DataSetName"); 2.查看调用SqlDataAdapter.Fill创建的结构 da. ...

  8. c# mysql fill_C#里sqlDataAdapter.fill(DataSet,String)的用法

    第二个参数 String是指定DataSet 里表的名字,例如 sqlDataAdapter.fill(DataSet,"学生表") 指定后,以后就可以这样调用这张表 DataSe ...

  9. 数据库-ADONET-使用强类型DataSet

    使用强类型DataSet对象 使用ADONET访问DataSet内容的方式,与使用ADO和DAO的Recordset对象具有类似的编程格式. l ADONET和VBNET txtCompanyName ...

  10. 一个无法捕获ADO.NET Dataset的内存错误

    Dataset是ADO.NET在内存保存数据所用的新结构.在某些方面上,Dataset和ADO的Recordset对象相似:不过,Dataset可以把整个schema(包括table.关系.关键字连同 ...

最新文章

  1. Blender着色器纹理材质创作教程含源文件 Shader Forge
  2. java获得单元格的值_java – 从单元格值Apache POI获取单元格索引
  3. Excel 技术篇-解决“单元格不能自动适应大小“问题
  4. 160个CrackMe005
  5. 用libevent实现简易的telnet
  6. 2018刑侦推理 java_2018年刑侦科目推理试题
  7. ISA server的常见身份验证方式
  8. Spark数据分析及处理_ELT
  9. 【蓝桥杯单片机】矩阵键盘和独立键盘新解(更稳定更高复用性)
  10. ubuntu12.04 安装配置jdk1.7
  11. html 行自动对齐,html – 行元素不会对齐
  12. Raphael的set使用
  13. python中的shallow copy 和 deep copy
  14. python Copula 模型实现
  15. 如何将图片调为半色调_为什么我们喜欢粗糙的唱片,半色调网点和其他缺陷?
  16. postgresql 客户端 uri 设置时区
  17. 分布式数据库中间件—TDDL
  18. 我花了一夜用数据结构给女朋友写个H5走迷宫游戏
  19. 关于泰勒展开的细节-《三体》读后感的读后感...
  20. 电脑显示黑屏但是鼠标能动怎么处理?

热门文章

  1. C++ 面试实战经验分享
  2. Fedora23安装有道词典
  3. 数论复习之费马与欧拉
  4. response.getStatusCode()==200什么意思
  5. shell用户输入数字加法操作
  6. 查看远程计算机ip地址吗,我的电脑跟别人远程过可不可以查对方IP地址
  7. 2 SAP SCC1同一服务器传请求号
  8. 【Android安全】手机Root、刷机、救砖常用命令
  9. 微信开发--IOS微信端confirm以及alert去掉网址的方法
  10. css实现遮罩层动画