mxnet.ndarray.contrib.while_loop¶
-
mxnet.ndarray.contrib.
while_loop
(cond, func, loop_vars, max_iterations=None)[source]¶ Run a while loop with user-defined computation and loop condition.
This operator simulates a while loop which iterately does customized computation as long as the condition is satisfied.
loop_vars is a list of NDArrays on which the computation uses.
cond is a user-defined function, used as the loop condition. It consumes loop_vars, and produces a scalar MXNet NDArray, indicating the termination of the loop. The loop ends when cond returns false (zero). The cond is variadic, and its signature should be cond(*loop_vars) => NDArray.
func is a user-defined function, used as the loop body. It also consumes loop_vars, and produces step_output and new_loop_vars at each step. In each step, step_output should contain the same number elements. Through all steps, the i-th element of step_output should have the same shape and dtype. Also, new_loop_vars should contain the same number of elements as loop_vars, and the corresponding element should have the same shape and dtype. The func is variadic, and its signature should be func(*loop_vars) => (NDArray or nested List[NDArray] step_output, NDArray or nested List[NDArray] new_loop_vars).
max_iterations is a scalar that defines the maximum number of iterations allowed.
This function returns two lists. The first list has the length of |step_output|, in which the i-th element are all i-th elements of step_output from all steps, stacked along axis 0. The second list has the length of |loop_vars|, which represents final states of loop variables.
Warning
For now, the axis 0 of all NDArrays in the first list are max_iterations, due to lack of dynamic shape inference.
Warning
When cond is never satisfied, we assume step_output is empty, because it cannot be inferred. This is different from the symbolic version.
- Parameters
cond (a Python function.) – The loop condition.
func (a Python function.) – The loop body.
loop_vars (an NDArray or nested lists of NDArrays.) – The initial values of the loop variables.
max_iterations (a python int.) – Maximum number of iterations.
- Returns
outputs (an NDArray or nested lists of NDArrays) – stacked output from each step
states (an NDArray or nested lists of NDArrays) – final state
Examples
>>> cond = lambda i, s: i <= 5 >>> func = lambda i, s: ([i + s], [i + 1, s + i]) >>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) >>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10) >>> outputs [ [[ 1] [ 2] [ 4] [ 7] [11] [16] [...] # undefined value [...] [...] [...]] <NDArray 6x1 @cpu(0)>] >>> states [ [6] <NDArray 1 @cpu(0)>, [16] <NDArray 1 @cpu(0)>]