Pytorch学习1
Pytorch深度学习(1)
如果想要用pytorch让计算机认出一只猫,一共需要5步
准备数据
定义模型
训练模型
评估模型
做出预测
第一步
如果我们遇到不会的题目,我们可以通过学习资料的帮助,来学会这个问题
计算机同理,我们需要准备大量的数据,让计算机明白这个问题的解决方案
pytorch中
import torchvision #模块
该模块提供了一些已有数据集的直接导入
包括CIFAR,COCO,imagineNet,MNIST等文本数据集以及音频数据集等常见数据集
以下是一段获取数据的方式
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from matplotlib import pyplot
#以上为导入模块
path = './datasets/mnist'
#定义数据集的位置
trans = Compose([ToTensor()])
#下载并定义数据集
train = MNIST(path, train=True , download=True , transform=trans)
test = MNIST(path, train=False , download=True , transform=trans)
#定义如何枚举数据集
train_dl = DataLoader(train, batch_size=32, shuffle=True)
test_dl = DataLoader(test, batch_size=32, shuffle=True)
#以batch方式获取图片
i , (inputs,targets) = next(enumerate(train_dl))
#绘图,进行展示
for i in range(25):
pyplot.subplot(5,5,i+1)
pyplot.imshow(input[i][0],cmap='gray')
pyplot.show()
MNIST是官方提供的手写数据集,path即为下载保存的位置
trans在此处被定义成图像的预处理方式
Compose([ToTensor()])
Compose是一个顺序处理图像的方法,ToTensor是因为pytorch中只能处理Tensor数据,而不能直接处理图像数据
train = MNIST(path, train=True , download=True , transform=trans)
path是指定查找数据集或下载数据集的位置
train值为True则使用6万训练集,False则用1万的测试集
download则为是否要下载数据集
transform是预处理操作(常见变换Normalize() 方差归一化,RandomCrop() 图像大小调整)
train_dl = DataLoader(train, batch_size=32, shuffle=True)
#创建数据加载器
#1:数据集,2:每个队列的大小,3:是否打乱
i , (inputs,targets) = next(enumerate(train_dl))
#enumerate(train_dl)返回一个(index,(inputs, targets))元组
#inputs [32,1,28,28]:32张图片,1通道(灰度图),28*28像素
第二步
定义模型
需要导入3个概念:卷积层,池化层,全连接层
卷积层
是计算机认识一张图片最基础的步骤。把一张图像,转换为一系列由0和1组成的数据矩阵
池化层
把原来很大的图像矩阵压缩,变成一张更小更容易计算的图像矩阵
全连接层
把更小的图像展开成一段数据
第三步
训练模型
将训练数据输入模型,计算损失并调整更新模型参数,这样重复进行多个周期,直到模型很好的学习到了训练数据里的特征
训练过程中,我们需要定义损失函数和优化器
损失函数:用于衡量模型预测的准确性
优化器:用于调整模型的参数
第四步
评估模型
例如:现在是判断手写数字,我们需要用训练好的模型判断一下,是不是和我们写下的数字一致
第五步
做出预测
即使用模型
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 SUの小站!
评论