Developper documentation for Scan

Context

This document is meant to act as reference material for developers workingon Theano’s loop mechanism. This mechanism is called Scan and its internalsare highly complex, hence the need for a centralized repository of knowledgeregarding its inner workings.

The theano.scan() function is the public-facing interface for looping inTheano. Under the hood, this function will perform some processing on itsinputs and instantiate the Scan op class which implements the loopingmechanism. It achieves this by compiling its own Theano function representingthe computation to be done at every iteration of the loop and calling it asmany times as necessary.

The correspondence between the parameters and behaviors of the function and theop is not always simple since the former is meant for usability and the secondfor performance. Since this document is intended to be used by developersworking inside Scan itself, it will mostly discuss things from the point of viewof the Scan op class. Nonetheless, it will attempt to link those elements totheir corresponding concepts in the scan function as often as is reasonablypractical.

Pre-requisites

The following sections assumes the reader is familiar with the following :

Relevant code files

The implementation of Scan is spread over several files intheano/scan_module. The different files, and sections of the code theydeal with, are :

  • scan.py implements the scan function. The scan functionarranges the arguments of scan correctly, constructs the scan op andafterwards calls the constructed scan op on the arguments. This functiontakes care of figuring out missing inputs and shared variables.
  • scan_op.py implements the Scan op class. The Scan respectsthe Op interface, and contains most of the logic of the scan operator.
  • scan_utils.py contains several helpful functions used throughout out theother files that are specific of the scan operator.
  • scan_views.py contains different views of the scan op that havesimpler and easier signatures to be used in specific cases.
  • scan_opt.py contains the list of all Theano graph optimizations for thescan operator.

Notation

Scan being a sizeable and complex module, it has its own naming convention forfunctions and variables which this section will attempt to introduce.

A scan op contains a Theano function representing the computationthat is done in a single iteration of the loop represented by the scan op (inother words, the computation given by the function provided as value totheano.scan‘s fn argument ). Whenever we discuss a scan op, the outerfunction refers to the Theano function that contains the scan op whereas theinner function refers to the Theano function that is contained inside thescan op.

In the same spirit, the inputs and outputs of the Apply node wrapping the scanop (or scan node for short) are referred to as outer inputs and outeroutputs, respectively, because these inputs and outputs are variables in theouter function graph. The inputs and outputs of scan’s inner function aredesignated inner inputs and inner outputs, respectively.

Scan variables

The following are the different types of variables that Scan has thecapacity to handle, along with their various caracteristics.

Sequence : A sequence is a Theano variable which Scan will iterateover and give sub-elements to its inner function as input. A sequencehas no associated output. For a sequence variable X, at timestept, the inner function will receive as input the sequence elementX[t]. These variables are used through the argument sequencesof the theano.scan() function.

Non-sequences : A non-sequence is a Theano variable which Scanwill provide as-is to its inner function. Like a sequence, anon-sequence has no associated output. For a non-sequence variableX, at timestep t, the inner function will receive as inputthe variable X. These variables are used through the argumentnon_sequences of the theano.scan() function.

Nitsot (no input tap, single output tap) : A nitsot is an outputvariable of the inner function that is not fed back as an input to thenext iteration of the inner function. Nitsots are typicallyencountered in situations where Scan is used to perform a ‘map’operation (every element in a tensor is independently altered using agiven operation to produce a new tensor) such as squaring every numberin a vector.

Sitsot (single input tap, single output tap) : A sitsot is an outputvariable of the inner function that is fed back as an input to the nextiteration of the inner function. A typical setting where a sitsot might beencountered is the case where Scan is used to compute the cumulative sum overthe elements of a vector and a sitsot output is employed to act as anaccumulator.

Mitsot (multiple input taps, single output tap) : A mitsot is anoutput variable of the inner function that is fed back as an input tofuture iterations of the inner function (either multiple futureiterations or a single one that isn’t the immediate next one). Forexample, a mitsot might be used in the case where Scan is used tocompute the Fibonacci sequence, one term of the sequence at everytimestep, since every computed term needs to be reused to compute thetwo next terms of the sequence.

Mitmot (multiple input taps, multiple output taps) : These outputs existbut they cannot be directly created by the user. They can appear in a theanograph as a result of taking the gradient of the output of a Scan with respectto its inputs: This will result in the creation of a new scan node used tocompute the gradients of the first scan node. If the original Scan had sitsotsor mitsots variables, the new Scan will use mitmots to compute the gradientsthrough time for these variables.

To synthesize :

Type of scan variablesCorresponding outer inputCorresponding inner input at timestep t (indexed from 0)Corresponding inner output at timestep t (indexed from 0)Corresponding outer output tCorresponding argument of the theano.scan() function
SequenceSequence of elements XIndividual sequence element X[t]No corresponding inner outputNo corresponding outer outputsequences
Non-SequenceAny variable XVariable identical to XNo corresponding inner outputNo corresponding outer outputnon_sequences
Non-recurring output (nitsot)No corresponding outer inputNo corresponding inner inputOutput value at timestep tConcatenation of the values of the output at all timestepoutputs_info
Singly-recurrent output (sitsot)Initial value (value at timestep -1)Output value at previous timestep (t-1)Output value at timestep tConcatenation of the values of the output at all timestepoutputs_info
Multiply-recurrent output (mitsot)Initial values for the required timesteps where t<0Output value at previous required timestepsOutput value at timestep tConcatenation of the values of the output at all timestepoutputs_info
Multiply-recurrent multiple outputs (mitmot)Initial values for the required timesteps where t<0Output value at previous required timestepsOutput values for current and multiple future timestepsConcatenation of the values of the output at all timestepNo corresponding argument

