mxnet.module.BucketingModule¶
-
class
mxnet.module.
BucketingModule
(sym_gen, default_bucket_key=None, logger=<module 'logging' from '/var/lib/jenkins/miniconda3/envs/mxnet-docs/lib/python3.7/logging/__init__.py'>, context=cpu(0), work_load_list=None, fixed_param_names=None, state_names=None, group2ctxs=None, compression_params=None)[source]¶ This module helps to deal efficiently with varying-length inputs.
- Parameters
sym_gen (function) – A function when called with a bucket key, returns a triple
(symbol, data_names, label_names)
.default_bucket_key (str (or any python object)) – The key for the default bucket.
logger (Logger) –
context (Context or list of Context) – Defaults to
mx.cpu()
work_load_list (list of number) – Defaults to
None
, indicating uniform workload.fixed_param_names (list of str) – Defaults to
None
, indicating no network parameters are fixed.state_names (list of str) – States are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by set_states()
group2ctxs (dict of str to context or list of context,) – or list of dict of str to context Default is None. Mapping the ctx_group attribute to the context assignment.
compression_params (dict) – Specifies type of gradient compression and additional arguments depending on the type of compression being used. For example, 2bit compression requires a threshold. Arguments would then be {‘type’:‘2bit’, ‘threshold’:0.5} See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
-
__init__
(sym_gen, default_bucket_key=None, logger=<module 'logging' from '/var/lib/jenkins/miniconda3/envs/mxnet-docs/lib/python3.7/logging/__init__.py'>, context=cpu(0), work_load_list=None, fixed_param_names=None, state_names=None, group2ctxs=None, compression_params=None)[source]¶ Initialize self. See help(type(self)) for accurate signature.
Methods
__init__
(sym_gen[, default_bucket_key, …])Initialize self.
backward
([out_grads])Backward computation.
bind
(data_shapes[, label_shapes, …])Binding for a BucketingModule means setting up the buckets and binding the executor for the default bucket key.
fit
(train_data[, eval_data, eval_metric, …])Trains the module parameters.
forward
(data_batch[, is_train])Forward computation.
forward_backward
(data_batch)A convenient function that calls both
forward
andbackward
.get_input_grads
([merge_multi_context])Gets the gradients with respect to the inputs of the module.
get_outputs
([merge_multi_context])Gets outputs from a previous forward computation.
get_params
()Gets current parameters.
get_states
([merge_multi_context])Gets states from all devices.
init_optimizer
([kvstore, optimizer, …])Installs and initializes optimizers.
init_params
([initializer, arg_params, …])Initializes parameters.
install_monitor
(mon)Installs monitor on all executors
iter_predict
(eval_data[, num_batch, reset, …])Iterates over predictions.
load_params
(fname)Loads model parameters from file.
predict
(eval_data[, num_batch, …])Runs prediction and collects the outputs.
prepare
(data_batch[, sparse_row_id_fn])Prepares the module for processing a data batch.
save_params
(fname)Saves model parameters to file.
score
(eval_data, eval_metric[, num_batch, …])Runs prediction on
eval_data
and evaluates the performance according to the giveneval_metric
.set_params
(arg_params, aux_params[, …])Assigns parameters and aux state values.
set_states
([states, value])Sets value for states.
switch_bucket
(bucket_key, data_shapes[, …])Switches to a different bucket.
update
()Updates parameters according to installed optimizer and the gradient computed in the previous forward-backward cycle.
update_metric
(eval_metric, labels[, pre_sliced])Evaluates and accumulates evaluation metric on outputs of the last forward computation.
Attributes
data_names
A list of names for data required by this module.
data_shapes
Get data shapes.
label_shapes
Get label shapes.
output_names
A list of names for the outputs of this module.
output_shapes
Gets output shapes.
symbol
The symbol of the current bucket being used.