flops

paddle.flops ( self, input_size=None, dtype=None ) [源代码]

flops 函数能够打印网络的基础结构和参数信息。

参数

  • net (paddle.nn.Layer||paddle.static.Program) - 网络实例,必须是 paddle.nn.Layer 的子类或者静态图下的 paddle.static.Program。

  • input_size (list) - 输入张量的大小。注意:仅支持batch_size=1。

  • custom_ops (dict,可选) - 字典,用于实现对自定义网络层的统计。字典的key为自定义网络层的class,value为统计网络层flops的函数,函数实现方法见示例代码。此参数仅在 ‘net’ 为paddle.nn.Layer时生效。默认值:None。

  • print_detail (bool, 可选) - bool值,用于控制是否打印每个网络层的细节。默认值:False

返回

整型,网络模型的计算量。

代码示例

  1. import paddle
  2. import paddle.nn as nn
  3. class LeNet(nn.Layer):
  4. def __init__(self, num_classes=10):
  5. super(LeNet, self).__init__()
  6. self.num_classes = num_classes
  7. self.features = nn.Sequential(
  8. nn.Conv2D(
  9. 1, 6, 3, stride=1, padding=1),
  10. nn.ReLU(),
  11. nn.MaxPool2D(2, 2),
  12. nn.Conv2D(
  13. 6, 16, 5, stride=1, padding=0),
  14. nn.ReLU(),
  15. nn.MaxPool2D(2, 2))
  16. if num_classes > 0:
  17. self.fc = nn.Sequential(
  18. nn.Linear(400, 120),
  19. nn.Linear(120, 84),
  20. nn.Linear(
  21. 84, 10))
  22. def forward(self, inputs):
  23. x = self.features(inputs)
  24. if self.num_classes > 0:
  25. x = paddle.flatten(x, 1)
  26. x = self.fc(x)
  27. return x
  28. lenet = LeNet()
  29. # m 是 nn.Layer 的一个实类, x 是m的输入, y 是网络层的输出.
  30. def count_leaky_relu(m, x, y):
  31. x = x[0]
  32. nelements = x.numel()
  33. m.total_ops += int(nelements)
  34. FLOPs = paddle.flops(lenet, [1, 1, 28, 28], custom_ops= {nn.LeakyReLU: count_leaky_relu},
  35. print_detail=True)
  36. print(FLOPs)
  37. #+--------------+-----------------+-----------------+--------+--------+
  38. #| Layer Name | Input Shape | Output Shape | Params | Flops |
  39. #+--------------+-----------------+-----------------+--------+--------+
  40. #| conv2d_2 | [1, 1, 28, 28] | [1, 6, 28, 28] | 60 | 47040 |
  41. #| re_lu_2 | [1, 6, 28, 28] | [1, 6, 28, 28] | 0 | 0 |
  42. #| max_pool2d_2 | [1, 6, 28, 28] | [1, 6, 14, 14] | 0 | 0 |
  43. #| conv2d_3 | [1, 6, 14, 14] | [1, 16, 10, 10] | 2416 | 241600 |
  44. #| re_lu_3 | [1, 16, 10, 10] | [1, 16, 10, 10] | 0 | 0 |
  45. #| max_pool2d_3 | [1, 16, 10, 10] | [1, 16, 5, 5] | 0 | 0 |
  46. #| linear_0 | [1, 400] | [1, 120] | 48120 | 48000 |
  47. #| linear_1 | [1, 120] | [1, 84] | 10164 | 10080 |
  48. #| linear_2 | [1, 84] | [1, 10] | 850 | 840 |
  49. #+--------------+-----------------+-----------------+--------+--------+
  50. #Total Flops: 347560 Total Params: 61610