本文数据来源于第九届泰迪杯数据挖掘挑战赛,需要的好兄弟可以自行去下载哦,也可以下载我处理好了的数据。

文章目录

  • 前言
  • 一、图片压缩,预处理
  • 二、代码部分
    • 1.数据准备部分
    • 2.模型部分:
  • 总结

前言

深度学习的卷积神经网络是一个比较重要的研究方向,关于卷积的一些理论,在我的另外一篇博客
大家可以去了解一下。


一、图片压缩,预处理

把岩石数据分成了7类:

先通过数据处理把图片分成7类,方便我们后续的导入。

图片压缩会造成损失,所以对原图片还是需要进行处理,比如:

很明显,这个地方图片从35M变成了930KB,所以这种压缩效果是很好的。而且其本身并不会造成损失,emm,这个算我的直观感受把,如有错误,还请回复我哦。这个包就是tinypng,这个是一个很强大的包。你可以去看他的官方介绍。
这个地方推荐一位大佬的博客,这里面讲的有关于tinypng的操作以及原理链接.

在补充一点点大佬文中没有提及的使用方法:
tinypng提供了一个method的参数,有:scale、fit、cover、thumb。
这几个库有什么特点呢:

  1. scale:
    尺度缩小图片比例。您必须提供一个目标 width或一个目标height,但不能同时提供两者。缩放后的图像将完全具有所提供的宽度或高度。
  2. fit:
    缩放图像比例下降,使其内符合给定尺寸。您必须同时提供width和height。缩放后的图像不会超过这两个尺寸。
  3. cover:
    缩放比例的图像和裁剪,如果必要的,这样的结果具有准确的给定尺寸。您必须同时提供 width和height。图像的哪些部分被裁剪掉是自动确定的。智能算法确定图像的最重要区域。
  4. thumb:
    Cover的更高级实现也可以检测出具有纯背景的剪切图像。图像被缩小到 width和height你提供。如果检测到带有独立物体的图像,它将在必要时添加更多背景空间或裁切不重要的部分。

这个方法呢,需要消耗你的次数。
用法:

source = tinify.from_file("large.jpg")
resized = source.resize(method="fit",width=150,height=100
)
resized.to_file("thumbnail.jpg")

二、代码部分

1.数据准备部分

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.utils.data as Data
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from skimage import io,transform
import skimage
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"rock_1 = []
label_rock_1 = []
rock_2 = []
label_rock_2 = []
rock_3 = []
label_rock_3 = []
rock_4 = []
label_rock_4 = []
rock_5 = []
label_rock_5 = []
rock_6 = []
label_rock_6 = []
rock_7 = []
label_rock_7 = []def get_files(file_dir):for file in os.listdir(file_dir + '/1'):rock_1.append(file_dir + '/1' + '/' + file)label_rock_1.append(1)for file in os.listdir(file_dir + '/2'):rock_2.append(file_dir + '/2' + '/' + file)label_rock_2.append(2)for file in os.listdir(file_dir + '/3'):rock_3.append(file_dir + '/3' + '/' + file)label_rock_3.append(3)for file in os.listdir(file_dir + '/4'):rock_4.append(file_dir + '/4' + '/' + file)label_rock_4.append(4)for file in os.listdir(file_dir + '/5'):rock_5.append(file_dir + '/5' + '/' + file)label_rock_5.append(5)for file in os.listdir(file_dir + '/6'):rock_6.append(file_dir + '/6' + '/' + file)label_rock_6.append(6)for file in os.listdir(file_dir + '/7'):rock_7.append(file_dir + '/7' + '/' + file)label_rock_7.append(7)image_list = np.hstack((rock_1, rock_2, rock_3, rock_4,rock_5,rock_6,rock_7))label_list = np.hstack((label_rock_1, label_rock_2, label_rock_3, label_rock_4,label_rock_5,label_rock_6,label_rock_7))temp = np.array([image_list, label_list])temp = temp.transpose()np.random.shuffle(temp)# 将所有的img和lab转换成listreturn temppath_1 = 'G:/泰迪杯数据挖掘_data/例子'
temp = get_files(path_1)BATCH_SIZE = 5
LR = 0.0004
print(len(temp))
temp_1 = temp[:1500]
temp_2 = temp[1500:]#训练数据
all_image_list_train = list(temp_1[:, 0])
all_label_list_train = list(temp_1[:, 1])
train_img = []
for i in all_image_list_train:img = skimage.io.imread(i)img = transform.resize(img,(128,128))img = img/255.0img = img.astype('float32')train_img.append(img)all_label_list_train_1 = []
for j in all_label_list_train:all_label_list_train_1.append(int(j))train_transform = transforms.Compose([transforms.Normalize((0.5,), (0.5,)), #將matrices转成 Tensor,並把数值normalize到[0,1](data normalization)
])
train_x = np.array(train_img)
train_y = np.array(all_label_list_train_1)
print(train_x.shape)
train_x_1 = train_x.reshape(1500,3,128,128)
train_x_1 = torch.from_numpy(train_x_1)train_y_1 = torch.from_numpy(train_y)
torch_dataset = Data.TensorDataset(train_x_1,train_y_1)
# #测试数据
all_image_list_1 = list(temp_2[:, 0])
all_label_list_1 = list(temp_2[:, 1])
val_img = []
for i in all_image_list_1:img = skimage.io.imread(i)img = transform.resize(img, (128, 128))img = img/255.0img = img.astype('float32')val_img.append(img)
all_label_list_2 = []
for j in all_label_list_1:all_label_list_2.append(int(j))train_loader = DataLoader(torch_dataset, batch_size=BATCH_SIZE, shuffle=True,num_workers=0)val_x = np.array(val_img)
val_y = np.array(all_label_list_2)
val_x=val_x.reshape(390,3,128,128)
val_x=torch.from_numpy(val_x)
#  转换为torch张量
val_y=torch.from_numpy(val_y)

