pytorch一步一步在VGG16上训练自己的数据集-程序员宅基地

技术标签: 数据集加载  pytorch  

准备数据集及加载,ImageFolder

在很多机器学习或者深度学习的任务中,往往我们要提供自己的图片。也就是说我们的数据集不是预先处理好的,像mnist,cifar10等它已经给你处理好了,更多的是原始的图片。比如我们以猫狗分类为例。在data文件下,有两个分别为train和val的文件夹。然后train下是cat和dog两个文件夹,里面存的是自己的图片数据,val文件夹同train。这样我们的数据集就准备好了。
在这里插入图片描述
ImageFolder能够以目录名作为标签来对数据集做划分,下面是pytorch中文文档中关于ImageFolder的介绍:
在这里插入图片描述

#对训练集做一个变换
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),		#对图片尺寸做一个缩放切割
    transforms.RandomHorizontalFlip(),		#水平翻转
    transforms.ToTensor(),					#转化为张量
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))	#进行归一化
])
#对测试集做变换
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

train_dir = "G:/data/train"           #训练集路径
#定义数据集
train_datasets = datasets.ImageFolder(train_dir, transform=train_transforms)
#加载数据集
train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)

val_dir = "G:/datat/val"		
val_datasets = datasets.ImageFolder(val_dir, transform=val_transforms)
val_dataloader = torch.utils.data.DataLoader(val_datasets, batch_size=batch_size, shuffle=True)

迁移学习以VGG16为例

下面是迁移代码的实现:

class VGGNet(nn.Module):
    def __init__(self, num_classes=2):	   #num_classes,此处为 二分类值为2
        super(VGGNet, self).__init__()
        net = models.vgg16(pretrained=True)   #从预训练模型加载VGG16网络参数
        net.classifier = nn.Sequential()	#将分类层置空,下面将改变我们的分类层
        self.features = net		#保留VGG16的特征层
        self.classifier = nn.Sequential(    #定义自己的分类层
                nn.Linear(512 * 7 * 7, 512),  #512 * 7 * 7不能改变 ,由VGG16网络决定的,第二个参数为神经元个数可以微调
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(512, 128),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

完整代码如下

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
from torchvision import models

batch_size = 16
learning_rate = 0.0002
epoch = 10

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

train_dir = './VGGDataSet/train'
train_datasets = datasets.ImageFolder(train_dir, transform=train_transforms)
train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)

val_dir = './VGGDataSet/val'
val_datasets = datasets.ImageFolder(val_dir, transform=val_transforms)
val_dataloader = torch.utils.data.DataLoader(val_datasets, batch_size=batch_size, shuffle=True)


class VGGNet(nn.Module):
    def __init__(self, num_classes=3):
        super(VGGNet, self).__init__()
        net = models.vgg16(pretrained=True)
        net.classifier = nn.Sequential()
        self.features = net
        self.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 512),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(512, 128),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

#--------------------训练过程---------------------------------
model = VGGNet()
if torch.cuda.is_available():
    model.cuda()
params = [{'params': md.parameters()} for md in model.children()
          if md in [model.classifier]]
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_func = nn.CrossEntropyLoss()

Loss_list = []
Accuracy_list = []



for epoch in range(100):
    print('epoch {}'.format(epoch + 1))
    # training-----------------------------
    train_loss = 0.
    train_acc = 0.
    for batch_x, batch_y in train_dataloader:
        batch_x, batch_y = Variable(batch_x).cuda(), Variable(batch_y).cuda()
        out = model(batch_x)
        loss = loss_func(out, batch_y)
        train_loss += loss.data[0]
        pred = torch.max(out, 1)[1]
        train_correct = (pred == batch_y).sum()
        train_acc += train_correct.data[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
        train_datasets)), train_acc / (len(train_datasets))))

    # evaluation--------------------------------
    model.eval()
    eval_loss = 0.
    eval_acc = 0.
    for batch_x, batch_y in val_dataloader:
        batch_x, batch_y = Variable(batch_x, volatile=True).cuda(), Variable(batch_y, volatile=True).cuda()
        out = model(batch_x)
        loss = loss_func(out, batch_y)
        eval_loss += loss.data[0]
        pred = torch.max(out, 1)[1]
        num_correct = (pred == batch_y).sum()
        eval_acc += num_correct.data[0]
    print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
        val_datasets)), eval_acc / (len(val_datasets))))
        
	Loss_list.append(eval_loss / (len(val_datasets)))
    Accuracy_list.append(100 * eval_acc / (len(val_datasets)))

x1 = range(0, 100)
x2 = range(0, 100)
y1 = Accuracy_list
y2 = Loss_list
plt.subplot(2, 1, 1)
plt.plot(x1, y1, 'o-')
plt.title('Test accuracy vs. epoches')
plt.ylabel('Test accuracy')
plt.subplot(2, 1, 2)
plt.plot(x2, y2, '.-')
plt.xlabel('Test loss vs. epoches')
plt.ylabel('Test loss')
plt.show()
# plt.savefig("accuracy_loss.jpg")
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/hnu_zzt/article/details/85092092

智能推荐

python数据清洗---实战案例(清洗csv文件)_对csv文件进行数据清洗-程序员宅基地

