当前位置:网站首页>Pytorch:torchvision包-总结
Pytorch:torchvision包-总结
2022-07-19 05:21:00 【三世】
TORCHVISION 官网地址:torchvision — Torchvision 0.12 documentation
计算机视觉是深度学习中最重要的一类应用,为了方便研究者使用,PyTorch团队专门开发了一个视觉工具包torchvion
,这个包独立于PyTorch,需通过pip instal torchvision
安装。在之前的例子中我们已经见识到了它的部分功能,这里再做一个系统性的介绍。torchvision它是一个视觉工具包,提供了很多视觉图像处理的工具,主要包含三部分:
- datasets: 提供常用的数据集加载,设计上都是继承
torhc.utils.data.Dataset
,主要包括MNIST
、CIFAR10/100
、ImageNet
、COCO
等。 - models:提供深度学习中各种经典网络的网络结构以及预训练好的模型,包括
AlexNet
、VGG系列、ResNet系列、Inception系列等。 - transforms:提供常用的数据预处理操作,主要包括对Tensor以及PIL Image对象的操作。
- utils:用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个mini-batch的图像可以产生一个图像格网。
1、datasets
torchvision.datasets 是用来进行数据加载的,Torchvision 在 torchvision.datasets 模块中提供了许多内置数据集,以及用于构建您自己的数据集的实用程序类。
内置数据集:所有数据集都是 torch.utils.data.Dataset 的子类,即它们实现了 __getitem__ 和 __len__ 方法。 因此,它们都可以传递给 torch.utils.data.DataLoader,它可以使用 torch.multiprocessing 工作线程并行加载多个样本。 内置数据集包含图像分类(Caltech 101物体分类 、CelebA人脸分类、CIFAR10 物体汽车分类、MNIST 手写数据分类、LFW人脸验证、ImageNet图像分类等等),图像检测与分割(Cityscapes Dataset、Pascal VOC Segmentation Dataset),以及光流法Optical Flow数据集、视频分类数据集等等。
例如:
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
from torchvision import datasets
from torch.utils.data import DataLoader
normalize = T.Normalize(mean=3, std=0.2)
transform = T.Compose([
#T.RandomResizedCrop(224),
#T.RandomHorizontalFlip(),
T.ToTensor(),
normalize,
])
# 指定数据集路径为data,如果数据集不存在则进行下载
# 通过train=False获取测试集
dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)
dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
img = make_grid(next(dataiter)[0], 4) # 拼成4*4网格图片,且会转成3通道
to_img = T.ToPILImage()
to_img(img)
save_image(img, 'a.png')
Image.open('a.png')
自定义datasets类的基类:
DatasetFolder(root, loader, Any], …) | A generic data loader.通用数据加载 |
ImageFolder(root, transform, …) | A generic data loader where the images are arranged in this way by default: .通用图像数据加载 |
VisionDataset(root, transforms, transform, …) | Base Class For making datasets which are compatible with torchvision. |
2、models
torchvision.models 包含用于解决不同任务的模型定义,包括:图像分类、像素语义分割、对象检测、实例分割、人物关键点检测、视频分类和光流。
其中包含如下模型:
Inception v3
ShuffleNet v2
通过如下方法构造:
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
convnext_tiny = models.convnext_tiny()
convnext_small = models.convnext_small()
我们使用 PyTorch torch.utils.model_zoo 提供预训练模型。 这些可以通过传递 pretrained=True 来构造:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
The process for obtaining the values of mean and std is roughly equivalent to:
import torch
from torchvision import datasets, transforms as T
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)
means = []
stds = []
for img in subset(dataset):
means.append(torch.mean(img))
stds.append(torch.std(img))
mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))
以下架构为 INT8 量化模型提供支持。 您可以通过调用其构造函数来获得具有随机权重的模型:
googlenet = models.quantization.googlenet()
inception_v3 = models.quantization.inception_v3()
mobilenet_v2 = models.quantization.mobilenet_v2()
mobilenet_v3_large = models.quantization.mobilenet_v3_large()
resnet18 = models.quantization.resnet18()
resnet50 = models.quantization.resnet50()
只需几行代码即可获得预训练的量化模型:
import torchvision.models as models
model = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))
语义分割
模型包包含以下语义分割模型架构的定义:
目标检测、实例分割和人物关键点检测
模型包包含以下用于检测的模型架构的定义:
3、transforms
其中transforms
模块提供了对PIL Image
对象和Tensor
对象的常用操作。张量图像是具有 (C, H, W) 形状的张量,其中 C 是通道数,H 和 W 是图像的高度和宽度。 A batch of Tensor Images 是一个 (B, C, H, W) 形状的张量,其中 B 是 batch 中的图像数量。
对PIL Image的操作包括:
Scale
:调整图片尺寸,长宽比保持不变CenterCrop
、RandomCrop
、RandomResizedCrop
: 裁剪图片Pad
:填充ToTensor
:将PIL Image对象转成Tensor,会自动将[0, 255]归一化至[0, 1]
对Tensor的操作包括:
- Normalize:标准化,即减均值,除以标准差
- ToPILImage:将Tensor转为PIL Image对象
如果要对图片进行多个操作,可通过Compose
函数将这些操作拼接起来,类似于nn.Sequential
。注意,这些操作定义后是以函数的形式存在,真正使用时需调用它的__call__
方法,这点类似于nn.Module
。例如要将图片调整为224×224,首先应构建这个操作trans = Resize((224, 224))
,然后调用trans(img)
。下面我们就用transforms的这些操作来优化上面实现的dataset。
官方用例子解释了transforms具体的用法:
1、Illustration of transforms
4、utils
The torchvision.utils
module contains various utilities, mostly for vizualization.
draw_bounding_boxes(image, boxes[, labels, …]) | Draws bounding boxes on given image. |
draw_segmentation_masks(image, masks[, …]) | Draws segmentation masks on given RGB image. |
draw_keypoints(image, keypoints[, …]) | Draws Keypoints on given RGB image. |
flow_to_image(flow) | Converts a flow to an RGB image. |
make_grid(tensor[, nrow, padding, …]) | Make a grid of images. |
save_image(tensor, fp[, format]) | Save a given Tensor into an image file. |
此示例说明了 torchvision 提供的用于可视化图像、边界框、分割掩码和关键点的一些实用程序。
代码例子地址:Visualization utilities — Torchvision 0.12 documentation
边界框代码
from torchvision.utils import draw_bounding_boxes
boxes = torch.tensor([[50, 50, 100, 200], [210, 150, 350, 430]], dtype=torch.float)
colors = ["blue", "yellow"]
result = draw_bounding_boxes(dog1_int, boxes, colors=colors, width=5)
show(result)
5、READING/WRITING IMAGES AND VIDEOS
torchvision.io 包提供了执行 IO 操作的函数。 它们目前专门用于读取和写入视频和图像。
Video
read_video(filename[, start_pts, end_pts, …]) | Reads a video from a file, returning both the video frames as well as the audio frames |
read_video_timestamps(filename[, pts_unit]) | List the video frames timestamps. |
write_video(filename, video_array, fps[, …]) | Writes a 4d tensor in [T, H, W, C] format in a video file |
Image
ImageReadMode(value) | Support for various modes while reading images. |
read_image(path[, mode]) | Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor. |
decode_image(input[, mode]) | Detects whether an image is a JPEG or PNG and performs the appropriate operation to decode the image into a 3 dimensional RGB or grayscale Tensor. |
encode_jpeg(input[, quality]) | Takes an input tensor in CHW layout and returns a buffer with the contents of its corresponding JPEG file. |
decode_jpeg(input[, mode, device]) | Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor. |
write_jpeg(input, filename[, quality]) | Takes an input tensor in CHW layout and saves it in a JPEG file. |
encode_png(input[, compression_level]) | Takes an input tensor in CHW layout and returns a buffer with the contents of its corresponding PNG file. |
decode_png(input[, mode]) | Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. |
write_png(input, filename[, compression_level]) | Takes an input tensor in CHW layout (or HW in the case of grayscale images) and saves it in a PNG file. |
read_file(path) | Reads and outputs the bytes contents of a file as a uint8 Tensor with one dimension. |
write_file(filename, data) | Writes the contents of a uint8 tensor with one dimension to a file. |
边栏推荐
- Disk and file system
- MySQL Galera cluster configuration
- VLAN overview
- [resource record] Introduction to Bayesian neural network (BNN), common packages and differences
- String cache pool and integer cache pool
- 常用软件快捷键
- ZABBIX automatically discovers and monitors GPU
- CEPH detailed mon_ osd_ max_ split_ count
- invalid syntax
- zabbix自动发现并监控GPU
猜你喜欢
Manually compile and install Apache
【资源记录】作为程序员 对统计学中的卡方分布/检测,t分布/检测,f分布/检测 的自学记录
System safety and Application
Variational Inference 笔记 from UCB CS 285 Sergey Levine
gocore-v2框架-脚手架生成项目结构介绍
C语言基础知识
Network packet capturing to understand the establishment process of TCP triple handshake
关于正向代理和反向代理的理解
狂神。SMBMS(超市订单管理系统)
pytorch梯度的计算过程
随机推荐
gocore-v2框架-API接口开发理念
Scala case (companion object)
System safety and Application
Solution to unmount failure using umount command
Transport layer protocol
Lombok cooperates with logback to realize the simplest log output
Manually compile and install Apache
二叉树
pytorch之nn.Conv1d详解
lambda用法
静态路由工作原理与配置
【转】解决内存/显存泄露的方法 pytorch
猫和狗的分类例子-Kaggle
Lunix boot and troubleshooting
kvm虚拟机迁移到openstack环境,提示InvalidDiskInfo Disk info file is invalid qemu-img fai
Overview of key core technologies of intelligent operation and maintenance aiops worth seeing
js 使元素获取或失去焦点
ModuleNotFoundError: No module named 'cv2'
[turn] method to solve memory / video memory leakage pytorch
Mikrotik ROS软路由配置PCC负载均衡实现双宽带叠加