摘要:在 Pytorch官网给出了一个类似于 Hello World 的图片识别神经网络程序,使用的数据集合为 FASHION MNIST。下面将这个程序修改为 MNIST的训练集合,用于识别鼠标手写体识别。
一、前言
在 Pytorch官网给出了一个类似于 Hello World 的图片识别神经网络程序,使用的数据集合为 FASHION MNIST。下面将这个程序修改为 MNIST的训练集合,用于识别鼠标手写体识别。
对于程序的修改,将其中的训练数据由原来的 FASHION MNIST,修改为 EMNIST,使用其中的 MNIST 数据集合。神经网络依然是双隐层全连接BP网络。下面对该网络进行训练。
现在进行50个周期的训练。最后,识别精度达到了93.7% 。花费了155秒的训练,效果还算不错。
后面,又经过100轮的训练,最后的精度提高到 96.3%。
将训练好的模型存储在磁盘中保存。下面选取测试集合中的第一个样本进行测试。可以看到最后一行,给出了预测与实际标号是一致的。至此,说明了模型工作正常了。
观察一下该模型识别错误的数字,可以看到,在所有203个识别错误的字符中,大部分都属于书写奇怪的数字。如果人工识别,也未见得能够识别正确。所以,由此可以看到,这个神经网络模型的精度还是非常合理的。对比识别正确的数字,这与人工识别基本保持一致。
▲ 图1.3.1 识别正确的数字
▲ 图1.3.2 模型中识别错误的数字
四、测试印刷体测试一下,印刷体的10个数字,应用刚才的人工神经网络识别的效果。将它们转换成 28×28 的图像。存储在文件中。
经过测试可以看到,其中百分之50的数字都识别错误了。
▲ 图1.4.1 识别结果,错误率达到了50%
将字体改成粗黑体,识别率提高到80%,只有数字 1 和 5 出现了错误,都被识别成8 了。
▲ 图1.4.2 粗黑体字,识别效果
五、测试手写数字使用鼠标写出十个数字,查看一下识别结果。正确率只有 60% 。
将鼠标手写体加粗,识别的正确率提升了。只有 8 和 0 识别错误。
将其中的8和0,重新书写一下。现在,所有的数字都识别正确了。
※总 结 ※本文测试MNIST训练了一个双隐层,全连接的人工神经网络,对于鼠标手写体识别效果不错。字体越粗,效果越好。
[1]
下载EMNIST数据库:
[2]
下载 Pytorch中 TorchVision中的图像数据:
[3]
在Windows11中安装 Python+CU118: https://blog.csdn.net/zhuoqingjoking97298/article/details/144734305?sharetype=blogdetail&sharerId=144734305&sharerefer=PC&sharesource=zhuoqingjoking97298&spm=1011.2480.3001.8118
来源:APPLE频道