|
|
|
Callbacks |
|
Callbacks are objects that can customize the behavior of the training loop in the PyTorch |
|
[Trainer] (this feature is not yet implemented in TensorFlow) that can inspect the training loop |
|
state (for progress reporting, logging on TensorBoard or other ML platforms) and take decisions (like early |
|
stopping). |
|
Callbacks are "read only" pieces of code, apart from the [TrainerControl] object they return, they |
|
cannot change anything in the training loop. For customizations that require changes in the training loop, you should |
|
subclass [Trainer] and override the methods you need (see trainer for examples). |
|
By default, TrainingArguments.report_to is set to "all", so a [Trainer] will use the following callbacks. |
|
|
|
[DefaultFlowCallback] which handles the default behavior for logging, saving and evaluation. |
|
[PrinterCallback] or [ProgressCallback] to display progress and print the |
|
logs (the first one is used if you deactivate tqdm through the [TrainingArguments], otherwise |
|
it's the second one). |
|
[~integrations.TensorBoardCallback] if tensorboard is accessible (either through PyTorch >= 1.4 |
|
or tensorboardX). |
|
[~integrations.WandbCallback] if wandb is installed. |
|
[~integrations.CometCallback] if comet_ml is installed. |
|
[~integrations.MLflowCallback] if mlflow is installed. |
|
[~integrations.NeptuneCallback] if neptune is installed. |
|
[~integrations.AzureMLCallback] if azureml-sdk is |
|
installed. |
|
[~integrations.CodeCarbonCallback] if codecarbon is |
|
installed. |
|
[~integrations.ClearMLCallback] if clearml is installed. |
|
[~integrations.DagsHubCallback] if dagshub is installed. |
|
[~integrations.FlyteCallback] if flyte is installed. |
|
[~integrations.DVCLiveCallback] if dvclive is installed. |
|
|
|
If a package is installed but you don't wish to use the accompanying integration, you can change TrainingArguments.report_to to a list of just those integrations you want to use (e.g. ["azure_ml", "wandb"]). |
|
The main class that implements callbacks is [TrainerCallback]. It gets the |
|
[TrainingArguments] used to instantiate the [Trainer], can access that |
|
Trainer's internal state via [TrainerState], and can take some actions on the training loop via |
|
[TrainerControl]. |
|
Available Callbacks |
|
Here is the list of the available [TrainerCallback] in the library: |
|
[[autodoc]] integrations.CometCallback |
|
- setup |
|
[[autodoc]] DefaultFlowCallback |
|
[[autodoc]] PrinterCallback |
|
[[autodoc]] ProgressCallback |
|
[[autodoc]] EarlyStoppingCallback |
|
[[autodoc]] integrations.TensorBoardCallback |
|
[[autodoc]] integrations.WandbCallback |
|
- setup |
|
[[autodoc]] integrations.MLflowCallback |
|
- setup |
|
[[autodoc]] integrations.AzureMLCallback |
|
[[autodoc]] integrations.CodeCarbonCallback |
|
[[autodoc]] integrations.NeptuneCallback |
|
[[autodoc]] integrations.ClearMLCallback |
|
[[autodoc]] integrations.DagsHubCallback |
|
[[autodoc]] integrations.FlyteCallback |
|
[[autodoc]] integrations.DVCLiveCallback |
|
- setup |
|
TrainerCallback |
|
[[autodoc]] TrainerCallback |
|
Here is an example of how to register a custom callback with the PyTorch [Trainer]: |
|
thon |
|
class MyCallback(TrainerCallback): |
|
"A callback that prints a message at the beginning of training" |
|
def on_train_begin(self, args, state, control, **kwargs): |
|
print("Starting training") |
|
|
|
trainer = Trainer( |
|
model, |
|
args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
callbacks=[MyCallback], # We can either pass the callback class this way or an instance of it (MyCallback()) |
|
) |
|
|
|
Another way to register a callback is to call trainer.add_callback() as follows: |
|
thon |
|
trainer = Trainer() |
|
trainer.add_callback(MyCallback) |
|
Alternatively, we can pass an instance of the callback class |
|
trainer.add_callback(MyCallback()) |
|
|
|
TrainerState |
|
[[autodoc]] TrainerState |
|
TrainerControl |
|
[[autodoc]] TrainerControl |