EarlyStopping

class paddle.callbacks.EarlyStopping ( monitor=’loss’, mode=’auto’, patience=0, verbose=1, min_delta=0, baseline=None, save_best_model=True )

在模型评估阶段,模型效果如果没有提升,EarlyStopping 会让模型提前停止训练。

参数:

  • monitor (str,可选) - 监控量。该量作为模型是否停止学习的监控指标。默认值:’loss’。

  • mode (str,可选) - 可以是’auto’、’min’或者’max’。在min模式下,模型会在监控量的值不再减少时停止训练;max模式下,模型会在监控量的值不再增加时停止训练;auto模式下,实际的模式会从 monitor 推断出来。如果 monitor 中有’acc’,将会认为是max模式,其它情况下,都会被推断为min模式。默认值:’auto’。

  • patience (int,可选) - 多少个epoch模型效果未提升会使模型提前停止训练。默认值:0。

  • verbose (int,可选) - 可以是0或者1。0代表不打印模型提前停止训练的日志,1代表打印日志。默认值:1。

  • min_delta (int|float,可选) - 监控量最小改变值。当evaluation的监控变量改变值小于 min_delta ,就认为模型没有变化。默认值:0。

  • baseline (int|float,可选) - 监控量的基线。如果模型在训练 patience 个epoch后效果对比基线没有提升,将会停止训练。如果是None,代表没有基线。默认值:None。

  • save_best_model (bool,可选) - 是否保存效果最好的模型(监控量的值最优)。文件会保存在 fit 中传入的参数 save_dir 下,前缀名为best_model,默认值: True。

代码示例

  1. import paddle
  2. from paddle import Model
  3. from paddle.static import InputSpec
  4. from paddle.vision.models import LeNet
  5. from paddle.vision.datasets import MNIST
  6. from paddle.metric import Accuracy
  7. from paddle.nn import CrossEntropyLoss
  8. import paddle.vision.transforms as T
  9. device = paddle.set_device('cpu')
  10. sample_num = 200
  11. save_dir = './best_model_checkpoint'
  12. transform = T.Compose(
  13. [T.Transpose(), T.Normalize([127.5], [127.5])])
  14. train_dataset = MNIST(mode='train', transform=transform)
  15. val_dataset = MNIST(mode='test', transform=transform)
  16. net = LeNet()
  17. optim = paddle.optimizer.Adam(
  18. learning_rate=0.001, parameters=net.parameters())
  19. inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
  20. labels = [InputSpec([None, 1], 'int64', 'label')]
  21. model = Model(net, inputs=inputs, labels=labels)
  22. model.prepare(
  23. optim,
  24. loss=CrossEntropyLoss(reduction="sum"),
  25. metrics=[Accuracy()])
  26. callbacks = paddle.callbacks.EarlyStopping(
  27. 'loss',
  28. mode='min',
  29. patience=1,
  30. verbose=1,
  31. min_delta=0,
  32. baseline=None,
  33. save_best_model=True)
  34. model.fit(train_dataset,
  35. val_dataset,
  36. batch_size=64,
  37. log_freq=200,
  38. save_freq=10,
  39. save_dir=save_dir,
  40. epochs=20,
  41. callbacks=[callbacks])