mxnet.executor_manager.DataParallelExecutorManager¶
-
class
mxnet.executor_manager.
DataParallelExecutorManager
(symbol, ctx, train_data, arg_names, param_names, aux_names, work_load_list=None, logger=None, sym_gen=None)[source]¶ Helper class to manage multiple executors for data parallelism.
- Parameters
symbol (Symbol) – Output symbol.
ctx (list of Context) – Devices to run on.
param_names (list of str) – Name of all trainable parameters of the network.
arg_names (list of str) – Name of all arguments of the network.
aux_names (list of str) – Name of all auxiliary states of the network.
train_data (DataIter) – Training data iterator.
work_load_list (list of float or int, optional) – The list of work load for different devices, in the same order as ctx.
logger (logging logger) – When not specified, default logger will be used.
sym_gen (A function that generate new Symbols depending on different) – input shapes. Used only for bucketing.
-
__init__
(symbol, ctx, train_data, arg_names, param_names, aux_names, work_load_list=None, logger=None, sym_gen=None)[source]¶ Initialize self. See help(type(self)) for accurate signature.
Methods
__init__
(symbol, ctx, train_data, arg_names, …)Initialize self.
backward
()Run backward on the current executor.
copy_to
(arg_params, aux_params)Copy data from each executor to
`arg_params
andaux_params
.forward
([is_train])Run forward on the current executor.
install_monitor
(monitor)Install monitor on all executors.
load_data_batch
(data_batch)Load data and labels into arrays.
set_params
(arg_params, aux_params)Set parameter and aux values.
update_metric
(metric, labels[, pre_sliced])Update metric with the current executor.
Attributes
aux_arrays
Shared aux states.
grad_arrays
Shared gradient arrays.
param_arrays
Shared parameter arrays.