本文基于AI研习社2020年10月30日发布的“图像场景分类挑战赛”完成,主要是从一个小白如何去将一个图像分类的任务用代码跑起来的角度写的,也是自己的一个学习过程。
图像场景分类挑战赛:https://god.yanxishe.com/97?from=god_home_list
挂载Google Drive (在Colab中将Google云盘载入进来)
from google.colab import drive drive.mount('/content/drive')
解压文件(解压数据集压缩包文件到当前运行环境)
!cp -r /content/drive/My\ Drive/Scene/Image_Classification.zip ./ #将google云盘中的数据集压缩文件拷贝到当前运行环境 !unzip Image_Classification.zip #将数据集压缩文件解压,在当前运行环境得到'train'文件夹、'test'文件夹和'train.csv'文件
创建一个文件夹存放训练好的模型
! mkdir /content/drive/My\ Drive/Scene/checkpoint
导包(导入所有要用的包,在写代码过程中需要一个补充一个即可)
import torch import pandas as pd from PIL import Image from torchvision import transforms, models from torch.utils.data import random_split, DataLoader import os import torch.nn as nn import time import torch.optim as optim
查看是否正在使用GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'CPU') print(device)
读取标签文件(读取训练集的带标签文件,此处为CSV格式文件)
def readLabelFile(): label_file = pd.read_csv('train.csv') return label_file['filename'],label_file['label'] filename,filelabel = readLabelFile() map = ['buildings', 'street', 'forest', 'sea', 'mountain', 'glacier'] num_class = len(map) #将label中的字符串转换为数字 for i in range(len(map)): filelabel[filelabel==map[i]] = i #将对象转换为列表 filename = filename.values filelabel = filelabel.values
定义读取数据集的类(包括训练集和测试集)
class TrainDataset(torch.utils.data.Dataset): def __init__(self, root, img_list, label_list, transform = None): self.root = root self.img_list = img_list self.label_list = label_list self.transform = transform def __getitem__(self, index): img = Image.open(self.root + self.img_list[index]).convert('RGB') label = self.label_list[index] if self.transform: img = self.transform(img) return img,label def __len__(self): return len(self.img_list) class TestDataset(torch.utils.data.Dataset): def __init__(self, img_path, transform = None): self.img_path = img_path self.transform = transform def __getitem__(self, index): img = Image.open(self.img_path[index]).convert('RGB') if self.transform: img = self.transform(img) return img,index def __len__(self): return len(self.img_path)
预处理(对数据集进行预处理)
transform = { 'train': transforms.Compose([ transforms.Resize((224, 224),interpolation=2), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]), 'val': transforms.Compose([ ]) }
调用读取数据集的类(包括训练集和测试集)
train_dataset = TrainDataset('./train/', filename, filelabel, transform['train']) tra_dataset, val_dataset = random_split(train_dataset, [10000, 3627]) test_dataset = TestDataset([x.path for x in os.scandir('./test/')], transform['train']) tra_loader = DataLoader(tra_dataset, batch_size=64, shuffle=True, num_workers=2) val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2) tra_dataset_num = tra_dataset.__len__()
初始化预训练模型
def initializeModel(model_name, num_class, finetuning=False, pretrained=True): if model_name == 'alexnet': model = models.alexnet(pretrained=pretrained) elif model_name == 'vgg11': model = models.vgg11(pretrained=pretrained) elif model_name == 'vgg11_bn': model = models.vgg11_bn(pretrained=pretrained) elif model_name == 'vgg13': model = models.vgg13(pretrained=pretrained) elif model_name == 'vgg13_bn': model = models.vgg13_bn(pretrained=pretrained) elif model_name == 'vgg16': model = models.vgg16(pretrained=pretrained) elif model_name == 'vgg16_bn': model = models.vgg11(pretrained=pretrained) elif model_name == 'vgg19': model = models.vgg11(pretrained=pretrained) elif model_name == 'vgg19_bn': model = models.vgg11(pretrained=pretrained) elif model_name == 'resnet18': model = models.resnet18(pretrained=pretrained) elif model_name == 'resnet34': model = models.resnet34(pretrained=pretrained) elif model_name == 'resnet50': model = models.resnet50(pretrained=pretrained) elif model_name == 'resnet101': model = models.resnet101(pretrained=pretrained) elif model_name == 'resnet152': model = models.resnet152(pretrained=pretrained) elif model_name == 'squeezenet1_0': model = models.squeezenet1_0(pretrained=pretrained) elif model_name == 'squeezenet1_1': model = models.squeezenet1_1(pretrained=pretrained) elif model_name == 'densenet121': model = models.densenet121(pretrained=pretrained) elif model_name == 'densenet169': model = models.densenet169(pretrained=pretrained) elif model_name == 'densenet161': model = models.densenet161(pretrained=pretrained) elif model_name == 'densenet201': model = models.densenet201(pretrained=pretrained) elif model_name == 'inception_v3': model = models.inception_v3(pretrained=pretrained) elif modle_name == 'googlenet': model = models.googlenet(pretrained=pretrained) elif model_name == 'shufflenet_v2_x0_5': model = models.shufflenet_v2_x0_5(pretrained=pretrained) elif model_name == 'shufflenet_v2_x1_0': model = models.shufflenet_v2_x1_0(pretrained=pretrained) elif model_name == 'shufflenet_v2_x1_5': model = models.shufflenet_v2_x1_5(pretrained=pretrained) elif model_name == 'shufflenet_v2_x2_0': model = models.shufflenet_v2_x2_0(pretrained=pretrained) elif model_name == 'mobilenet_v2': model = models.mobilenet_v2(pretrained=pretrained) elif model_name == 'resnext50_32x4d': model = models.resnext50_32x4d(pretrained=pretrained) elif model_name == 'resnext101_32x8d': model = models.resnext101_32x8d(pretrained=pretrained) elif model_name == 'wide_resnet50_2': model = models.wide_resnet50_2(pretrained=pretrained) elif model_name == 'wide_resnet101_2': model = models.wide_resnet101_2(pretrained=pretrained) elif model_name == 'mnasnet0_5': model = models.mnasnet0_5(pretrained=pretrained) elif model_name == 'mnasnet0_75': model = models.mnasnet0_75(pretrained=pretrained) elif model_name == 'mnasnet1_0': model = models.mnasnet1_0(pretrained=pretrained) elif model_name == 'mnasnet1_3': model = models.mnasnet1_3(pretrained=pretrained) else: raise ValueError('No such Model %s' % model_name) if finetuning: for param in model.parameters(): param.requires_grad = True else: for param in model.parameters(): param.requires_grad = False fc_features = model.fc.in_features #提取预训练网络模型fc层中固定的参数 model.fc = nn.Linear(fc_features, num_class) #将预训练网络模型fc层中最终分类的类别数修改为数据集的类别数 model = model.to(device) #将模型加载到指定设备(GPU)上 return model
定义训练方法
def traWay(model, criterion, optimizer, epochs): begin_time = time.time() once_begin_time = begin_time for epoch in range(epochs): print('Epoch {}/{}'.format(epoch+1, epochs)) print('-' * 10) running_loss = 0.0 running_corrects = 0.0 #遍历数据集 for img, labels in tra_loader: img = img.to(device) labels = labels.to(device) optimizer.zero_grad() #将梯度初始化为零 outputs = model(img) #前向传播求出预测的值 preds = torch.argmax(outputs, dim=1) loss = criterion(outputs, labels) loss.backward() optimizer.step() #对参数进行更新 running_loss += loss.item() * img.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss/tra_dataset_num epoch_acc = running_corrects/tra_dataset_num print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc)) print('Training Time per Epoch {}'.format(time.time() - once_begin_time)) once_begin_time = time.time() end_time = time.time() - begin_time print('Training complete in {:.0f}m {:.0f}s'.format(end_time // 60, end_time % 60)) return model
训练
model = initializeModel('resnet152', num_class, True) optimizer = optim.Adam(model.fc.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() pre_check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best.tar' if '152_state_best.tar' in os.listdir(r'/content/drive/My Drive/Scene/checkpoint'): print('loading previous state......') checkpoint = torch.load(pre_check_name) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) model = traWay(model, criterion, optimizer, 1) check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best.tar' torch.save({ 'model_state_dict':model.state_dict(), 'optimizer_state_dict':optimizer.state_dict() },check_name)
测试
model = initializeModel('resnet152', num_class, False) check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best.tar' checkpoint = torch.load(check_name) model.load_state_dict(checkpoint['model_state_dict']) with open('./result.txt', mode='w') as result_file: for img, index in test_loader: img = img.to(device) outputs = model(img) preds = torch.argmax(outputs, dim=1) for i in range(index.shape[0]): print(str(np.array(index)[i].item())+','+str(map[preds[i]])) result_file.write(str(index[i].item())+','+str(map[preds[i]])+'\n')
联系客服