Built-in Hooks
Base Hooks
trainer_tools.hooks.BaseHook
Base class for hooks. Hooks can interact with the Trainer at various points.
Source code in trainer_tools/hooks/base.py
before_fit(trainer)
Called before training starts. Guaranteed attributes: trainer.model, trainer.train_dl, trainer.valid_dl, trainer.opt, trainer.train_step, trainer.eval_step, trainer.epochs, trainer.device, trainer.config, trainer.accelerator, trainer.step_state, trainer.state, trainer.result
Source code in trainer_tools/hooks/base.py
before_epoch(trainer)
Called before each epoch (train + val). Guaranteed attributes (in addition to above): trainer.start_epoch, trainer.training, trainer.dl
before_step(trainer)
after_pred(trainer)
Called after forward pass. Note: no longer called natively since predict is removed, keep for legacy or user hooks.
after_loss(trainer)
Called after loss calculation. Note: no longer called natively since get_loss is removed, keep for legacy or user hooks.
after_backward(trainer)
Called after loss.backward(). Guaranteed attributes (in addition to before_step): trainer.result (has 'loss')
after_step(trainer)
Called after opt.step() and opt.zero_grad() but before batch logic cleanup. Guaranteed attributes: trainer._did_opt_step (True if optimizer stepped)
before_valid(trainer)
Will be called between train and val dataloaders within an epoch. Guaranteed attributes: trainer.training is False, trainer.dl is valid_dl
after_epoch(trainer)
after_fit(trainer)
trainer_tools.hooks.MainProcessHook
Bases: BaseHook
Marker base class for hooks that should only run on the main process.
Hooks that inherit from this class will be automatically skipped on non-main processes in distributed training, eliminating the need for 'if trainer.is_main' guards inside the hook implementation.
Typical use cases: - Metrics logging - Checkpointing - Progress bars - Any I/O or console output
Source code in trainer_tools/hooks/base.py
trainer_tools.hooks.LambdaHook
Bases: BaseHook
Creates a hook from callables passed as keyword arguments.
Source code in trainer_tools/hooks/base.py
Optimization Hooks
trainer_tools.hooks.AMPHook
Bases: BaseHook
A hook to seamlessly add Automatic Mixed Precision (AMP). - Initializes a GradScaler at the beginning of training. - Wraps the forward pass (predict + get_loss) in an autocast context. - Wraps backward and optimizer step with gradient scaling.
Source code in trainer_tools/hooks/optimization.py
trainer_tools.hooks.GradientAccumulationHook
Bases: BaseHook
Accumulates gradients over multiple steps.
Source code in trainer_tools/hooks/optimization.py
trainer_tools.hooks.GradClipHook
Bases: BaseHook
Hook to clip gradients after backward pass.
Source code in trainer_tools/hooks/optimization.py
trainer_tools.hooks.LRSchedulerHook
Bases: BaseHook
A hook to integrate a PyTorch learning rate scheduler into the training loop.
Source code in trainer_tools/hooks/optimization.py
Checkpointing and Utilities
trainer_tools.hooks.CheckpointHook
Bases: MainProcessHook
Saves model, optimizer, scheduler, scaler, and RNG states. Can resume training from a checkpoint.
Works transparently in both single-device and distributed (Accelerate) setups.
Source code in trainer_tools/hooks/checkpoint.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | |
__init__(save_dir, save_every_steps=1000, keep_last=3, resume_path=None, save_strategy='best', metric_name='valid_loss')
Source code in trainer_tools/hooks/checkpoint.py
trainer_tools.hooks.EMAHook
Bases: BaseHook
Keeps Exponential moving average of a model
Source code in trainer_tools/hooks/ema.py
trainer_tools.hooks.accelerate.AccelerateHook
Bases: BaseHook
Integrates HF Accelerate into the training loop.
Handles distributed training (DDP/FSDP), mixed precision, gradient accumulation, gradient clipping, and device placement — all through a single hook.
When using AccelerateHook, do not add AMPHook, GradientAccumulationHook,
or GradClipHook — their functionality is subsumed by Accelerate.
LRSchedulerHook remains compatible and its scheduler will be prepared automatically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
gradient_accumulation_steps
|
int
|
Number of micro-batches to accumulate before an optimizer update. Accelerate handles loss scaling and gradient synchronisation suppression automatically. |
1
|
max_grad_norm
|
float | None
|
Maximum gradient norm for clipping (applied only on
synchronisation/update steps). |
None
|
**kwargs
|
Any
|
Forwarded to |
{}
|
Source code in trainer_tools/hooks/accelerate.py
__init__(gradient_accumulation_steps=1, max_grad_norm=None, **kwargs)
Source code in trainer_tools/hooks/accelerate.py
trainer_tools.hooks.ProgressBarHook
Bases: MainProcessHook
A hook to display progress bars for epochs and batches.