文章目录

  • pytorch hello word
  • pytorch 物体检测

pytorch hello word

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)batch_size = 64# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")# Define model
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

pytorch 物体检测

## 导入相关模块
import numpy as np
import torchvision
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as pltmodel = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()## 准备需要检测的图像
image = Image.open(r"E:\Users\MyPC\PycharmProjects\untitled3\img\cc.jpg")
transform_d = transforms.Compose([transforms.ToTensor()])
image_t = transform_d(image)    ## 对图像进行变换
pred = model([image_t])         ## 将模型作用到图像上
print(pred)## 定义使用COCO数据集对应的每类的名称
"""fire hydrant 消防栓,stop sign 停车标志, parking meter 停车收费器, bench 长椅。zebra 斑马, giraffe 长颈鹿, handbag 手提包, suitcase 手提箱, frisbee (游戏用)飞盘(flying disc)。skis 滑雪板(ski的复数),snowboard 滑雪板(ski是单板滑雪,snowboarding 是双板滑雪。)kite 风筝, baseball bat 棒球棍, baseball glove 棒球手套, skateboard 滑板, surfboard 冲浪板, tennis racket 网球拍。broccoli 西蓝花,donut甜甜圈,炸面圈(doughnut,空心的油炸面包), cake 蛋糕、饼, couch 长沙发(靠chi)。potted plant 盆栽植物。 dining table 餐桌。 laptop 笔记本电脑,remote 遥控器(=remote control), cell phone 移动电话(=mobile phone)(cellular 细胞的、蜂窝状的), oven 烤炉、烤箱。 toaster 烤面包器(toast 烤面包片)sink 洗碗池, refrigerator 冰箱。(=fridge), scissor剪刀(see, zer), teddy bear 泰迪熊。 hair drier 吹风机。 toothbrush 牙刷。
"""
COCO_INSTANCE_CATEGORY_NAMES = ['__BACKGROUND__', 'person', 'bicycle', 'car', 'motorcycle','airplane', 'bus', 'train', 'trunk', 'boat', 'traffic light','fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench','bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant','bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A','N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard','sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard','surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass','cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple','sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza','donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A','dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop','mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven','toaster', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock','vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]## 检测出目标的类别和得分
pred_class = [COCO_INSTANCE_CATEGORY_NAMES[ii] for ii in list(pred[0]['labels'].numpy())]
pred_score = list(pred[0]['scores'].detach().numpy())## 检测出目标的边界框
pred_boxes = [[ii[0], ii[1], ii[2], ii[3]] for ii in list(pred[0]['boxes'].detach().numpy())]## 只保留识别的概率大约 0.5 的结果。
pred_index = [pred_score.index(x) for x in pred_score if x > 0.5]# ## 设置图像显示的字体
# fontsize = np.int16(image.size[1] / 20)
# font1 = ImageFont.truetype("/usr/share/fonts/gnu-free/FreeMono.ttf", fontsize)## 可视化对象
draw = ImageDraw.Draw(image)
for index in pred_index:box = pred_boxes[index]draw.rectangle(box, outline="blue")texts = pred_class[index]+":"+str(np.round(pred_score[index], 2))draw.text((box[0], box[1]), texts, fill="blue")## 显示图像
plt.imshow(image)
plt.show()

册 https://pytorchbook.cn/