Optimizations

remove_constants_and_unused_inputs_scan

This optimization serves two purposes, The first is to remove a scan op’sunused inputs. The second is to take a scan op’s constant inputs and removethem, instead injecting the constants directly into the graph or the scanop’s inner function. This will allow constant folding to happen inside theinner function.

PushOutNonSeqScan

This optimizations pushes, out of Scan’s inner function and into the outerfunction, computation that depends only on non-sequence inputs. Suchcomputation ends up being done every iteration on the same values so movingit to the outer function to be executed only once, before the scan op,reduces the amount of computation that needs to be performed.

PushOutSeqScan

This optimization resembles PushOutNonSeqScan but it tries to push, out ofthe inner function, the computation that only relies on sequence andnon-sequence inputs. The idea behing this optimization is that, when it ispossible to do so, it is generally more computationally efficient to performa single operation on a large tensor rather then perform that same operationmany times on many smaller tensors. In many cases, this optimization canincrease memory usage but, in some specific cases, it can also decrease it.

PushOutScanOutput

This optimizations attempts to push out some of the computation at the endof the inner function to the outer function, to be executed after the scannode. Like PushOutSeqScan, this optimization aims to replace many operationson small tensors by few operations on large tensors. It can also lead toincreased memory usage.

PushOutDot1

This is another optimization that attempts to detect certain patterns ofcomputation in a scan op’s inner function and move this computation to theouter graph.

ScanInplaceOptimizer

This optimization attempts to make Scan compute its recurrent outputs inplaceon the input tensors that contain their initial states. This optimization canimprove runtime performance as well as reduce memory usage.

ScanSaveMem

This optimizations attempts to determine if a scan node, during its execution,for any of its outputs, can get away with allocating a memory buffer that islarge enough to contain some of the computed timesteps of that output but notall of them.

By default, during the execution of a scan node, memory buffers will beallocated to store the values computed for every output at every iteration.However, in some cases, there are outputs for which there is only really aneed to store the most recent N values, not all of them.

For instance, if a scan node has a sitsot output (last computed value isfed back as an input at the next iteration) and only the last timestep ofthat output is ever used in the outer function, the ScanSaveMem optimizationcould determine that there is no need to store all computed timesteps forthat sitsot output. Only the most recently computed timestep ever needs tobe kept in memory.

ScanMerge

This optimization attempts to fuse distinct scan ops into a single scan opthat performs all the computation. The main advantage of merging scan opstogether comes from the possibility of both original ops having somecomputation in common. In such a setting, this computation ends up being donetwice. The fused scan op, however, would only need to do it once and couldtherefore be more computationally efficient. Also, since every scan nodeinvolves a certain overhead, at runtime, reducing the number of scan nodes inthe graph can improve performance.

scan_merge_inouts

This optimization attempts to merge a scan op’s identical outer inputs as wellas merge its identical outer outputs (outputs that perform the samecomputation on the same inputs). This can reduce the amount of computation aswell as result in a simpler graph for both the inner function and the outerfunction.

Helper classes and functions

Because of the complexity involved in dealing with Scan, a large number ofhelper classes and functions have been developped over time to implementoperations commonly needed when dealing with the scan op. The scan opitself defines a large number of them and others can be found in the filescan_utils.py. This sections aims to point out the most useful ones sortedby usage.

Accessing/manipulating Scan’s inputs and outputs by type

Declared in scan_utils.py, the class scan_args handles theparsing of the inputs and outputs (both inner and outer) to a formatthat is easier to analyse and manipulate. Without this class,analysing Scan’s inputs and outputs often required convoluted logicwhich make for code that is hard to read and to maintain. Because ofthis, you should favor using scan_args when it is practical andappropriate to do so.

The scan op also defines a few helper functions for this purpose, such asinner_nitsot_outs() or mitmot_out_taps(), but they are often poorlydocumented and easy to misuse. These should be used with great care.

Navigating between outer inputs/outputs and inner inputs/outputs

Navigation between these four sets of variables can be done in two ways,depending on the type of navigation that is required.

If the goal is to navigate between variables that are associated with the samestates (ex : going from an outer sequence input to the corresponding innersequence input, going from an inner output associated with a recurrent stateto the inner input(s) associated with that same recurrent state, etc.), thenthe var_mappings attribute of the scan op can be used.

This attribute is a dictionnary with 12 {key/value} pairs. The keys are listedbelow :

  • “outer_inp_from_outer_out”
  • “inner_inp_from_outer_out”
  • “inner_out_from_outer_out”
  • “inner_inp_from_outer_inp”
  • “inner_out_from_outer_inp”
  • “outer_out_from_outer_inp”
  • “outer_inp_from_inner_inp”
  • “inner_out_from_inner_inp”
  • “outer_out_from_inner_inp”
  • “outer_inp_from_inner_out”
  • “inner_inp_from_inner_out”
  • “outer_out_from_inner_out”

Every corresponding value is a dictionary detailing a mapping from one set ofvariables to another. For each of those dictionaries the keys are indices ofvariables in one set and the values are the indices of the correspondingvariables in another set. For mappings to outer variables, the values areindividual indices or -1 if there is not corresponding outer variable.For mappings to inner variables, the values are list of indices becausemultiple inner variables may be associated with the same state.

If the goal is to navigate between variables that are connected(meaning that one of them is used to compute the other), the methodsconnection_pattern() and inner_connection_pattern() can beused. The method connection_pattern() returns a list of listsdetailing, for every pair of outer input and outer output whether theyare connected or not. The method inner_connection_pattern()accomplishes the same goal but for every possible pair of inner outputand inner input.