mxnet.test_utils.check_symbolic_forward¶
-
mxnet.test_utils.
check_symbolic_forward
(sym, location, expected, rtol=0.0001, atol=None, aux_states=None, ctx=None, equal_nan=False, dtype=<class 'numpy.float32'>)[source]¶ Compares a symbol’s forward results with the expected ones. Prints error messages if the forward results are not the same as the expected ones.
- Parameters
sym (Symbol) – output symbol
location (list of np.ndarray or dict of str to np.ndarray) –
The evaluation point
- if type is list of np.ndarray
Contains all the numpy arrays corresponding to sym.list_arguments().
- if type is dict of str to np.ndarray
Contains the mapping between argument names and their values.
expected (list of np.ndarray or dict of str to np.ndarray) –
The expected output value
- if type is list of np.ndarray
Contains arrays corresponding to exe.outputs.
- if type is dict of str to np.ndarray
Contains mapping between sym.list_output() and exe.outputs.
check_eps (float, optional) – Relative error to check to.
aux_states (list of np.ndarray of dict, optional) –
- if type is list of np.ndarray
Contains all the NumPy arrays corresponding to sym.list_auxiliary_states
- if type is dict of str to np.ndarray
Contains the mapping between names of auxiliary states and their values.
ctx (Context, optional) – running context
dtype (np.float16 or np.float32 or np.float64) – Datatype for mx.nd.array.
equal_nan (Boolean) – if True, nan is a valid value for checking equivalency (ie nan == nan)
Example
>>> shape = (2, 2) >>> lhs = mx.symbol.Variable('lhs') >>> rhs = mx.symbol.Variable('rhs') >>> sym_dot = mx.symbol.dot(lhs, rhs) >>> mat1 = np.array([[1, 2], [3, 4]]) >>> mat2 = np.array([[5, 6], [7, 8]]) >>> ret_expected = np.array([[19, 22], [43, 50]]) >>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected])