网络搭建:

mynn.py:

import torchfrom torch import nnclass mynn(nn.Module):    def __init__(self):        super(mynn, self).__init__()        self.layer1 = nn.Sequential(            nn.Linear(3520, 4096), nn.BatchNorm1d(4096), nn.ReLU(True)        )        self.layer2 = nn.Sequential(            nn.Linear(4096, 4096), nn.BatchNorm1d(4096), nn.ReLU(True)        )        self.layer3 = nn.Sequential(            nn.Linear(4096, 4096), nn.BatchNorm1d(4096), nn.ReLU(True)        )        self.layer4 = nn.Sequential(            nn.Linear(4096, 4096), nn.BatchNorm1d(4096), nn.ReLU(True)        )        self.layer5 = nn.Sequential(            nn.Linear(4096, 3072), nn.BatchNorm1d(3072), nn.ReLU(True)        )        self.layer6 = nn.Sequential(            nn.Linear(3072, 2048), nn.BatchNorm1d(2048), nn.ReLU(True)        )        self.layer7 = nn.Sequential(            nn.Linear(2048, 1024), nn.BatchNorm1d(1024), nn.ReLU(True)        )        self.layer8 = nn.Sequential(            nn.Linear(1024, 256), nn.BatchNorm1d(256), nn.ReLU(True)        )        self.layer9 = nn.Sequential(            nn.Linear(256, 64), nn.BatchNorm1d(64), nn.ReLU(True)        )        self.layer10 = nn.Sequential(            nn.Linear(64, 32), nn.BatchNorm1d(32), nn.ReLU(True)        )        self.layer11 = nn.Sequential(            nn.Linear(32, 3)        )

    def forward(self, x):        x = self.layer1(x)        x = self.layer2(x)        x = self.layer3(x)        x = self.layer4(x)        x = self.layer5(x)        x = self.layer6(x)        x = self.layer7(x)        x = self.layer8(x)        x = self.layer9(x)        x = self.layer10(x)        x = self.layer11(x)        return x

Dataset重定义:mydataset.py
import osfrom torch.utils import dataimport numpy as npfrom astropy.io import fitsfrom torchvision import transforms as Tfrom PIL import Imageimport pandas as pd

class mydataset(data.Dataset):

    def __init__(self,csv_file,root_dir=None,transform=None):        self.landmarks_frame=np.loadtxt(open(csv_file,"rb"),delimiter=",")             #landmarks_frame是一个numpy矩阵        self.root_dir=root_dir        self.transform=transform    def __len__(self):        return len(self.landmarks_frame)    def __getitem__(self, idx):        lfit=self.landmarks_frame[idx,:]        lable=lfit[len(lfit)-1]        datafit=lfit[0:(len(lfit)-1)]        return lable,datafit主程序:main.py
import torchfrom torch import nn, optimfrom torchvision import datasets, transformsfrom torch.autograd import Variable#from models import Mynet, my_AlexNet, my_VGGfrom sdata import mydatasetimport timeimport numpy as npfrom model import  mynnif __name__ == '__main__':  #如果Dataloader开启num_workers > 0  必须要在'__main__'下才能消除报错

    data_train = mydataset.mydataset(csv_file="G:\\DATA\\train.csv",root_dir=None,transform=None)    #data_test = mydataset(test=True)    data_test = mydataset.mydataset(csv_file="G:\\DATA\\test.csv", root_dir=None, transform=None)    data_loader_train = torch.utils.data.DataLoader(dataset=data_train,                                                    batch_size=256,                                                    shuffle=True,                                                    num_workers=0,                                                    pin_memory=True)    data_loader_test = torch.utils.data.DataLoader(dataset=data_test,                                                    batch_size=256,                                                    shuffle=True,                                                    num_workers=0,                                                    pin_memory=True)    print("**dataloader done**")    model = mynn.mynn()

    if torch.cuda.is_available():        #model = model.cuda()        model.to(torch.device('cuda'))    #损失函数    criterion = nn.CrossEntropyLoss()    #optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)    #优化算法    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4)    n_epochs = 1000

    global_train_acc = []

    s_time = time.time()

    for epoch in range(n_epochs):        running_loss = 0.0        running_correct = 0.0        print('Epoch {}/{}'.format(epoch, n_epochs))        for label,datafit in data_loader_train:            x_train, y_train = datafit,label            #x_train, y_train = Variable(x_train.cuda()), Variable(y_train.cuda())            x_train, y_train = x_train.to(torch.device('cuda')), y_train.to(torch.device('cuda'))            x_train=x_train.float()            y_train=y_train.long()            #x_train, y_train = Variable(x_train), Variable(y_train)            outputs = model(x_train)            _, pred = torch.max(outputs.data, 1)            optimizer.zero_grad()            loss = criterion(outputs, y_train)            loss.backward()            optimizer.step()

            running_loss += loss.item()            running_correct += torch.sum(pred == y_train.data)

        testing_correct = 0.0        for label,datafit in data_loader_test:            x_test, y_test = datafit,label            x_test=x_test.float()            y_test=y_test.long()            x_test, y_test = Variable(x_test.cuda()), Variable(y_test.cuda())            # x_test, y_test = Variable(x_test), Variable(y_test)            outputs = model(x_test)            _, pred = torch.max(outputs.data, 1)            testing_correct += torch.sum(pred == y_test.data)

        print('Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Accuracy '              'is:{:.4f}'.format(running_loss / len(data_train),                                 100 * running_correct / len(data_train),                                 100 * testing_correct / len(data_test)))

    e_time = time.time()    print('time_run is :', e_time - s_time)    print('*******done******')

