CIFAIR-10分类

入门实战篇

小试牛刀: CIFAR-10分类
下面我们来尝试实现对CIFAR-10数据集的分类,步骤如下:
1.使用torchvision加载并预处理CIFAR-10数据集
2.定义网络
3定义损失函数和优化器
4.训练网络并更新网络参数
5.测试网络

数据预处理部分:

1
2
3
4
5
6
7
8
9
10
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() # 可以把Tensor转成Image,方便可视化
```
commandline

# 定义对数据的预处理
<details>
<summary>点击展开/和并代码</summary>

transform = transforms.Compose([
transforms.ToTensor(), # 转为Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
])

训练集

trainset = tv.datasets.CIFAR10(
root=’/home/cy/tmp/data/‘,
train=True,
download=True,
transform=transform)

trainloader = t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2)

测试集

testset = tv.datasets.CIFAR10(
‘/home/cy/tmp/data/‘,
train=False,
download=True,
transform=transform)

testloader = t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2)

classes = (‘plane’, ‘car’, ‘bird’, ‘cat’,
‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’)


</details>

Donate
  • Copyright: Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source.

扫一扫,分享到微信

微信分享二维码
  • Copyrights © 2022-2024 Yutouegg
  • Visitors: | Views:

请我喝杯咖啡吧~

支付宝
微信