python-pytorch hello world相关推荐

  1. python pytorch 包的安装

    python pytorch 包的安装 打开官网:https://pytorch.org/ https://pytorch.org/get-started/locally/

  2. THCudaCheck FAIL file=/opt/conda/conda-bld/python/pytorch/work/aten/src/THC/THCCachingHostAllocator.

    各位大佬好,我想跑YOLOV5,用极链云租了个实例, 按照帮助文档:https://cloud.videojj.com/help/.配置好了环境,pycharm deployment也配置成功了,可以 ...

  3. win10+centos7+Anaconda3+python+pytorch

    安装Anaconda3 直接去清华的镜像进行下载,因为官网进不去,下载Anaconda3-2019.03-Linux-x86_64.sh 也可以命令行下载 wget https://mirrors.t ...

  4. TX2+python+pytorch install

    前半部分可以参考这一篇大神的https://blog.csdn.net/qq_33869371/article/details/88168202 Installing PyTorch on TX2 T ...

  5. 【深度学习】使用Python+PyTorch预测野外火灾

    作者 | Aishwarya Srinivasan 编译 | VK 来源 | Towards Data Science 联合国在实现其可持续发展目标方面面临的主要障碍之一是与自然灾害作斗争,而造成巨大 ...

  6. python pytorch fft_看PyTorch源代码的心路历程

    1. 起因 曾经碰到过别人的模型prelu在内部的推理引擎算出的结果与其在原始框架PyTorch中不一致的情况,虽然理论上大家实现的都是一个算法,但是从参数上看,因为经过了模型转换,中间做了一些调整. ...

  7. python pytorch 版本,python 如何查看pytorch版本

    看代码吧~ import torch print(torch.__version__) 补充:pytorch不同版本安装以及版本查看 一:基于conda安装 conda create --name p ...

  8. 学习python/pytorch过程中遇到的知识点

    Pytorch torch.backends.cudnn.deterministic 和 torch.backends.cudnn.benchmark 这两个参数,用于固定算法,使每次运行结果都一样. ...

  9. Python Pytorch

    学习基础知识 大多数机器学习工作流程都涉及处理数据.创建模型.优化模型参数和保存经过训练的模型.本教程向您介绍在 PyTorch 中实现的完整 ML 工作流,并提供链接以了解有关每个概念的更多信息. ...

  10. python pytorch语音识别_PyTorch通过ASR实现语音到文本端的模型以及pytorch语音识别(speech) - pytorch中文网...

    ASR,英文的全称是Automated Speech Recognition,即自动语音识别技术,它是一种将人的语音转换为文本的技术.今天我们主要了解pytorch实现语音到文本的端到端模型. spe ...

最新文章

  1. Android应用中如何保护JAVA代码
  2. 血泪史:阿里云+ubuntu+vnc+xfce4
  3. Oracle用户创建及设置
  4. DOM4J介绍与代码示例【转载】
  5. C++/CX:类的继承
  6. Android Studio中导入第三方库
  7. Oracle11gR1中细粒度访问网络服务(转)
  8. iOS开发之SQLite的Object-C封装
  9. @param注解什么意思_你对Java注解真的理解吗?
  10. PowerPoint 消除所有动画VBA指令
  11. 孩子学python_小孩子的内心世界
  12. Unity 接入Apple登录
  13. CSS画出半圆,四分之一圆,三角等图形
  14. 学生台灯哪个品牌的专业?盘点小学生台灯品牌排行榜
  15. 2021-05-30_蓝桥杯嵌入式拓展板STM32G431--光敏电阻
  16. 网站403报错问题原因解答
  17. FreeType字体引擎介绍
  18. 1 -【第十一届】蓝桥杯物联网试题(模拟题)
  19. serv-u ftp server是什么?如何利用花生壳搭建ftp服务器?
  20. Python爬虫系列(2)

热门文章

  1. 华为计算机单机pc游戏软件,华为应用市场pc端
  2. 【MCMC】PyMC2库进行MCMC估计线性回归参数
  3. java设计模式 (二) 创建模式
  4. 腾讯AI在星际2完整对战中击败“作弊级”内建Bot
  5. 英语词性篇 - 英语疑问词
  6. NOIP2016 买铅笔【模拟】
  7. 架构师接龙:岳旭强 VS. 杨卫华
  8. 日记侠:朋友圈一定要刷屏吗?
  9. 网络信息安全:一、端口安全
  10. 黑板模式(Blackboard Design Pattern)。