Kaggle经典案例1

dogs vs cats

Dogs vs,Cats是一个传统的二分类问题,其训练集包含25000张图片,均放置在同一文件夹下,命名格式为,num>,jpg ,如at,10880.jpg 、
dog.18.jpg ,测试集包含1250张图片,命名为num>jpg ,如10.jpg 。参赛者需根据训练集的图片训练模型,并在测试集上进行预测,输出它是狗的杨
率。最后提交的csv文件如下,第一列是图片的~,第二列是图片为狗的概率。

1.要点梳理

程序主要包含以下功能:

  • 模型定义
  • 数据加载
  • 训练和测试

2.数据加载模块

数据的相关处理主要保存在data/dataset.py中。关于数据加载的相关操作,在上一章中我们已经提到过,其基本原理就是使用Dataset提供数据集的封装,再使用Dataloader实现数据并行加载。Kaggle提供的数据包括训练集和测试集,而我们在实际使用中,还需专门从训练集中取出一部分作为验证集。对于这三类数据集,其相应操作也不太一样,而如果专门写三个Dataset,则稍显复杂和冗余,因此这里通过加一些判断来区分。对于训练集,我们希望做一些数据增强处理,如随机裁剪、随机翻转、加噪声等,而验证集和测试集则不需要。下面看dataset.py的代码:

点击展开/折叠代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T


class DogCat(data.Dataset):

def __init__(self, root, transforms=None, train=True, test=False):
"""
目标:获取所有图片地址,并根据训练、验证、测试划分数据
"""
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]

# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1])) #lambda表示构造函数,传入的参数是x
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))

imgs_num = len(imgs)

# 划分训练、验证集,验证:训练 = 3:7
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7*imgs_num)]
else :
self.imgs = imgs[int(0.7*imgs_num):]

if transforms is None:

# 数据转换操作,测试验证和训练的数据转换有所区别

normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])

# 测试集和验证集
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
# 训练集
else :
self.transforms = T.Compose([
T.Resize(256),
T.RandomReSizedCrop(224),
T.RandomHorizontalFlip(), #训练集中需要一点杂质
T.ToTensor(),
normalize
])


def __getitem__(self, index):
"""
返回一张图片的数据
对于测试集,没有label,返回图片id,如1000.jpg返回1000
"""
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label

def __len__(self):
"""
返回数据集中所有图片的个数
"""
return len(self.imgs)

3.模型定义模块

模型的定义主要保存在models/目录下,其中BasicModule是对nn.Module的简易封装,提供快速加载和保存模型的接口。

点击展开/折叠代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class BasicModule(t.nn.Module):
"""
封装了nn.Module,主要提供save和load两个方法
"""

def __init__(self):
super(BasicModule,self).__init__()
self.model_name = str(type(self)) # 模型的默认名字

def load(self, path):
"""
可加载指定路径的模型
"""
self.load_state_dict(t.load(path))

def save(self, name=None):
"""
保存模型,默认使用“模型名字+时间”作为文件名,
如AlexNet_0710_23:57:29.pth
"""
if name is None:
prefix = 'checkpoints/' + self.model_name + '_'
name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
t.save(self.state_dict(), name)
return name
Donate
  • Copyright: Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source.
  • Copyrights © 2022-2024 Yutouegg
  • Visitors: | Views:

请我喝杯咖啡吧~

支付宝
微信