2.模型部分:

#测试集数据
# cs_list_image = []
# cs_name_1 = []
# path_cs = 'G:/泰迪杯数据挖掘_data/B题测试数据/压缩'
# for cs_name in os.listdir(path_cs):
#     cs_name_1.append(cs_name)
#     name_cs_total = path_cs+'/'+cs_name
#     cs_list_image.append(name_cs_total)
# cs_list_image_1 = []
# for image_cs in cs_list_image:
#     img = skimage.io.imread(image_cs)
#     img = img / 255.0
#     img = img.astype('float32')
#     cs_list_image_1.append(img)
# cs_list_image_1 = np.array(cs_list_image_1)
# cs_list_image_1 = cs_list_image_1.reshape(35,3,64,64)
# cs_list_image_1 = torch.from_numpy(cs_list_image_1)class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(   #(1,28,28)(30,64,64,3)in_channels=3,     #1代表着灰度图片,,如果是3这个地方就是代表彩色图片out_channels=32,   #输出的特征值16个kernel_size=3,   #5x5卷积核stride=1,    # 步长,每次移动一个像素padding=1,   #扩充边缘,方便提取边缘特征  padding = (kernel_size-1)/2),  #图片变成(16,28,28)(30,64,64,16)nn.Dropout(0.5),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  #这个地方使用2x2的区域再一次卷积/  变成(32,14,14)(30,32,32,16))self.conv2 = nn.Sequential(nn.Conv2d(32,64,3,1,1),   #变成(32,14,14)(30,32,32,32)(16,32,5,1,2)nn.Dropout(0.25),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),   #变成(32,7,7)(30,16,16,96))self.conv3 = nn.Sequential(nn.Conv2d(64,32,3,1,1),nn.Dropout(0.25),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2),)self.conv4 = nn.Sequential(nn.Conv2d(32, 16, 3, 1, 1),nn.Dropout(0.25),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(2),)self.conv5 = nn.Sequential(nn.Conv2d(16,10,3,1,1),nn.Dropout(0.25),nn.BatchNorm2d(10),nn.ReLU(),nn.MaxPool2d(2),)# self.conv6 = nn.Sequential(#     nn.Conv2d(16, 16, 3, 1, 1),#     nn.BatchNorm2d(16),#     nn.ReLU(),#     nn.MaxPool2d(2),# )self.out = nn.Linear(10*4*4,8)#(96*16*16,)(32*7*7,10)def forward(self,x):   #进行展平x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = self.conv5(x)#x = self.conv6(x)x = x.view(x.size(0),-1)   #(batch,32*7*7)output = self.out(x)return outputcnn = CNN()
acc_list = []
optimizer = torch.optim.Adam(cnn.parameters(),lr=LR)  #优化器
loss_fun = nn.CrossEntropyLoss()  #自带softmax
from sklearn.metrics import accuracy_score
train_losses = []
val_losses = []EPOCH = 5
for epoch in range(EPOCH):for step,(b_x,b_y) in enumerate(train_loader):output_train = cnn(b_x)loss = loss_fun(output_train,b_y.long())#loss_val = loss_fun(output_val, y_val.long())optimizer.zero_grad()loss.backward()optimizer.step()train_losses.append(loss)if step % 100 == 0:test_out = cnn(val_x)# cnn.eval()pre = torch.argmax(test_out,1)acc = accuracy_score(val_y,pre)acc_list.append(acc)print('Epoch:', epoch, '| train loss:%.4f' % loss.item(), '| test accuracy:%.4f' % acc)

需要注意的是,GPU运行是不支持你将图片格式改成float32以下的,不包括32,CPU运行则是可以把图片改成float16,float8的。不仅仅是pytorch框架,还有Keras

总结

本文就不对代码部分进行讲解了,只要知道原理,理解,知道怎么用就行,当然也可以考虑直接调用别人已经训练好的模型。
pytorch需要注意的就是数据的格式问题。
文中模型的精度应该就是败笔了。只能达到0.6几,后面有时间调优,应该会来翻改模型。


文中有错误部分,还劳烦指正。
Github源码


害!!!
一写博客就饿!!!

别的不说,再见了。
干饭!!!

