LeNet在手写数字识别上的应用

LeNet网络的实现代码如下:

  1. # 导入需要的包
  2. import paddle
  3. import paddle.fluid as fluid
  4. import numpy as np
  5. from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
  6. # 定义 LeNet 网络结构
  7. class LeNet(fluid.dygraph.Layer):
  8. def __init__(self, name_scope, num_classes=1):
  9. super(LeNet, self).__init__(name_scope)
  10. # 创建卷积和池化层块,每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化
  11. self.conv1 = Conv2D(num_channels=1, num_filters=6, filter_size=5, act='sigmoid')
  12. self.pool1 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')
  13. self.conv2 = Conv2D(num_channels=6, num_filters=16, filter_size=5, act='sigmoid')
  14. self.pool2 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')
  15. # 创建第3个卷积层
  16. self.conv3 = Conv2D(num_channels=16, num_filters=120, filter_size=4, act='sigmoid')
  17. # 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分裂标签的类别数
  18. self.fc1 = Linear(input_dim=120, output_dim=64, act='sigmoid')
  19. self.fc2 = Linear(input_dim=64, output_dim=num_classes)
  20. # 网络的前向计算过程
  21. def forward(self, x):
  22. x = self.conv1(x)
  23. x = self.pool1(x)
  24. x = self.conv2(x)
  25. x = self.pool2(x)
  26. x = self.conv3(x)
  27. x = fluid.layers.reshape(x, [x.shape[0], -1])
  28. x = self.fc1(x)
  29. x = self.fc2(x)
  30. return x

