【轻松学 Pytorch】构建浅层神经网络

文章来源:原创 gloomyfish OpenCV学堂@微信公众号

关键知识点

前面我们刚刚组队完毕,更新了第一篇【轻松学 Pytorch - 环境搭建与基本语法】,我说我会坚持写下去,这个是我的第二篇,使用pytorch实现简单神经网络完成手写数字识别。这个是所有深度学习框架入门标配的例子,但是从这个例子上我们可以学到pytorch的很多基础知识点,我罗列一下,大致有如下:

1.开始用torch.nn包里面的函数搭建网络
2.模型保存为pt文件与加载调用
3.Torchvision.transofrms来做数据预处理
4.DataLoader简单调用处理数据集

只有理解和看清以上四点才算入门了这个例子。


数据集:

Mnist数据集,数字为0~9、大小为28x28的灰度图像。

加载数据集代码实现:

train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

展示如下:
file

预处理数据方式

file
其中

Totensor表示把灰度图像素值从0~255转化为0~1之间

Normalize表示对输入的减去0.5, 除以0.5

网络结构如下:

输入层:784个神经元

隐藏层:100个神经元

输出层:10个神经元

file
定义损失函数与优化函数
file
开启训练

for s in range(5):    print("run in step : %d"%s)    for i, (x_train, y_train) in enumerate(train_dl):        x_train = x_train.view(x_train.shape[0], -1)        y_pred = model(x_train)        train_loss = loss_fn(y_pred, y_train)        if (i + 1) % 100 == 0:            print(i + 1, train_loss.item())        model.zero_grad()        train_loss.backward()        optimizer.step()

file

测试模型准确率

total = 0;correct_count = 0for test_images, test_labels in test_dl:    for i in range(len(test_labels)):        image = test_images[i].view(1, 784)        with t.no_grad():            pred_labels = model(image)        plabels = t.exp(pred_labels)        probs = list(plabels.numpy()[0])        pred_label = probs.index(max(probs))        true_label = test_labels.numpy()[i]        if pred_label == true_label:            correct_count += 1        total += 1

展示如下:
file
打印准确率与保存模型
file

完整演示代码

import torch as tfrom torch.utils.data import DataLoaderimport torchvision as tvtransform = tv.transforms.Compose([tv.transforms.ToTensor(),                                  tv.transforms.Normalize((0.5,), (0.5,)),                             ])train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_dl = DataLoader(train_ts, batch_size=32, shuffle=True, drop_last=False)test_dl = DataLoader(test_ts, batch_size=64, shuffle=True, drop_last=False)model = t.nn.Sequential(   t.nn.Linear(784, 100),   t.nn.ReLU(),   t.nn.Linear(100, 10),   t.nn.LogSoftmax(dim=1))loss_fn = t.nn.NLLLoss(reduction="mean")optimizer = t.optim.Adam(model.parameters(), lr=1e-3)for s in range(5):   print("run in step : %d"%s)   for i, (x_train, y_train) in enumerate(train_dl):       x_train = x_train.view(x_train.shape[0], -1)       y_pred = model(x_train)       train_loss = loss_fn(y_pred, y_train)       if (i + 1) % 100 == 0:           print(i + 1, train_loss.item())       model.zero_grad()       train_loss.backward()       optimizer.step()total = 0;correct_count = 0for test_images, test_labels in test_dl:   for i in range(len(test_labels)):       image = test_images[i].view(1, 784)       with t.no_grad():           pred_labels = model(image)       plabels = t.exp(pred_labels)       probs = list(plabels.numpy()[0])       pred_label = probs.index(max(probs))       true_label = test_labels.numpy()[i]       if pred_label == true_label:           correct_count += 1       total += 1print("total acc : %.2f\n"%(correct_count / total))t.save(model, './nn_mnist_model.pt')

展示如下:
file
file


运行结果:
file

 推荐阅读 

轻松学Pytorch–环境搭建与基本语法


pytorch 实用工具总结
PyTorch trick 集锦
https://bbs.cvmart.net/topics/1401

微信公众号: 极市平台(ID: extrememart )
每天推送最新CV干货