我也是最近才开始这方面的学习,这篇就当作学习的笔记,记录一下学习的过程所要处理的数据数据中主要存在的问题主要包括:1.列名中存在空格2.存在重复数据3.存在缺失数据下面开始对数据进行清洗导入pandas模块,打开数据文件import pandas as pddf = pd.read_csv("ResourceFile.csv")我们输出指定列名print(df.名称)但此时会报错,因为列名"名称"中含有空格,我们输出列名看一下,两种方法方法一:print(df.des._对csv文件进行数据清洗

哈工大计算机系统_Amnesia0810的博客-程序员宅基地

计算机系统大作业题 目程序人生-Hello’s P2P 专 业 计算机学   号 120L022309班 级2003009学 生 孙浩翔 指 导 教 师吴锐 计算机科学与技术学院2021年5月摘 要..._哈工大计算机系统

Python 标准库 zipfile 压缩文件/文件夹_python zipfile压缩文件夹_Likianta Me的博客-程序员宅基地

压缩单个目录时, ZipFile 需要 write 目录, 以及目录下每一个文件 (包括子文件夹的文件) 的路径._python zipfile压缩文件夹

nginx正向代理+反向代理_www.xxoo_学不会go不改名的博客-程序员宅基地

一个入门小白对Nginx的理解,希望可以帮助到你_www.xxoo

Docker安装Ubuntu-程序员宅基地

localfootstep@ubuntu:~$ ifconfigdocker0: flags=4163<UP,BROADCAST,RUNNING,MULTICAST> mtu 1500 inet 172.17.0.1 netmask 255.255.0.0 broadcast 172.17.255.255 inet6 fe80::42:..._docker安装ubuntu

Android开发之华为手机无法看log日志解决方法(亲测可用华为荣耀6)_安卓文件管理器无法打开.log文件】-程序员宅基地

原文 : 转载自程序猿小冰的博客 : http://blog.csdn.net/qq_21376985/article/details/51798992华为荣耀的测试机,发现在Android Studio下无法查看log日志,看不了日志就不方便解决问题了。解决方法:进入手机拨号界面输入:*#*#2846579#*#*依次选择ProjectMenu—后台设置—-LOG_安卓文件管理器无法打开.log文件】

随便推点

jmeter 函数助手使用 之 字符串函数_jmeter中string的应用-程序员宅基地

${__RandomString(,)},该函数有3个参数,第一个参数表示生成的随机字符串位数,第二个表示在哪些字母下生成,第三个表示变量名称,也可以不要变量名称。使用方法:1、点击工具,点击函数助手对话框;2、找到RandomString,点击;3、填写需要的字符串长度,字符串取值范围,以及变量名(可选)。..._jmeter中string的应用

React框架创建项目详细流程-项目的基本配置-项目的代码规范_react创建项目-程序员宅基地

文件夹、文件名称统一小写、多个单词以连接符(-)连接, 组件采用大驼峰;JavaScript变量名称采用小驼峰标识,常量全部使用大写字母;CSS采用普通CSS和styled-component结合来编写(全局采用普通CSS或Less、局部采用styled-component);整个项目不再使用class组件,统一使用函数式组件,并且全面拥抱Hooks;所有的函数式组件,为了避免不必要的渲染,全部使用memo进行包裹;组件内部的状态,使用useState、业务数据全部放在redux中管理;_react创建项目

SpringBoot之parent、starter、引导类、内嵌tomcat_spring-boot-starter-parent tomcat_冬天vs不冷的博客-程序员宅基地

使用parent可以帮助开发者进行版本的统一管理打开后可以查阅到其中又继承了一个坐标这个坐标中定义了两组信息,第一组是各式各样的依赖版本号属性第二组是各式各样的的依赖坐标信息,可以看出依赖坐标定义中没有具体的依赖版本号,而是引用了第一组信息中定义的依赖版本属性值第二组依赖坐标是在依赖管理标签内,则表示只是引入申明,只有在子pom中使用(不用写版本号)依赖才会导入,所有即使..._spring-boot-starter-parent tomcat

memset初始化出错_memset初始化数组最后一个初始化为0有问题-程序员宅基地

memset初始化时只能将数组的值初始化为0或者-1输入其他值则会出错原因:很简单,memset是一个字节一个字节设置的,取要赋的值的后8位二进制进行赋值。1的二进制是(00000000 00000000 00000000 00000001),取后8位(00000001),int型占4个字节,当初始化为1时,它把一个int的每个字节都设置为1,也就是0x01010101,二进制是00000001 000..._memset初始化数组最后一个初始化为0有问题

基于(ztmap)BIM的数字孪生建造智慧机房管理后台展示系统_企业网站bim带后台-程序员宅基地

我国现阶段针对于连续性的中心业务数据量要求相对较大,为了保障中心关键业务数据的相对安全,构建了越来越复杂的设备系统,随之而来的则是数据中心的故障排查,呈现出较大的困难,管理系统复杂而庞大,针对于数据中心的安全维护以及故障排查等问题亟待解决,管理成本的提升,运维效率的无法提高都成为了现阶段数据中心管理的重点任务,因此,为了有效实现数据中心管理的高效能,必须结合现代化先进技术进行创新改良。1、 数据中心管理系统存在的问题1.1管理分散我国现阶段经济高速发展,随之来的在各行各业当中产生了大量的信息化数据,通_企业网站bim带后台

图论总结——最小生成树算法:prim算法、kruskal算法_种下一颗草莓的博客-程序员宅基地

什么是最小生成树?一个有 n 个结点的连通图的生成树是原图的极小连通子图,且包含原图中的所有 n 个结点,并且有保持图连通的最少的边。因此,只有连通图才存在最小生成树,因为所有顶点都可达,如果是非连通图,一定不存在最小生成树。求最短路径时的图一般是有向图求最小生成树时的图一般是无向图最小生成树的应用例如要在n个城市之间铺设光缆,主要目标是要使这 n 个城市的任意两个之间都可以通信,但铺设光缆的费用很高,且各个城市之间铺设光缆的费用不同,因此另一个目标是要使铺设光缆的总费用最低。这就需要找到