mxnet.ndarray.contrib.cond¶
-
mxnet.ndarray.contrib.
cond
(pred, then_func, else_func)[source]¶ Run an if-then-else using user-defined condition and computation
This operator simulates a if-like branch which chooses to do one of the two customized computations according to the specified condition.
pred is a scalar MXNet NDArray, indicating which branch of computation should be used.
then_func is a user-defined function, used as computation of the then branch. It produces outputs, which is a list of NDArrays. The signature of then_func should be then_func() => NDArray or nested List[NDArray].
else_func is a user-defined function, used as computation of the else branch. It produces outputs, which is a list of NDArrays. The signature of else_func should be else_func() => NDArray or nested List[NDArray].
The outputs produces by then_func and else_func should have the same number of elements, all of which should be in the same shape, of the same dtype and stype.
This function returns a list of symbols, representing the computation result.
- Parameters
pred (a MXNet NDArray representing a scalar.) – The branch condition.
then_func (a Python function.) – The computation to be executed if pred is true.
else_func (a Python function.) – The computation to be executed if pred is false.
- Returns
outputs
- Return type
an NDArray or nested lists of NDArrays, representing the result of computation.
Examples
>>> a, b = mx.nd.array([1]), mx.nd.array([2]) >>> pred = a * b < 5 >>> then_func = lambda: (a + 5) * (b + 5) >>> else_func = lambda: (a - 5) * (b - 5) >>> outputs = mx.nd.contrib.cond(pred, then_func, else_func) >>> outputs[0] [42.] <NDArray 1 @cpu(0)>