dmlcloud

This the API reference for the dmlcloud package.

Pipeline([config, name])

A training pipeline that consists of multiple stages.

Stage([name, epochs])

Hook Points:

current_pipe()

Returns the current running pipeline or None if no pipeline is running

current_stage()

Returns the current running stage or None if no pipeline is running

log_metric(name, value[, reduction, prefixed])

Shorthand for current_stage().log

torch.distributed Helpers

dmlcloud provides a set of helper functions to simplify the use of torch.distributed.

init([kind])

Initializes the torch.distributed framework.

seed([seed, group])

Share's the seed from the root rank to all ranks in the group and seeds the random number generators.

deinitialize_torch_distributed([fail_silently])

Deinitializes the torch distributed framework.

is_root([group])

Check if the current rank is the root rank (rank 0).

root_only(fn[, group, synchronize, timeout])

Decorator for methods that should only be called on the root rank.

root_first([group])

Context manager that ensures that the root rank executes the code first before all other ranks.

rank()

Returns the rank of the current process.

world_size()

Returns the total number of processes.

local_rank()

Returns the local rank of the current process.

local_world_size()

Returns the local world size.

local_node()

Returns the node id of the current process.

all_gather_object(obj[, group])

Gather objects from all ranks in the group.

gather_object(obj[, dst, group])

Gathers objects from all ranks in the group to the destination rank.

broadcast_object([obj, src, group, device])

Broadcasts an object from the source rank to all other ranks in the group.

has_slurm()

Check if the program was started using srun (SLURM).

has_environment()

Check if the environment variables used by the "env://" initialization method are set.

has_mpi()

Check if MPI is available.

Logging

dmlcloud provides a set of logging utilities to simplify logging in a distributed environment. In particular, it lazily setups a logger (‘dmlcloud’) that only logs on the root process. Users are encouraged to use the provided log functions instead of print statements to prevent duplicated logs.

logger

Instances of the Logger class represent a single logging channel.

log(level, msg, *args[, exc_info, ...])

Log 'msg % args' with severity 'level' on the dmlcloud logger.

debug(msg, *args[, exc_info, stack_info, extra])

Log 'msg % args' with severity 'TRACE' on the dmlcloud logger.

info(msg, *args[, exc_info, stack_info, extra])

Log 'msg % args' with severity 'INFO' on the dmlcloud logger.

warning(msg, *args[, exc_info, stack_info, ...])

Log 'msg % args' with severity 'WARNING' on the dmlcloud logger.

error(msg, *args[, exc_info, stack_info, extra])

Log 'msg % args' with severity 'ERROR' on the dmlcloud logger.

critical(msg, *args[, exc_info, stack_info, ...])

Log 'msg % args' with severity 'CRITICAL' on the dmlcloud logger.

print_worker(*values[, sep, end, file, ...])

Print the values to a stream, default sys.stdout, with additional information about the worker.

print_root(*values[, sep, end, file, flush])

Print the values to a stream if the current rank is the root rank.

setup_logger()

Setup the dmlcloud logger.

reset_logger()

Reset the dmlcloud logger to its initial state.

Metric Tracking

TrainingHistory()

Stores the training history of a model.

Tracker()

Keeps track of multiple metrics and reduces them at the end of each epoch.

Model Creation

scale_lr(base_lr[, world_size])

Scales the learning rate based on the world size.

wrap_ddp(module, device[, sync_bn, ...])

Wraps a module with DistributedDataParallel.

count_parameters(module)

Returns the number of trainable parameters in a module.

Config Helpers

These functions can be used to create objects from configuration files.

import_object(object_path)

Imports an object from a module.

factory_from_cfg(config, *args, **kwargs)

Creates a factory function from a configuration dictionary or a string.

obj_from_cfg(config, *args, **kwargs)

Creates an object from a configuration dictionary or a string.