将天文数据写入csv中:main.py
# -*- coding: utf-8 -*-"""Spyder Editor

This is a temporary script file."""

import matplotlib.pyplot as pltfrom astropy.io import fitsimport osimport matplotlibmatplotlib.use('Qt5Agg')from astropy.io import fitsimport numpy as npfrom sklearn.model_selection import train_test_splitfrom sklearn import svmfrom sklearn.decomposition import PCAdef getData(fitPath,cla):    fileList=[]                 #所有.fit文件    files=os.listdir(fitPath)      #返回一个列表,其中包含在目录条目的名称    y=[]    for f in files:        if os.path.isfile(fitPath+'/'+f) and f[-4:-1]==".fi":            fileList.append(fitPath+'/'+f)  #添加文件    len=90000    x=np.ones(3521)    num=1    for path in fileList:        f = fits.open(path)        header = f[0].header  # fit文件中的各种标识

        SPEC_CLN = header['SPEC_CLN']        SN_G = header['SN_G']        NAXIS1 = header['NAXIS1']  # 光谱数据维度        COEFF0 = header['COEFF0']        COEFF1 = header['COEFF1']        wave = np.ones(NAXIS1)  # 光谱图像中的横坐标        for i in range(NAXIS1):            wave[i] = i        logwavelength = COEFF0 + wave * COEFF1        for i in range(NAXIS1):            wave[i] = 10 ** logwavelength[i]        min=0        for i in range(NAXIS1-1):            if wave[i]<=4000 and wave[i+1]>=4000:                min=i        spec = f[0].data[0, :]  # 光谱数据  fit中的第一行数据        spec=spec[min:min+3521]        spec=np.array(spec)        spec[3520]=cla        if num==1:            x=spec            num=2        else:            x=np.row_stack((x,spec))    #np.savetxt(csvPath,x, delimiter=',')    return x

if __name__ == '__main__':    x=getData("G:\DATA\STAR",0)    x_train,x_test=train_test_split(x,test_size=0.1 ,random_state=0)

    y=getData("G:\DATA\QSO",1)    y_train, y_test = train_test_split(y, test_size=0.1, random_state=0)    x_train = np.row_stack((x_train,y_train))    x_test=np.row_stack((x_test,y_test))

    z=getData("G:\DATA\GALAXY",2)    z_train, z_test = train_test_split(z, test_size=0.1, random_state=0)    x_train=np.row_stack((x_train,z_train))    x_test = np.row_stack((x_test,z_test))    np.savetxt("G:\\DATA\\train.csv",x_train, delimiter=',')    np.savetxt("G:\\DATA\\test.csv", x_test, delimiter=',')

转载于:https://www.cnblogs.com/invisible2/p/11523330.html

pytorch简单框架相关推荐

  1. 深度学习调用TensorFlow、PyTorch等框架

    深度学习调用TensorFlow.PyTorch等框架 一.开发目标目标 提供统一接口的库,它可以从C++和Python中的多个框架中运行深度学习模型.欧米诺使研究人员能够在自己选择的框架内轻松建立模 ...

  2. 《预训练周刊》第29期:Swin Transformer V2:扩大容量和分辨率、SimMIM:用于遮蔽图像建模的简单框架...

    No.29 智源社区 预训练组 预 训 练 研究 观点 资源 活动 关于周刊 本期周刊,我们选择了10篇预训练相关的论文,涉及图像处理.图像屏蔽编码.推荐系统.语言模型解释.多模态表征.多语言建模.推 ...

  3. Pytorch Lightning框架:使用笔记【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】

    pytorch是有缺陷的,例如要用半精度训练.BatchNorm参数同步.单机多卡训练,则要安排一下Apex,Apex安装也是很烦啊,我个人经历是各种报错,安装好了程序还是各种报错,而pl则不同,这些 ...

  4. 《预训练周刊》第29期:Swin Transformer V2:扩大容量和分辨率、SimMIM:用于遮蔽图像建模的简单框架

    关于周刊 本期周刊,我们选择了10篇预训练相关的论文,涉及图像处理.图像屏蔽编码.推荐系统.语言模型解释.多模态表征.多语言建模.推理优化.细胞抗原预测.蛋白结构理解和化学反应的探索.此外,在资源分享 ...

  5. 缓冲运动之框架開始一级简单框架实例

    ***********************缓冲运动[框架開始]-1.html********************************************* <!DOCTYPE h ...

  6. SpringMVC+Thymeleaf +HTML的简单框架

    SpringMVC+Thymeleaf +HTML的简单框架 一.问题 项目中需要公众号开发,移动端使用的是H5,但是如果不用前端框架的话,只能考虑JS前端用ajax解析JSON字符串了.今天我们就简 ...

  7. java 简单 框架_java简单框架设计

    设计框架包可以作为一个工具给大家用,需要有完全不同设计思路给出来,不同于我们去做一个web服务.网站. 或者一个业务微服务,需要从原来使用视角转换成一个构建者视角. 框架或者工具,更多是框架来管理或者 ...

  8. Deep CARs:使用Pytorch学习框架实现迁移学习

    全文共13449字,预计学习时长26分钟或更长 图片来源:https://www.pexels.com/photo/vehicles-parked-inside-elevated-parking-lo ...

  9. 机器学习基础概念及简单框架

    机器学习要知道的基础概念和简单框架 机器学习相关的基础概念 机器学习的简单框架 机器学习相关的基础概念 All models are wrong but some are useful(所有模型都是错 ...

最新文章

  1. 为什么我们都要等到失去后才知道珍惜呢(转载)
  2. mysql table combine_Mysql系列-性能优化神器EXPLAIN使用介绍及分析
  3. LaTex 更改文字颜色
  4. HDU - 4289 Control(最小割-最大流)
  5. python反编译exe_实战 Python3.7+64位 Exe 反编译
  6. TOP命令 详解CPU 查看多个核心的利用率按1
  7. b站2020用户画像_B站2020年度动画大选来袭!论引战,还是要看B站
  8. Linux echo
  9. 为什么电脑CPU这么贵?
  10. ModelSim6.3 使用教程By Sunev
  11. 2015年蓝桥杯省赛A组c++第4题
  12. asp.net多语言设置方法
  13. python类库31[命令行解析]
  14. 成人高考计算机基础历年真题,成人高考历年真题及答案
  15. oracle10g在win10上的安装
  16. asr语音转写_利用Real-time ASR语音转写服务实现直播实时弹幕提升用户体验
  17. HPSocket C++控制台版DEMO
  18. GTX1060 Windows10 旧版显卡驱动下载链接
  19. built a JNCIS LAB系列:Chapter 2 OSPF v1.0
  20. MAC升级gcc版本

热门文章

  1. 深度学习之基于GAN实现手写数字生成
  2. 从sqlserver中数据写入mysql_从SQL server数据库导入Mysql数据库的体验
  3. composer在windows中安装失败
  4. HttpServlet概述及应用
  5. loadrunner 只能并发50_loadrunner 场景设计-(一)
  6. 图神经网络(二)GCN的性质(3)GCN是一个低通滤波器
  7. php dirtoarray,PHP Ds\Stack toArray()用法及代码示例
  8. python2exe下载_py2exe下载 0.6.9.win32-py2.7-python转exe工具-pc6下载站
  9. 看了就会的VScode给C++的配置编译环境(Visual Studio Code)
  10. 图论——Tarjan 初步 DFS序+时间戳+欧拉序