符号微分

图 D-1 展示了符号微分是如何运行在相当简单的函数上的,g(x,y) = 5 + xy。该函数的计算图如图的左边所示。通过符号微分,我们可得到图的右部分,它代表了 \frac{\partial g}{\partial x} = 0 + (0 \times x + y \times 1) = y,相似地也可得到关于y的导数。

D-1

概算法先获得叶子节点的偏导数。常数 5 返回常数 0,因为常数的导数总是 0。变量x返回常数 1,变量y返回常数 0,因为 \frac{\partial y}{\partial x} = 0(如果我们找关于y的偏导数,那它将反过来)。

现在我们移动到计算图的相乘节点处,代数告诉我们,uv相乘后的导数为 \frac{\partial (u \times v)}{\partial x} = \frac{\partial v}{\partial x} \times u + \frac{\partial u}{\partial x} \times v 。因此我们可以构造有图中大的部分,代表0 × x + y × 1

最后我们往上走到计算图的相加节点处,正如 5 条规则里提到的,和的导数等于导数的和。所以我们只需要创建一个相加节点,连接我们已经计算出来的部分。我们可以得到正确的偏导数,即:\frac{\partial g}{\partial x} = 0 + (0 \times x + y \times 1)

然而,这个过程可简化。对该图应用一些微不足道的剪枝步骤,可以去掉所有不必要的操作,然后我们可以得到一个小得多的只有一个节点的偏导计算图:\frac{\partial g}{\partial x} = y

在这个例子里,简化操作是相当简单的,但对更复杂的函数来说,符号微分会产生一个巨大的计算图,该图可能很难去简化,以导致次优的性能。更重要的是,符号微分不能处理由任意代码定义的函数,例如,如下已在第 9 章讨论过的函数:

  1. def my_func(a, b):
  2. z = 0
  3. for i in range(100):
  4. z = a * np.cos(z + i) + z * np.sin(b - i)
  5. return z