事件起源

最近在研究GNN,看了些许GNN的东西,心想着光看不练门外汉啊!这可不行,于是我开始自己动手实现一个GCN识别,一想到整一个模型demo,那必少不了MINIST数据集,反正就移花接木大法(MINIST可能会想,我这么忙,真的屑屑你!)一开始想着自己整一个,但是还没开始我就陷入了沉思,MINIST一个图片数据怎么去变成图边数据,于是万能的百度指导我进入了人均星星星的知乎,在那里,我发现有个日本小帅哥(后文我都讲日帅)已经做过了,那不说了,作为能cv绝不手写的浑水摸鱼星人,开始了偶尔心血来潮的扒代码历程。
原代码传送门

扒代码历程

1. 图片数据变成节点和边

怎么把图片数据变成节点和边?我想了很久,看了日帅的代码(为啥是看代码,因为日文我看不懂(o^^o))我豁然开朗,其把每张图片每个像素点想成Node,其邻近关系考虑为边,具体思路如下(以3*3的数据举例):
第一步:阈值过滤(为啥这样叫,因为我喜欢)

通过设定一个阈值k,源代码为102,我们这里设置为2(不知道为啥选这个,有知道的嘛?),将小于k的变成-1,反之为1000;

第二步:padding(为啥这样叫,大家都这么叫.~.)
源代码将padding_width定为2,其实我想了一下1是不是也可以;

第三步:得到Nodes和Edge的信息
将array中的非-1标记为节点k,k=0,1,2,3…

最后保存每个节点的坐标作为Node feature,以及边信息,例如3节点坐标为(1,1)边为[(3,1),(3,2),(3,4),(3,5),(3,6)],至此图片数据变成图数据(日帅给我的启发很大的,回头我继续思考一下)。代码添加注释如下:

import gzip
import numpy as npdata = 0# 读取gzip图片数据,转换图片格式
with gzip.open('data/train-images-idx3-ubyte.gz', 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape([-1, 28, 28])# 把28*28的数据中<102变成1,大于变成1000,为啥取102我也不知道,嘻嘻嘻。
data = np.where(data < 102, -1, 1000)
for e,imgtmp in enumerate(data):# 数组padding,其实我在考虑做padding为1是不是也行img = np.pad(imgtmp, [(2, 2), (2, 2)], "constant", constant_values=(-1))# node标记 0,1,2,3,4.......cnt = 0for i in range(2, 30):for j in range(2, 30):if img[i][j] == 1000:img[i][j] = cntcnt += 1# 记录边和节点信息edges = []nodes = np.zeros((cnt, 2))for i in range(2, 30):for j in range(2, 30):if img[i][j] == -1:continuefilter = img[i - 2:i + 3, j - 2:j + 3].flatten()# Node的八个方位filter1 = filter[[6, 7, 8, 11, 13, 16, 17, 18]]# 记录节点的坐标nodes[filter[12]][0] = i - 2nodes[filter[12]][1] = j - 2# 记录边for tmp in filter1:if not tmp == -1:edges.append([filter[12], tmp])# 保存节点数据和边数据np.save("data/graphs/" + str(e), edges)np.save("data/node_features/" + str(e),nodes)

2. 模型训练

这部分的话,就不过细讲了(主要是我也过细讲不了),大致分为三个部分:
第一部分:加载数据
加载labels和处理好的Nodes以及edge数据,也就是说自建数据集,代码如下:

def load_mnist_graph(data_size=60000):# 获取数据主函数data_list = []labels = 0with gzip.open('data/train-labels-idx1-ubyte.gz', 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)for i in range(data_size):edge = torch.tensor(np.load('data/graphs/' + str(i) + '.npy').T, dtype=torch.long)x = torch.tensor(np.load('data/node_features/' + str(i) + '.npy') / 28, dtype=torch.float)# 构建数据集d = Data(x=x, edge_index=edge.contiguous(), t=int(labels[i]))data_list.append(d)if i % 1000 == 999:print("\rData loaded " + str(i + 1), end="  ")print("Complete!")return data_list

第二部分:定义网络
按照自己的喜好定义就行,毕竟我电脑cpu那点算力,还不支持我随心所欲的训练,M1早点出GPU版吧,孩子顶不住了(>﹏<),这里就按照日帅的来吧!

# 定义网络结构
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = GCNConv(2, 16)self.conv2 = GCNConv(16, 32)self.conv3 = GCNConv(32, 48)self.conv4 = GCNConv(48, 64)self.conv5 = GCNConv(64, 96)self.conv6 = GCNConv(96, 128)self.linear1 = torch.nn.Linear(128,64)self.linear2 = torch.nn.Linear(64,10)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)x = F.relu(x)x = self.conv3(x, edge_index)x = F.relu(x)x = self.conv4(x, edge_index)x = F.relu(x)x = self.conv5(x, edge_index)x = F.relu(x)x = self.conv6(x, edge_index)x = F.relu(x)x, _ = scatter_max(x, data.batch, dim=0)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)return x

