混合精度

概述

混合精度训练方法通过混合使用单精度和半精度数据格式来加速深度神经网络训练过程,同时保持了单精度训练所能达到的网络精度。混合精度训练能够加速计算过程,同时减少内存使用和存取,并在特定的硬件上可以训练更大的模型或batch size。

计算流程

MindSpore混合精度典型的计算流程如下图所示:

mix precision

  • 参数以FP32存储;

  • 正向计算过程中,遇到FP16算子,需要把算子输入和参数从FP32 cast成FP16进行计算;

  • 将Loss层设置为FP32进行计算;

  • 反向计算过程中,首先乘以Loss Scale值,避免反向梯度过小而产生下溢;

  • FP16参数参与梯度计算,其结果将被cast回FP32;

  • 除以Loss scale值,还原被放大的梯度;

  • 判断梯度是否存在溢出,如果溢出则跳过更新,否则优化器以FP32对原始参数进行更新。

本文通过自动混合精度和手动混合精度的样例来讲解计算流程。

自动混合精度

使用自动混合精度,需要调用相应的接口,将待训练网络和优化器作为输入传进去;该接口会将整张网络的算子转换成FP16算子(除BatchNorm算子和Loss涉及到的算子外)。另外要注意:使用混合精度后,一般要用上Loss Scale,避免数值计算溢出。

具体的实现步骤为:

  • 引入MindSpore的混合精度的接口amp;

  • 定义网络:该步骤和普通的网络定义没有区别(无需手动配置某个算子的精度);

  • 使用amp.build_train_network()接口封装网络模型和优化器,在该步骤中MindSpore会将有需要的算子自动进行类型转换。

代码样例如下:

  1. Copy# The interface of Auto_mixed precision
  2. from mindspore.train import amp
  3.  
  4. # Define network
  5. class LeNet5(nn.Cell):
  6. def __init__(self):
  7. super(LeNet5, self).__init__()
  8. self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
  9. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  10. self.fc1 = nn.Dense(16 * 5 * 5, 120)
  11. self.fc2 = nn.Dense(120, 84)
  12. self.fc3 = nn.Dense(84, 10)
  13. self.relu = nn.ReLU()
  14. self.max_pool2d = nn.MaxPool2d(kernel_size=2)
  15. self.flatten = P.Flatten()
  16.  
  17. def construct(self, x):
  18. x = self.max_pool2d(self.relu(self.conv1(x)))
  19. x = self.max_pool2d(self.relu(self.conv2(x)))
  20. x = self.flatten(x)
  21. x = self.relu(self.fc1(x))
  22. x = self.relu(self.fc2(x))
  23. x = self.fc3(x)
  24. return x
  25.  
  26. # Initialize network
  27. net = LeNet5()
  28.  
  29. # Define training data, label and sens
  30. predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  31. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  32. scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
  33.  
  34. # Define Loss and Optimizer
  35. net.set_train()
  36. loss = MSELoss()
  37. optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  38. net_with_loss = WithLossCell(net, loss)
  39. train_network = amp.build_train_network(net_with_loss, optimizer, level="O2")
  40.  
  41. # Run training
  42. output = train_network(inputs, label, scaling_sens)

手动混合精度

MindSpore还支持手动混合精度。假定在网络中只有一个Dense Layer要用FP32计算,其他Layer都用FP16计算。混合精度配置以Cell为粒度,Cell默认是FP32类型。

以下是一个手动混合精度的实现步骤:

  • 定义网络: 该步骤与自动混合精度中的步骤2类似;注意:在LeNet中的fc3算子,需要手动配置成FP32;

  • 配置混合精度: LeNet通过net.add_flags_recursive(fp16=True),把该Cell及其子Cell中所有的算子都配置成FP16;

  • 使用TrainOneStepWithLossScaleCell封装网络模型和优化器。

代码样例如下:

  1. Copy# Define network
  2. class LeNet5(nn.Cell):
  3. def __init__(self):
  4. super(LeNet5, self).__init__()
  5. self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
  6. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  7. self.fc1 = nn.Dense(16 * 5 * 5, 120)
  8. self.fc2 = nn.Dense(120, 84)
  9. self.fc3 = nn.Dense(84, 10).add_flags_recursive(fp32=True)
  10. self.relu = nn.ReLU()
  11. self.max_pool2d = nn.MaxPool2d(kernel_size=2)
  12. self.flatten = P.Flatten()
  13.  
  14. def construct(self, x):
  15. x = self.max_pool2d(self.relu(self.conv1(x)))
  16. x = self.max_pool2d(self.relu(self.conv2(x)))
  17. x = self.flatten(x)
  18. x = self.relu(self.fc1(x))
  19. x = self.relu(self.fc2(x))
  20. x = self.fc3(x)
  21. return x
  22.  
  23. # Initialize network and set mixing precision
  24. net = LeNet5()
  25. net.add_flags_recursive(fp16=True)
  26.  
  27. # Define training data, label and sens
  28. predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  29. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  30. scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
  31.  
  32. # Define Loss and Optimizer
  33. net.set_train()
  34. loss = MSELoss()
  35. optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
  36. net_with_loss = WithLossCell(net, loss)
  37. train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
  38.  
  39. # Run training
  40. output = train_network(inputs, label, scaling_sens)