2021/7/12更新
兄弟们数据集
链接:https://pan.baidu.com/s/1JgbOBaDBhPvML1tZ7DhXPg
提取码:dtec
放这了

pytorch-CNN岩石分类(本地数据)相关推荐

  1. Pytorch+CNN+猫狗分类实战

    文章目录 0.前言 1.猫狗分类数据集 1.1数据集下载(可选部分) 1.2数据集分析 2.猫狗分类数据集预处理 2.1训练集和测试集划分 2.2训练集和测试集读取 3.剩余代码 4.总结 0.前言 ...

  2. Kaggle猫狗大战——基于Pytorch的CNN网络分类:数据获取、预处理、载入(1)

    Kaggle猫狗大战--基于Pytorch的CNN网络分类:数据获取.预处理.载入(1) 第一次写CSDN博客,之前一直是靠着CSDN学学代码,这次不得不亲自上场了,就想着将学习的过程都记录下来.新人 ...

  3. 《Pytorch - CNN模型》

    2020年10月5号,依然在家学习. 今天是我写的第四个 Pytorch程序, 这一次我想把之前基于PyTorch实现的简易的传统的BP全连接神经网络改写成CNN网络,想看看对比和效果差异. 这一次我 ...

  4. LESSON 10.110.210.3 SSE与二分类交叉熵损失函数二分类交叉熵损失函数的pytorch实现多分类交叉熵损失函数

    在之前的课程中,我们已经完成了从0建立深层神经网络,并完成正向传播的全过程.本节课开始,我们将以分类深层神经网络为例,为大家展示神经网络的学习和训练过程.在介绍PyTorch的基本工具AutoGrad ...

  5. 【项目实战课】人人免费可学!基于Pytorch的图像分类简单任务数据增强实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的图像分类简单任务数据增强实战>.所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的 ...

  6. vuejs实现本地数据的筛选分页

    今天项目需要一份根据本地数据的筛选分页功能,好吧,本来以为很简单,网上搜了搜全是ajax获取的数据,这不符合要求啊,修改起来太费力气,还不如我自己去写,不多说直接上代码 效果图: 项目需要:点击左侧进 ...

  7. ios网络学习------4 UIWebView的加载本地数据的三种方式

    ios网络学习------4 UIWebView的加载本地数据的三种方式 分类: IOS2014-06-27 12:56 959人阅读 评论(0) 收藏 举报 UIWebView是IOS内置的浏览器, ...

  8. pytorch实现文本分类_使用变形金刚进行文本分类(Pytorch实现)

    pytorch实现文本分类 'Attention Is All You Need' "注意力就是你所需要的" New deep learning models are introd ...

  9. Pytorch搭建常见分类网络模型------VGG、Googlenet、ResNet50 、MobileNetV2(4)

    接上一节内容:Pytorch搭建常见分类网络模型------VGG.Googlenet.ResNet50 .MobileNetV2(3)_一只小小的土拨鼠的博客-CSDN博客 mobilenet系列: ...

最新文章

  1. 出现 java.util.ConcurrentModificationException 时的解决办法
  2. JDK1.6安装与环境变量设置详细图解
  3. Android Linux自带iptables配置IP访问规则
  4. 计算机控制实验教程,新)《计算机控制技术》实验教程.doc
  5. kudu 存储引擎简析
  6. 工作两年多的一个菜鸟感想
  7. javamail 解码 base64 html格式邮件_Spring整合javaMail
  8. Linux清除用户登录记录和命令历史方法
  9. 【优化求解】基于matlab遗传算法求解仓库货位优化问题【含Matlab源码 1770期】
  10. VC6.0和VC2012的全局对象的释放!!!
  11. 【使用Mac制作手写签名的方法】
  12. ASP.NET 实现简单的注册界面(使用asp控件)
  13. 【bazel】根据.proto文件生成.h、.cc文件
  14. PYTHON:已知一点经纬度、方位角和距离,求另一点的经纬度
  15. 11. JS编程之查找元素在数组中的位置
  16. 神经网络中定义网络模型中的forward方法
  17. C语言入门——时间换算
  18. python协程实现一万并发_python进阶:服务端实现并发的八种方式
  19. 对于一颗满二叉排序树深度为K,求最小子树根节点值 Python代码实现
  20. 开发效率提升300%,Vue3新特性已成气候!

热门文章

  1. 【毕业设计】37-基于单片机智能楼宇消防监控系统设计(原理图工程+仿真工程+源代码+答辩论文+答辩PPT)
  2. IEC 61968 和 IEC 61850 量测模型的差异性分析(论文学习)
  3. 2023Matlab初级教程- 第一章 初识Matlab与界面介绍
  4. 几个好玩(整人)的vbs小程序
  5. arcgis重心迁移分析_arcgis人口重心迁移图
  6. mysql锁全表语句_MySql锁表语句
  7. 关于Oracle数据库字符集的选择及乱码情况
  8. 解决mido打开midi文件失败的问题
  9. 杰理AD14N/AD15N---UART串口使用问题
  10. Linux-网络部分总结(二实验)