colossalai.trainer

class colossalai.trainer.Trainer(engine, schedule=None, timer=None, logger=None)

This a class tending for easy deployments of users’ training and evaluation instead of writing their own scripts. It is similar with ignite.engine and keras.engine, but is called Trainer.

Parameters
  • engine (Engine) – Engine responsible for the process function

  • hooks_cfg – The configuration of hooks

  • verbose (bool, optional) – If True, additional information will be printed

property cur_epoch

Returns the index of the current epoch.

property cur_step

Returns how many iteration steps have been processed.

fit(train_dataloader, epochs, max_steps=None, test_dataloader=None, test_interval=1, hooks=None, display_progress=False, return_output_label=True)

Trains the model to fit training data.

Parameters
  • train_dataloader (DataLoader) – DataLoader in training

  • epochs (int) – Maximum number of epoches

  • max_steps (int) – Maximum number of running iterations

  • test_dataloader (DataLoader) – DataLoader in testing

  • test_interval (int) – Interval of testing

  • hooks_cfg (dict) – A list of hook configuration

  • display_progress (bool) – If True, the training progress will be printed

  • return_output_label (bool) – If True, the output of model and the label will be returned

evaluate(test_dataloader, hooks=None, display_progress=False, return_output_label=True)

Evaluates the model with testing data.

Parameters
  • test_dataloader (DataLoader) – DataLoader in testing

  • display_progress (bool, optional) – If True, the evaluation progress will be printed

  • return_output_label (bool) – If True, the output of model and the label will be returned

predict(data)

Uses trained model to make a prediction for a tensor or a tensor list.

Parameters

data (Union[Tensor, List[Tensor]) – Data as the input

Returns

The output of model as the prediction

Return type

Tensor