使用PyTorch实现猫狗分类Python源码及准确度对比(CNN、VGG16迁移学习两张方式)
数据集下载Dogs vs. Cats ~| Kaggle不同方法准确度方法轮数准确度cnn567.64%cnn1074.92%cnn1573.42%cnn2079.28%cnn2578.28%vgg16586.5%vgg161086.98%vgg161585.42%cnn.pyimporttorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromPILimportImageimportos num_epochs20batch_size50learning_rate0.001train_size25000indicestorch.randperm(train_size)train_indicesindices[:20000]test_indicesindices[20000:]classDogsVsCatsDataset(Dataset):def__init__(self,root,trainTrue,transformNone):super().__init__()self.rootroot self.transformtransform self.classes[dog,cat]self.files[]self.labels[]filesos.listdir(root)indextrain_indicesiftrainelsetest_indicesforiinindex:filefiles[i]self.files.append(file)ifdoginfile:self.labels.append(0)else:self.labels.append(1)def__len__(self):returnlen(self.files)def__getitem__(self,index):pathos.path.join(self.root,self.files[index])imageImage.open(path).convert(RGB)labelself.labels[index]ifself.transform:imageself.transform(image)returnimage,label transformtransforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(degrees30),# transforms.RandomResizedCrop(# size224, scale(0.08, 1.0), ratio(0.75, 1.33333)# ),transforms.ColorJitter(brightness0.2,contrast0.2,saturation0.2,hue0.1),transforms.ToTensor(),transforms.Normalize(mean(0.485,0.456,0.406),std(0.229,0.224,0.225)),])train_datasetDogsVsCatsDataset(root.\\data\\Dogs Vs Cats\\train,trainTrue,transformtransform)test_datasetDogsVsCatsDataset(root.\\data\\Dogs Vs Cats\\train,trainFalse,transformtransform)train_loaderDataLoader(train_dataset,batch_sizebatch_size,shuffleTrue)test_loaderDataLoader(test_dataset,batch_sizebatch_size,shuffleFalse)devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)classCNNModel(nn.Module):def__init__(self):super().__init__()self.cnn1nn.Sequential(nn.Conv2d(3,24,kernel_size3,stride1,padding1),nn.BatchNorm2d(24),nn.MaxPool2d(kernel_size2,stride2),)self.cnn2nn.Sequential(nn.Conv2d(24,48,kernel_size3,stride1,padding1),nn.BatchNorm2d(48),nn.MaxPool2d(kernel_size2,stride2),)self.cnn3nn.Sequential(nn.Conv2d(48,96,kernel_size3,stride1,padding1),nn.BatchNorm2d(96),nn.MaxPool2d(kernel_size2,stride2),)self.cnn4nn.Sequential(nn.Conv2d(96,48,kernel_size3,stride1,padding1),nn.BatchNorm2d(48),nn.MaxPool2d(kernel_size2,stride2),)self.dropoutnn.Dropout()self.line1nn.Linear(14*14*48,512)self.line2nn.Linear(512,2)defforward(self,x):outself.cnn1(x)outself.cnn2(out)outself.cnn3(out)outself.cnn4(out)outout.reshape(out.size(0),-1)outself.dropout(out)outself.line1(out)outself.line2(out)returnout modelCNNModel().to(device)criterionnn.CrossEntropyLoss()optimizeroptim.Adam(model.parameters(),lrlearning_rate)model.train()forepochinrange(num_epochs):fori,(image,label)inenumerate(train_loader):imageimage.to(device)labellabel.to(device)outputmodel(image)losscriterion(output,label)optimizer.zero_grad()loss.backward()optimizer.step()if(i1)%batch_size0:print(Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}.format(epoch1,num_epochs,i1,len(train_loader),loss.item()))model.eval()withtorch.no_grad():total0correct0forimage,labelintest_loader:imageimage.to(device)labellabel.to(device)outputmodel(image)_,predicttorch.max(output,1)totallen(label)correct(predictlabel).sum().item()print(Accuracy of test {} images: {} %.format(len(test_dataset),correct/total*100))vgg16.pyimporttorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimportmodels,transformsfromPILimportImageimportos num_epochs5batch_size10learning_rate0.001train_size25000indicestorch.randperm(train_size)train_indicesindices[:20000]test_indicesindices[20000:]classDogsVsCatsDataset(Dataset):def__init__(self,root,trainTrue,transformNone):super().__init__()self.rootroot self.transformtransform self.classes[dog,cat]self.files[]self.labels[]filesos.listdir(root)indextrain_indicesiftrainelsetest_indicesforiinindex:filefiles[i]self.files.append(file)ifdoginfile:self.labels.append(0)else:self.labels.append(1)def__len__(self):returnlen(self.files)def__getitem__(self,index):pathos.path.join(self.root,self.files[index])imageImage.open(path).convert(RGB)labelself.labels[index]ifself.transform:imageself.transform(image)returnimage,label transformtransforms.Compose([transforms.RandomResizedCrop(size224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(degrees30),transforms.ColorJitter(brightness0.2,contrast0.2,saturation0.2,hue0.1),transforms.ToTensor(),transforms.Normalize(mean(0.485,0.456,0.406),std(0.229,0.224,0.225)),])train_datasetDogsVsCatsDataset(root.\\data\\Dogs Vs Cats\\train,trainTrue,transformtransform)test_datasetDogsVsCatsDataset(root.\\data\\Dogs Vs Cats\\train,trainFalse,transformtransform)train_loaderDataLoader(train_dataset,batch_sizebatch_size,shuffleTrue)test_loaderDataLoader(test_dataset,batch_sizebatch_size,shuffleFalse)devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)modelmodels.vgg16(weightsmodels.VGG16_Weights.IMAGENET1K_V1)forparaminmodel.features.parameters():param.requires_gradFalsemodel.classifier[6].out_features2modelmodel.to(device)criterionnn.CrossEntropyLoss()optimizeroptim.Adam(model.parameters(),lrlearning_rate)model.train()forepochinrange(num_epochs):fori,(image,label)inenumerate(train_loader):imageimage.to(device)labellabel.to(device)outputmodel(image)losscriterion(output,label)optimizer.zero_grad()loss.backward()optimizer.step()if(i1)%1000:print(Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}.format(epoch1,num_epochs,i1,len(train_loader),loss.item()))model.eval()withtorch.no_grad():total0correct0forimage,labelintest_loader:imageimage.to(device)labellabel.to(device)outputmodel(image)_,predicttorch.max(output,1)totallen(label)correct(predictlabel).sum().item()print(Accuracy of test {} images: {} %.format(len(test_dataset),correct/total*100))参考文档PyTorch猫狗大战CNN vs VGG16迁移学习谁更胜一筹- 超腾开源