Theano 条件语句¶

theano 中提供了两种条件语句,ifelseswitch,两者都是用于在符号变量上使用条件语句:

  • ifelse(condition, var1, var2)
    • 如果 conditiontrue,返回 var1,否则返回 var2
  • switch(tensor, var1, var2)
    • Elementwise ifelse 操作,更一般化
  • switch 会计算两个输出,而 ifelse 只会根据给定的条件,计算相应的输出。
    ifelse 需要从 theano.ifelse 中导入,而 switchtheano.tensor 模块中。

In [1]:

  1. import theano, time
  2. import theano.tensor as T
  3. import numpy as np
  4. from theano.ifelse import ifelse
  1. Using gpu device 1: Tesla K10.G2.8GB (CNMeM is disabled)

假设我们有两个标量参数:$a, b$,和两个矩阵 $\mathbf{x, y}$,定义函数为:

\mathbf z = f(a, b,\mathbf{x, y}) = \left{ \begin{aligned} \mathbf x & ,\ a <= b\ \mathbf y & ,\ a > b\end{aligned}\right.

定义变量:

In [2]:

  1. a, b = T.scalars('a', 'b')
  2. x, y = T.matrices('x', 'y')

ifelse 构造,小于等于用 T.lt(),大于等于用 T.gt()

In [3]:

  1. z_ifelse = ifelse(T.lt(a, b), x, y)
  2.  
  3. f_ifelse = theano.function([a, b, x, y], z_ifelse)

switch 构造:

In [4]:

  1. z_switch = T.switch(T.lt(a, b), x, y)
  2.  
  3. f_switch = theano.function([a, b, x, y], z_switch)

测试数据:

In [5]:

  1. val1 = 0.
  2. val2 = 1.
  3. big_mat1 = np.ones((10000, 1000), dtype=theano.config.floatX)
  4. big_mat2 = np.ones((10000, 1000), dtype=theano.config.floatX)

比较两者的运行速度:

In [6]:

  1. n_times = 10
  2.  
  3. tic = time.clock()
  4. for i in xrange(n_times):
  5. f_switch(val1, val2, big_mat1, big_mat2)
  6. print 'time spent evaluating both values %f sec' % (time.clock() - tic)
  7.  
  8. tic = time.clock()
  9. for i in xrange(n_times):
  10. f_ifelse(val1, val2, big_mat1, big_mat2)
  11. print 'time spent evaluating one value %f sec' % (time.clock() - tic)
  1. time spent evaluating both values 0.638598 sec
  2. time spent evaluating one value 0.461249 sec

原文: https://nbviewer.jupyter.org/github/lijin-THU/notes-python/blob/master/09-theano/09.06-conditions-in-theano.ipynb