首页 > 编程 > Python > 正文

使用PyTorch实现MNIST手写体识别代码

2020-02-15 21:29:29
字体:
来源:转载
供稿:网友

实验环境

win10 + anaconda + jupyter notebook

Pytorch1.1.0

Python3.7

gpu环境(可选)

MNIST数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。

导入包

import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformstorch.__version__

定义超参数

BATCH_SIZE=512EPOCHS=20 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

数据集

我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False

train_loader = torch.utils.data.DataLoader(    datasets.MNIST('data', train=True, download=True,             transform=transforms.Compose([              transforms.ToTensor(),              transforms.Normalize((0.1307,), (0.3081,))            ])),    batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(    datasets.MNIST('data', train=False, transform=transforms.Compose([              transforms.ToTensor(),              transforms.Normalize((0.1307,), (0.3081,))            ])),    batch_size=BATCH_SIZE, shuffle=True)

定义网络

该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。

class ConvNet(nn.Module):  def __init__(self):    super().__init__()    self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24)     self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)    self.fc1 = nn.Linear(20*10*10,500)    self.fc2 = nn.Linear(500,10)  def forward(self,x):    in_size = x.size(0)    out = self.conv1(x)    out = F.relu(out)    out = F.max_pool2d(out, 2, 2)     out = self.conv2(out)    out = F.relu(out)    out = out.view(in_size,-1)    out = self.fc1(out)    out = F.relu(out)    out = self.fc2(out)    out = F.log_softmax(out,dim=1)    return out

实例化网络

model = ConvNet().to(DEVICE) # 将网络移动到gpu上optimizer = optim.Adam(model.parameters()) # 使用Adam优化器

定义训练函数

def train(model, device, train_loader, optimizer, epoch):  model.train()  for batch_idx, (data, target) in enumerate(train_loader):    data, target = data.to(device), target.to(device)    optimizer.zero_grad()    output = model(data)    loss = F.nll_loss(output, target)    loss.backward()    optimizer.step()    if(batch_idx+1)%30 == 0:       print('Train Epoch: {} [{}/{} ({:.0f}%)]/tLoss: {:.6f}'.format(        epoch, batch_idx * len(data), len(train_loader.dataset),        100. * batch_idx / len(train_loader), loss.item()))            
发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表