MNIST神经网络

360影视 2024-12-27 18:16 3

摘要:在 Pytorch官网给出了一个类似于 Hello World 的图片识别神经网络程序,使用的数据集合为 FASHION MNIST。下面将这个程序修改为 MNIST的训练集合,用于识别鼠标手写体识别。

一、前言

  在 Pytorch官网给出了一个类似于 Hello World 的图片识别神经网络程序,使用的数据集合为 FASHION MNIST。下面将这个程序修改为 MNIST的训练集合,用于识别鼠标手写体识别。

  对于程序的修改,将其中的训练数据由原来的 FASHION MNIST,修改为 EMNIST,使用其中的 MNIST 数据集合。神经网络依然是双隐层全连接BP网络。下面对该网络进行训练。

  现在进行50个周期的训练。最后,识别精度达到了93.7% 。花费了155秒的训练,效果还算不错。

  后面,又经过100轮的训练,最后的精度提高到 96.3%。

  将训练好的模型存储在磁盘中保存。下面选取测试集合中的第一个样本进行测试。可以看到最后一行,给出了预测与实际标号是一致的。至此,说明了模型工作正常了。

  观察一下该模型识别错误的数字,可以看到,在所有203个识别错误的字符中,大部分都属于书写奇怪的数字。如果人工识别,也未见得能够识别正确。所以,由此可以看到,这个神经网络模型的精度还是非常合理的。对比识别正确的数字,这与人工识别基本保持一致。

▲ 图1.3.1 识别正确的数字

▲ 图1.3.2 模型中识别错误的数字

四、测试印刷体

  测试一下,印刷体的10个数字,应用刚才的人工神经网络识别的效果。将它们转换成 28×28 的图像。存储在文件中。

  经过测试可以看到,其中百分之50的数字都识别错误了。

▲ 图1.4.1 识别结果,错误率达到了50%

  将字体改成粗黑体,识别率提高到80%,只有数字 1 和 5 出现了错误,都被识别成8 了。

▲ 图1.4.2 粗黑体字,识别效果

五、测试手写数字

  使用鼠标写出十个数字,查看一下识别结果。正确率只有 60% 。

  将鼠标手写体加粗,识别的正确率提升了。只有 8 和 0 识别错误。

  将其中的8和0,重新书写一下。现在,所有的数字都识别正确了。

※总  结 ※

  本文测试MNIST训练了一个双隐层,全连接的人工神经网络,对于鼠标手写体识别效果不错。字体越粗,效果越好。

[1]

下载EMNIST数据库:

[2]

下载 Pytorch中 TorchVision中的图像数据:

[3]

在Windows11中安装 Python+CU118: https://blog.csdn.net/zhuoqingjoking97298/article/details/144734305?sharetype=blogdetail&sharerId=144734305&sharerefer=PC&sharesource=zhuoqingjoking97298&spm=1011.2480.3001.8118

来源:APPLE频道

相关推荐