torch.utils.checkpoint

torch.utils.checkpoint.checkpoint(function, *args)

checkpoint模型或模型的一部分

checkpoint通过交换计算内存来工作。而不是存储整个计算图的所有中间激活用于向后计算,checkpoint不会不保存中间激活部分,而是在反向传递中重新计算它们。它可以应用于模型的任何部分。

具体来说,在正向传递中,function将以torch.no_grad()方式运行 ,即不存储中间激活。相反,正向传递保存输入元组和 function参数。在向后计算中,保存的输入变量以及 function会被回收,并且正向计算被function再次计算 ,现在跟踪中间激活,然后使用这些激活值来计算梯度。

警告 checkpointtorch.autograd.grad()中不起作用,但仅适用于torch.autograd.backward()

警告 如果function在向后执行和前向执行都不同,例如由于某个全局变量,checkpoint版本将不等同,并且不幸的是无法检测到。

参数:

  • function - 描述在模型的正向传递或模型的一部分中运行的内容。它也应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过 ,应正确使用第一个输入作为第二个输入(activation, hidden)functionactivationhidden
  • args - 包含输入的元组function

返回: attrfunction*args

返回类型:运行输出

torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs)

用于checkpoint sequential模型的辅助函数。

sequential模型按顺序执行一系列模块/函数(按顺序)。因此,我们可以将这种模型分为不同的部分和checkpoint。除最后一个段以外的所有段都将以某种torch.no_grad()方式运行 ,即不存储中间活动。将保存每个checkpoint`段的输入,以便在向后传递中重新运行段。

关于checkpoint如何工作可以参考checkpoint()

警告 checkpointtorch.autograd.grad()中不起作用,但仅适用于torch.autograd.backward()

参数:

  • functions - A torch.nn.Sequential或按顺序运行的模块或函数(包括模型)的列表。
  • segments - 要在模型中创建的块数
  • segments - 输入的张量元组functions

返回:

functions按顺序运行的输出*inputs

例:

  1. >>> model = nn.Sequential(...)
  2. >>> input_var = checkpoint_sequential(model, chunks, input_var)

部分地方存在翻译错误,即将修复

译者署名

用户名 头像 职能 签名
Song torch.utils.checkpoint - 图1 翻译 人生总要追求点什么