第三部分:训练主函数
训练部分的参数,可以按照自己电脑的算力以及结果定,这里不做过多修改,cv浑水摸鱼星人只是觉得日帅写的很棒(o^^o)!

def main():# 训练主程序data_size = 60000train_size = 50000batch_size = 100epoch_num = 150# 数据获取mnist_list = load_mnist_graph(data_size=data_size)device = torch.device('cpu')model = Net().to(device)trainset = mnist_list[:train_size]optimizer = torch.optim.Adam(model.parameters())trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)testset = mnist_list[train_size:]testloader = DataLoader(testset, batch_size=batch_size)criterion = nn.CrossEntropyLoss()history = {"train_loss": [],"test_loss": [],"test_acc": []}print("Start Train")# 训练部分model.train()for epoch in range(epoch_num):train_loss = 0.0for i, batch in enumerate(trainloader):batch = batch.to("cpu")optimizer.zero_grad()outputs = model(batch)loss = criterion(outputs, batch.t)loss.backward()optimizer.step()train_loss += loss.cpu().item()if i % 10 == 9:progress_bar = '[' + ('=' * ((i + 1) // 10)) + (' ' * ((train_size // 100 - (i + 1)) // 10)) + ']'print('\repoch: {:d} loss: {:.3f}  {}'.format(epoch + 1, loss.cpu().item(), progress_bar), end="  ")print('\repoch: {:d} loss: {:.3f}'.format(epoch + 1, train_loss / (train_size / batch_size)), end="  ")history["train_loss"].append(train_loss / (train_size / batch_size))correct = 0total = 0batch_num = 0loss = 0with torch.no_grad():for data in testloader:data = data.to(device)outputs = model(data)loss += criterion(outputs, data.t)_, predicted = torch.max(outputs, 1)total += data.t.size(0)batch_num += 1correct += (predicted == data.t).sum().cpu().item()history["test_acc"].append(correct / total)history["test_loss"].append(loss.cpu().item() / batch_num)endstr = ' ' * max(1, (train_size // 1000 - 39)) + "\n"print('Test Accuracy: {:.2f} %%'.format(100 * float(correct / total)), end='  ')print(f'Test Loss: {loss.cpu().item() / batch_num:.3f}', end=endstr)print('Finished Training')# 最终结果correct = 0total = 0with torch.no_grad():for data in testloader:data = data.to(device)outputs = model(data)_, predicted = torch.max(outputs, 1)total += data.t.size(0)correct += (predicted == data.t).sum().cpu().item()print('Accuracy: {:.2f} %%'.format(100 * float(correct / total)))

完整代码
完整代码,我做了些许简化和修改,提醒一下这里的代码我改成了cpu,有条件的大帅哥可以自行改为cuda,如下:

import numpy as np
import gzip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
from torch_scatter import  scatter_maxdef load_mnist_graph(data_size=60000):# 获取数据主函数data_list = []labels = 0with gzip.open('data/train-labels-idx1-ubyte.gz', 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)for i in range(data_size):edge = torch.tensor(np.load('data/graphs/' + str(i) + '.npy').T, dtype=torch.long)x = torch.tensor(np.load('data/node_features/' + str(i) + '.npy') / 28, dtype=torch.float)# 构建数据集d = Data(x=x, edge_index=edge.contiguous(), t=int(labels[i]))data_list.append(d)if i % 1000 == 999:print("\rData loaded " + str(i + 1), end="  ")print("Complete!")return data_list# 定义网络结构
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = GCNConv(2, 16)self.conv2 = GCNConv(16, 32)self.conv3 = GCNConv(32, 48)self.conv4 = GCNConv(48, 64)self.conv5 = GCNConv(64, 96)self.conv6 = GCNConv(96, 128)self.linear1 = torch.nn.Linear(128,64)self.linear2 = torch.nn.Linear(64,10)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)x = F.relu(x)x = self.conv3(x, edge_index)x = F.relu(x)x = self.conv4(x, edge_index)x = F.relu(x)x = self.conv5(x, edge_index)x = F.relu(x)x = self.conv6(x, edge_index)x = F.relu(x)x, _ = scatter_max(x, data.batch, dim=0)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)return xdef main():# 训练主程序data_size = 60000train_size = 50000batch_size = 100epoch_num = 150# 数据获取mnist_list = load_mnist_graph(data_size=data_size)device = torch.device('cpu')model = Net().to(device)trainset = mnist_list[:train_size]optimizer = torch.optim.Adam(model.parameters())trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)testset = mnist_list[train_size:]testloader = DataLoader(testset, batch_size=batch_size)criterion = nn.CrossEntropyLoss()history = {"train_loss": [],"test_loss": [],"test_acc": []}print("Start Train")# 训练部分model.train()for epoch in range(epoch_num):train_loss = 0.0for i, batch in enumerate(trainloader):batch = batch.to("cpu")optimizer.zero_grad()outputs = model(batch)loss = criterion(outputs, batch.t)loss.backward()optimizer.step()train_loss += loss.cpu().item()if i % 10 == 9:progress_bar = '[' + ('=' * ((i + 1) // 10)) + (' ' * ((train_size // 100 - (i + 1)) // 10)) + ']'print('\repoch: {:d} loss: {:.3f}  {}'.format(epoch + 1, loss.cpu().item(), progress_bar), end="  ")print('\repoch: {:d} loss: {:.3f}'.format(epoch + 1, train_loss / (train_size / batch_size)), end="  ")history["train_loss"].append(train_loss / (train_size / batch_size))correct = 0total = 0batch_num = 0loss = 0with torch.no_grad():for data in testloader:data = data.to(device)outputs = model(data)loss += criterion(outputs, data.t)_, predicted = torch.max(outputs, 1)total += data.t.size(0)batch_num += 1correct += (predicted == data.t).sum().cpu().item()history["test_acc"].append(correct / total)history["test_loss"].append(loss.cpu().item() / batch_num)endstr = ' ' * max(1, (train_size // 1000 - 39)) + "\n"print('Test Accuracy: {:.2f} %%'.format(100 * float(correct / total)), end='  ')print(f'Test Loss: {loss.cpu().item() / batch_num:.3f}', end=endstr)print('Finished Training')# 最终结果correct = 0total = 0with torch.no_grad():for data in testloader:data = data.to(device)outputs = model(data)_, predicted = torch.max(outputs, 1)total += data.t.size(0)correct += (predicted == data.t).sum().cpu().item()print('Accuracy: {:.2f} %%'.format(100 * float(correct / total)))if __name__ == '__main__':main()

结语

今天是521,我和日帅的约会让我在对图数据上的收获收益匪浅,也达到了自己动手完成一个小demo的目标,其实说实话,日系帅哥的颜值我还是很吃的,还能写代码的就更爱了!又是一个cv浑水摸鱼的一天万岁!