下面的程序使用随机数作为输入,查看经过LeNet-5的每一层作用之后,输出数据的形状

  1. # 输入数据形状是 [N, 1, H, W]
  2. # 这里用np.random创建一个随机数组作为输入数据
  3. x = np.random.randn(*[3,1,28,28])
  4. x = x.astype('float32')
  5. with fluid.dygraph.guard():
  6. # 创建LeNet类的实例,指定模型名称和分类的类别数目
  7. m = LeNet('LeNet', num_classes=10)
  8. # 通过调用LeNet从基类继承的sublayers()函数,
  9. # 查看LeNet中所包含的子层
  10. print(m.sublayers())
  11. x = fluid.dygraph.to_variable(x)
  12. for item in m.sublayers():
  13. # item是LeNet类中的一个子层
  14. # 查看经过子层之后的输出数据形状
  15. try:
  16. x = item(x)
  17. except:
  18. x = fluid.layers.reshape(x, [x.shape[0], -1])
  19. x = item(x)
  20. if len(item.parameters())==2:
  21. # 查看卷积和全连接层的数据和参数的形状,
  22. # 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数b
  23. print(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)
  24. else:
  25. # 池化层没有参数
  26. print(item.full_name(), x.shape)
  1. [<paddle.fluid.dygraph.nn.Conv2D object at 0x7f29858ebad0>, <paddle.fluid.dygraph.nn.Pool2D object at 0x7f29858f8110>, <paddle.fluid.dygraph.nn.Conv2D object at 0x7f29858f8230>, <paddle.fluid.dygraph.nn.Pool2D object at 0x7f29858f82f0>, <paddle.fluid.dygraph.nn.Conv2D object at 0x7f29858f8350>, <paddle.fluid.dygraph.nn.Linear object at 0x7f29858f8470>, <paddle.fluid.dygraph.nn.Linear object at 0x7f29858f85f0>]
  2. conv2d_0 [3, 6, 24, 24] [6, 1, 5, 5] [6]
  3. pool2d_0 [3, 6, 12, 12]
  4. conv2d_1 [3, 16, 8, 8] [16, 6, 5, 5] [16]
  5. pool2d_1 [3, 16, 4, 4]
  6. conv2d_2 [3, 120, 1, 1] [120, 16, 4, 4] [120]
  7. linear_0 [3, 64] [120, 64] [64]
  8. linear_1 [3, 10] [64, 10] [10]
  1. # -*- coding: utf-8 -*-
  2. # LeNet 识别手写数字
  3. import os
  4. import random
  5. import paddle
  6. import paddle.fluid as fluid
  7. import numpy as np
  8. # 定义训练过程
  9. def train(model):
  10. print('start training ... ')
  11. model.train()
  12. epoch_num = 5
  13. opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameter_list=model.parameters())
  14. # 使用Paddle自带的数据读取器
  15. train_loader = paddle.batch(paddle.dataset.mnist.train(), batch_size=10)
  16. valid_loader = paddle.batch(paddle.dataset.mnist.test(), batch_size=10)
  17. for epoch in range(epoch_num):
  18. for batch_id, data in enumerate(train_loader()):
  19. # 调整输入数据形状和类型
  20. x_data = np.array([item[0] for item in data], dtype='float32').reshape(-1, 1, 28, 28)
  21. y_data = np.array([item[1] for item in data], dtype='int64').reshape(-1, 1)
  22. # 将numpy.ndarray转化成Tensor
  23. img = fluid.dygraph.to_variable(x_data)
  24. label = fluid.dygraph.to_variable(y_data)
  25. # 计算模型输出
  26. logits = model(img)
  27. # 计算损失函数
  28. loss = fluid.layers.softmax_with_cross_entropy(logits, label)
  29. avg_loss = fluid.layers.mean(loss)
  30. if batch_id % 1000 == 0:
  31. print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))
  32. avg_loss.backward()
  33. opt.minimize(avg_loss)
  34. model.clear_gradients()
  35. model.eval()
  36. accuracies = []
  37. losses = []
  38. for batch_id, data in enumerate(valid_loader()):
  39. # 调整输入数据形状和类型
  40. x_data = np.array([item[0] for item in data], dtype='float32').reshape(-1, 1, 28, 28)
  41. y_data = np.array([item[1] for item in data], dtype='int64').reshape(-1, 1)
  42. # 将numpy.ndarray转化成Tensor
  43. img = fluid.dygraph.to_variable(x_data)
  44. label = fluid.dygraph.to_variable(y_data)
  45. # 计算模型输出
  46. logits = model(img)
  47. pred = fluid.layers.softmax(logits)
  48. # 计算损失函数
  49. loss = fluid.layers.softmax_with_cross_entropy(logits, label)
  50. acc = fluid.layers.accuracy(pred, label)
  51. accuracies.append(acc.numpy())
  52. losses.append(loss.numpy())
  53. print("[validation] accuracy/loss: {}/{}".format(np.mean(accuracies), np.mean(losses)))
  54. model.train()
  55. # 保存模型参数
  56. fluid.save_dygraph(model.state_dict(), 'mnist')
  57. if __name__ == '__main__':
  58. # 创建模型
  59. with fluid.dygraph.guard():
  60. model = LeNet("LeNet", num_classes=10)
  61. #启动训练过程
  62. train(model)
  1. start training ...
  1. Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz
  2. Begin to download
  3.  
  4. Download finished
  5. Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz
  6. Begin to download
  7. ........
  8. Download finished
  9. Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz
  10. Begin to download
  11.  
  12. Download finished
  13. Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz
  14. Begin to download
  15. ..
  16. Download finished
  1. epoch: 0, batch_id: 0, loss is: [2.5567963]
  2. epoch: 0, batch_id: 1000, loss is: [2.2921207]
  3. epoch: 0, batch_id: 2000, loss is: [2.329089]
  4. epoch: 0, batch_id: 3000, loss is: [2.2760074]
  5. epoch: 0, batch_id: 4000, loss is: [2.2555802]
  6. epoch: 0, batch_id: 5000, loss is: [2.321007]
  7. [validation] accuracy/loss: 0.3555000126361847/2.2462358474731445
  8. epoch: 1, batch_id: 0, loss is: [2.2364979]
  9. epoch: 1, batch_id: 1000, loss is: [2.1558306]
  10. epoch: 1, batch_id: 2000, loss is: [2.1844604]
  11. epoch: 1, batch_id: 3000, loss is: [1.7957464]
  12. epoch: 1, batch_id: 4000, loss is: [1.341808]
  13. epoch: 1, batch_id: 5000, loss is: [1.6028554]
  14. [validation] accuracy/loss: 0.7293000221252441/1.0572129487991333
  15. epoch: 2, batch_id: 0, loss is: [0.85837966]
  16. epoch: 2, batch_id: 1000, loss is: [0.6425297]
  17. epoch: 2, batch_id: 2000, loss is: [0.6375253]
  18. epoch: 2, batch_id: 3000, loss is: [0.40348434]
  19. epoch: 2, batch_id: 4000, loss is: [0.37101394]
  20. epoch: 2, batch_id: 5000, loss is: [0.65031445]
  21. [validation] accuracy/loss: 0.8730000257492065/0.47411048412323
  22. epoch: 3, batch_id: 0, loss is: [0.35694075]
  23. epoch: 3, batch_id: 1000, loss is: [0.25489596]
  24. epoch: 3, batch_id: 2000, loss is: [0.29641074]
  25. epoch: 3, batch_id: 3000, loss is: [0.18106733]
  26. epoch: 3, batch_id: 4000, loss is: [0.1899938]
  27. epoch: 3, batch_id: 5000, loss is: [0.32796213]
  28. [validation] accuracy/loss: 0.9122999906539917/0.3133768141269684
  29. epoch: 4, batch_id: 0, loss is: [0.24354395]
  30. epoch: 4, batch_id: 1000, loss is: [0.16107734]
  31. epoch: 4, batch_id: 2000, loss is: [0.20161033]
  32. epoch: 4, batch_id: 3000, loss is: [0.09298491]
  33. epoch: 4, batch_id: 4000, loss is: [0.11935985]
  34. epoch: 4, batch_id: 5000, loss is: [0.19827338]
  35. [validation] accuracy/loss: 0.9312999844551086/0.23992861807346344

通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。那么对于其它数据集效果如何呢?我们通过眼疾识别数据集iChallenge-PM验证一下。