type
status
date
slug
summary
tags
category
icon
password
MNIST 手写数字识别
本次测验需要学生基于MindSpore的API来快速实现一个简单的深度学习模型,完成手写数字识别的任务。该任务旨在考核学生对使用MindSpore完成深度学习全流程的掌握
学习资料
- 虚假的学习资料:
环境配置
- MindSpore 2.0, 安装教程:https://www.mindspore.cn/install
- download,可使用命令
pip install download
安装
如本练习以Notebook运行时,完成安装后需要重启kernel才能执行后续代码。
处理数据集
在本次练习中,我们使用Mnist数据集,自动下载完成后,使用
mindspore.dataset
提供的数据变换进行预处理。数据下载完成后,获得数据集对象。
【练习一】dataset预处理需指定在某个数据列进行操作,请通过
get_col_names
打印数据集中包含的数据列名,用于后续的数据预处理。MindSpore的dataset使用数据处理流水线(Data Processing Pipeline),需指定map、batch、shuffle等操作。这里我们使用map对图像数据及标签进行变换处理,然后将处理好的数据集打包为大小为64的batch。
【练习二】按照如下步骤,补完图像数据的预处理步骤。
- 图片数据处理:
- 原始图片中,每个像素的灰度值在0-255之间,我们需要通过
Rescale
将数值变为0-1之间; - 按照mean=0.1307,std=0.3081,通过
Normalize
对数据进行归一化; - 图像的shape为[height(H), width(W), channel(C)],通过
HWC2CHW
将shape变为[channel(C), height(H), width(W)];
- 标签数据处理:
- 将数据类型通过
TypeCast
转换为mindspore.int32
;
- 按照batch size进行批处理
使用
create_tuple_iterator
或create_dict_iterator
对数据集进行迭代。【练习三】打印第一个batch中图片的shape和dtype,以及标签的shape和dtype。如上述操作正确,图片的shape应为[64, 1, 28, 28], 标签的dtype应为Int32。
更多细节详见数据集 Dataset与数据变换 Transforms。
网络构建
mindspore.nn
类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承nn.Cell
类,并重写__init__
方法和construct
方法。__init__
包含所有网络层的定义,construct
中包含数据(Tensor)的变换过程(即计算图的构造过程)。【练习四】参考如下步骤,完成深度学习网络的搭建
- (已完成) flattern层,将二维图像矩阵转换为一维向量
- 全连接层,输入为28x28图片转换为一维向量后的长度,输出为512
- 非线性激活函数 ReLU
- 全连接层,输入和输出维度相同
- 非线性层ReLU
- 全连接层,输入为上一全连接层的输出维度,输出为分类数
可以通过
print(network)
验证网络结构是否正确。更多细节详见网络构建。
模型训练
在模型训练中,一个完整的训练过程(step)需要实现以下三步:
- 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
- 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
- 参数优化:将梯度更新到参数上。
MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:
- 正向计算函数定义。
- 通过函数变换获得梯度计算函数。
- 训练函数定义,执行正向计算、反向传播和参数优化。
【练习五】参考如上步骤,完成训练一个step的代码。
除训练外,我们定义测试函数,用来评估模型的性能。
【练习六】补全代码,使每个epoch可以打印当前的averge loss及accuracy。
训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的loss值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。
更多细节详见模型训练。
保存模型
模型训练完成后,需要将其参数进行保存。
加载模型
加载保存的权重分为两步:
- 重新实例化模型对象,构造模型。
- 加载模型参数,并将其加载至模型上。
【练习七】
- 通过调用
load_checkpoint
和load_param_into_net
接口,完成模型权重加载。
load_param_into_net
输出未加载权重的参数列表,为空时代表所有参数均加载成功。通过打印load_param_into_net
结果,验证是否成功加载权重。
加载后的模型可以直接用于预测推理。
【练习八】补全代码,完成模型推理。
【测验结果提交】打印推理预测结果,将结果截图提交。
更多细节详见保存与加载。
- 作者:王大卫
- 链接:https://tangly1024.com/article/note:mindspore
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。