我在日本小帅哥那学习了GCN相关推荐

  1. python turtle怎么画海绵宝宝_画师绘制海绵宝宝性转拟人,派大星变小帅哥,又脑补一出甜蜜大戏...

    我已经工作了有一段时间了,但是我依然很喜欢看<海绵宝宝>这部动漫,每次看的时候都会笑得没心没肺,十分欢乐. 好羡慕海绵宝宝和派大星他们啊,海绵宝宝还要上班,有自己的理想和工作,派大星真的是 ...

  2. 来我主页的小仙女小帅哥给你们一道很有深度(底层原理)的题(能够看到这篇文章的人希望你做一下))

    不懂JavaScript底层原理,你以后工作调试程序的时间会大大延长和延长学习新技术的时间 <!DOCTYPE html> <html lang="en"> ...

  3. 小码哥iOS学习笔记第二天: OC对象的分类

    Objective-C中的对象, 简称OC对象, 主要可以分为3种 instance对象(实例对象) class对象(类对象) meta-class对象(元类对象) 一.instance instan ...

  4. 小码哥iOS学习笔记第八天: block的底层结构

    一.最简单的block 1.最简单的block结构 ^{NSLog(@"this is a block");NSLog(@"this is a block"); ...

  5. 小帅哥~小美女~快点进来看看内部类鸭~

    目录 前言: 内部类: 1.成员内部类: 2.局部内部类: 3.匿名内部类: 4.静态内部类: 5.内部类关于static修饰的注意事项: 前言: 大家好啊!我有一个朋友...咳咳,给大家介绍一下它, ...

  6. 小码哥iOS学习笔记第十二天:Class结构

    一.Class的结构 通过查看源码, 可以得出Class的底层结构如下图 一开始class_data_bits_t bits;指向ro, 在加载的过程中创建了rw, 此时的指向顺序是bits-> ...

  7. 小蟑螂与帅哥的故事~

    1   小蟑螂趴在碗柜上偷kui. 那个帅哥哼着小曲儿在洗碗, 嗯哼嗯哼,哗啦哗啦.    帅哥洗碗都这么好听吗. 小蟑螂爱屋及乌了,帅哥干什么都好看. 小蟑螂黑黑的小脸儿红了. 2  小蟑螂今年一岁 ...

  8. [转载] 我叫李小帅

    离婚时,前夫说,房子是我奶奶的,不能给你.我说,行.他又说,家里的钱也就是我公司里的那些产品,你要它也没用.我说,好.他说,我家三代单传,小帅得跟我.我拍案:不行! 晚上,我去接小帅放学,第一次在路上 ...

  9. es拼音分词 大帅哥_机器学习

    1. 赌场风云(背景介绍) 最近一个赌场的老板发现生意不畅,于是派出手下去赌场张望.经探子回报,有位大叔在赌场中总能赢到钱,玩得一手好骰子,几乎是战无不胜.而且每次玩骰子的时候周围都有几个保镖站在身边 ...

最新文章

  1. 历史版本_DNF:历史版本十大经典地图,没经历过那个时代的人不会明白的
  2. 图集打包算法_UGUI打包图集工具-插件Simple Sprite Packer详解
  3. 华为云提供针对Nuget包管理器的缓存加速服务
  4. logo qt添加_linux下如何给qt程序添加图标?
  5. 沪港通:利好出尽就是利空
  6. while循环里面scanf_5.1 for循环
  7. 深入浅出Mybatis系列(一)---Mybatis入门[转]
  8. .NET 云原生架构师训练营(模块二 基础巩固 安全)--学习笔记
  9. C#算法设计排序篇之02-快速排序(附带动画演示程序)
  10. [转] C#中Dispose和Close的区别
  11. 学习OpenCV——OpenMP
  12. 树的遍历 (和) 玩转二叉树 的总结博客
  13. 思维导图案例之VeritasDCG
  14. 道路照明之电缆线路 - 设计笔记
  15. Pycharm安装chardet模块
  16. 软件测试基本流程【车机测试】
  17. 退出android recovery界面,怎么退出recovery模式?
  18. i.MX283开发板移植RTL8188ETV无线网卡驱动
  19. SX1278传输距离测试
  20. Word2Vec词向量模型代码

热门文章

  1. itunes计算机无法启动,解决:Apple移动设备服务无法启动
  2. 【poj 2488】A Knight's Journey 中文题意题解代码(C++)
  3. 解决pycharm导入自己写的模块飘红问题
  4. 任建新照常去办公室领取工资
  5. 学计算机编程配置需求,编程对电脑配置要求高吗?
  6. 一维离散动力学系统的混沌研究【基于matlab的动力学模型学习笔记_8】
  7. css 实现单行、多行文本显示
  8. 华北理工计算机学院官网,2019上半年华北理工大学计算机等级考试报名通知
  9. Svn中的tag标签的用法和意义
  10. for循环里面的break;和continue;语句