Skip to content

Built-in Hooks

Base Hooks

trainer_tools.hooks.BaseHook

Base class for hooks. Hooks can interact with the Trainer at various points.

Source code in trainer_tools/hooks/base.py
class BaseHook:
    """Base class for hooks. Hooks can interact with the Trainer at various points."""

    ord: int = 0  # Order of execution, higher values run later

    def store(self, trainer, key, value):
        """Stores a value in the trainer.state dictionary under the hook's namespace."""
        trainer.state[f"{self.__class__.__name__}.{key}"] = value

    def get(self, trainer, hook_cls, key, default=None):
        """Retrieves a value from the trainer.state dictionary under another hook's namespace."""
        cls_name = hook_cls.__name__ if isinstance(hook_cls, type) else hook_cls
        return trainer.state.get(f"{cls_name}.{key}", default)

    def before_fit(self, trainer):
        """Called before training starts.
        Guaranteed attributes: trainer.model, trainer.train_dl, trainer.valid_dl, trainer.opt, trainer.train_step, trainer.eval_step, trainer.epochs, trainer.device, trainer.config, trainer.accelerator, trainer.step_state, trainer.state, trainer.result
        """
        pass

    def before_epoch(self, trainer):
        """Called before each epoch (train + val).
        Guaranteed attributes (in addition to above): trainer.start_epoch, trainer.training, trainer.dl
        """
        pass

    def before_step(self, trainer):
        """Called before processing a batch.
        Guaranteed attributes: trainer.batch
        """
        pass

    def after_pred(self, trainer):
        """Called after forward pass. Note: no longer called natively since predict is removed, keep for legacy or user hooks."""
        pass

    def after_loss(self, trainer):
        """Called after loss calculation. Note: no longer called natively since get_loss is removed, keep for legacy or user hooks."""
        pass

    def after_backward(self, trainer):
        """Called after loss.backward().
        Guaranteed attributes (in addition to before_step): trainer.result (has 'loss')
        """
        pass

    def after_step(self, trainer):
        """Called after opt.step() and opt.zero_grad() but before batch logic cleanup.
        Guaranteed attributes: trainer._did_opt_step (True if optimizer stepped)
        """
        pass

    def before_valid(self, trainer):
        """Will be called between train and val dataloaders within an epoch.
        Guaranteed attributes: trainer.training is False, trainer.dl is valid_dl
        """
        pass

    def after_epoch(self, trainer):
        """Called after an entire epoch is finished (train + valid)."""
        pass

    def after_fit(self, trainer):
        """Called fully after the fit block finishes."""
        pass

    def after_cancel(self, trainer):
        """Called when training is interrupted (e.g. KeyboardInterrupt)."""
        pass

before_fit(trainer)

Called before training starts. Guaranteed attributes: trainer.model, trainer.train_dl, trainer.valid_dl, trainer.opt, trainer.train_step, trainer.eval_step, trainer.epochs, trainer.device, trainer.config, trainer.accelerator, trainer.step_state, trainer.state, trainer.result

Source code in trainer_tools/hooks/base.py
def before_fit(self, trainer):
    """Called before training starts.
    Guaranteed attributes: trainer.model, trainer.train_dl, trainer.valid_dl, trainer.opt, trainer.train_step, trainer.eval_step, trainer.epochs, trainer.device, trainer.config, trainer.accelerator, trainer.step_state, trainer.state, trainer.result
    """
    pass

before_epoch(trainer)

Called before each epoch (train + val). Guaranteed attributes (in addition to above): trainer.start_epoch, trainer.training, trainer.dl

Source code in trainer_tools/hooks/base.py
def before_epoch(self, trainer):
    """Called before each epoch (train + val).
    Guaranteed attributes (in addition to above): trainer.start_epoch, trainer.training, trainer.dl
    """
    pass

before_step(trainer)

Called before processing a batch. Guaranteed attributes: trainer.batch

Source code in trainer_tools/hooks/base.py
def before_step(self, trainer):
    """Called before processing a batch.
    Guaranteed attributes: trainer.batch
    """
    pass

after_pred(trainer)

Called after forward pass. Note: no longer called natively since predict is removed, keep for legacy or user hooks.

Source code in trainer_tools/hooks/base.py
def after_pred(self, trainer):
    """Called after forward pass. Note: no longer called natively since predict is removed, keep for legacy or user hooks."""
    pass

after_loss(trainer)

Called after loss calculation. Note: no longer called natively since get_loss is removed, keep for legacy or user hooks.

Source code in trainer_tools/hooks/base.py
def after_loss(self, trainer):
    """Called after loss calculation. Note: no longer called natively since get_loss is removed, keep for legacy or user hooks."""
    pass

after_backward(trainer)

Called after loss.backward(). Guaranteed attributes (in addition to before_step): trainer.result (has 'loss')

Source code in trainer_tools/hooks/base.py
def after_backward(self, trainer):
    """Called after loss.backward().
    Guaranteed attributes (in addition to before_step): trainer.result (has 'loss')
    """
    pass

after_step(trainer)

Called after opt.step() and opt.zero_grad() but before batch logic cleanup. Guaranteed attributes: trainer._did_opt_step (True if optimizer stepped)

Source code in trainer_tools/hooks/base.py
def after_step(self, trainer):
    """Called after opt.step() and opt.zero_grad() but before batch logic cleanup.
    Guaranteed attributes: trainer._did_opt_step (True if optimizer stepped)
    """
    pass

before_valid(trainer)

Will be called between train and val dataloaders within an epoch. Guaranteed attributes: trainer.training is False, trainer.dl is valid_dl

Source code in trainer_tools/hooks/base.py
def before_valid(self, trainer):
    """Will be called between train and val dataloaders within an epoch.
    Guaranteed attributes: trainer.training is False, trainer.dl is valid_dl
    """
    pass

after_epoch(trainer)

Called after an entire epoch is finished (train + valid).

Source code in trainer_tools/hooks/base.py
def after_epoch(self, trainer):
    """Called after an entire epoch is finished (train + valid)."""
    pass

after_fit(trainer)

Called fully after the fit block finishes.

Source code in trainer_tools/hooks/base.py
def after_fit(self, trainer):
    """Called fully after the fit block finishes."""
    pass

after_cancel(trainer)

Called when training is interrupted (e.g. KeyboardInterrupt).

Source code in trainer_tools/hooks/base.py
def after_cancel(self, trainer):
    """Called when training is interrupted (e.g. KeyboardInterrupt)."""
    pass

trainer_tools.hooks.MainProcessHook

Bases: BaseHook

Marker base class for hooks that should only run on the main process.

Hooks that inherit from this class will be automatically skipped on non-main processes in distributed training, eliminating the need for 'if trainer.is_main' guards inside the hook implementation.

Typical use cases: - Metrics logging - Checkpointing - Progress bars - Any I/O or console output

Source code in trainer_tools/hooks/base.py
class MainProcessHook(BaseHook):
    """
    Marker base class for hooks that should only run on the main process.

    Hooks that inherit from this class will be automatically skipped on
    non-main processes in distributed training, eliminating the need for
    'if trainer.is_main' guards inside the hook implementation.

    Typical use cases:
    - Metrics logging
    - Checkpointing
    - Progress bars
    - Any I/O or console output
    """

    pass

trainer_tools.hooks.LambdaHook

Bases: BaseHook

Creates a hook from callables passed as keyword arguments.

Source code in trainer_tools/hooks/base.py
class LambdaHook(BaseHook):
    """Creates a hook from callables passed as keyword arguments."""

    def __init__(self, **callbacks):
        for k, v in callbacks.items():
            setattr(self, k, v)

Optimization Hooks

trainer_tools.hooks.AMPHook

Bases: BaseHook

A hook to seamlessly add Automatic Mixed Precision (AMP). - Initializes a GradScaler at the beginning of training. - Wraps the forward pass (predict + get_loss) in an autocast context. - Wraps backward and optimizer step with gradient scaling.

Source code in trainer_tools/hooks/optimization.py
class AMPHook(BaseHook):
    """
    A hook to seamlessly add Automatic Mixed Precision (AMP).
    - Initializes a GradScaler at the beginning of training.
    - Wraps the forward pass (predict + get_loss) in an autocast context.
    - Wraps backward and optimizer step with gradient scaling.
    """

    ord = -200  # Must run very early so its replacements are wrapped by others

    def __init__(self, enabled=True, dtype=torch.float16, device_type="cuda"):
        self.enabled, self.dtype, self.device_type = enabled, dtype, device_type

    def before_fit(self, trainer):
        """Called before training starts. Initialize the scaler and wrap operations."""
        has_bf16_params = any(p.dtype == torch.bfloat16 for p in trainer.model.parameters())
        use_scaler = self.enabled and self.dtype == torch.float16 and not has_bf16_params
        trainer.scaler = GradScaler(enabled=use_scaler)
        trainer.autocast = autocast(self.device_type, enabled=self.enabled, dtype=self.dtype)
        log.info(f"Mixed Precision Training: {'Enabled' if self.enabled else 'Disabled'}")

        # Wrap operations
        original_backward = trainer.do_backward
        original_opt_step = trainer.do_opt_step

        original_train_step = trainer.train_step
        original_eval_step = trainer.eval_step

        def amp_train_step(batch, t):
            with trainer.autocast:
                return original_train_step(batch, t)

        def amp_eval_step(batch, t):
            with trainer.autocast:
                return original_eval_step(batch, t)

        trainer.train_step = amp_train_step
        trainer.eval_step = amp_eval_step

        def amp_backward():
            if "loss" in trainer.result:
                scaled_loss = trainer.scaler.scale(trainer.result["loss"])
                scaled_loss.backward()

        def amp_opt_step():
            loss = trainer.result.get("loss", 1)
            if loss != 0:
                trainer.scaler.step(trainer.opt)
                trainer.scaler.update()
                return True
            return False

        trainer.do_backward = amp_backward
        trainer.do_opt_step = amp_opt_step

__init__(enabled=True, dtype=torch.float16, device_type='cuda')

Source code in trainer_tools/hooks/optimization.py
def __init__(self, enabled=True, dtype=torch.float16, device_type="cuda"):
    self.enabled, self.dtype, self.device_type = enabled, dtype, device_type

trainer_tools.hooks.GradientAccumulationHook

Bases: BaseHook

Accumulates gradients over multiple steps.

Source code in trainer_tools/hooks/optimization.py
class GradientAccumulationHook(BaseHook):
    """
    Accumulates gradients over multiple steps.
    """

    ord = -10

    def __init__(self, steps: int = 1):
        self.steps = steps

    def before_fit(self, trainer):
        """Configure StepState and wrap optimizer operations."""
        trainer.step_state.grad_accum_steps = self.steps

        original_backward = trainer.do_backward
        original_opt_step = trainer.do_opt_step
        original_zero_grad = trainer.do_zero_grad

        def grad_accum_backward():
            if "loss" in trainer.result:
                trainer.result["loss"] = trainer.result["loss"] / self.steps
            original_backward()

        def grad_accum_opt_step():
            is_last_batch = (trainer.step_state.batch_idx + 1) >= len(trainer.dl)
            if trainer.step_state.is_grad_accum_boundary(is_last_batch):
                return original_opt_step()
            return False  # Skipped

        def grad_accum_zero_grad():
            is_last_batch = (trainer.step_state.batch_idx + 1) >= len(trainer.dl)
            if trainer.step_state.is_grad_accum_boundary(is_last_batch):
                original_zero_grad()

        trainer.do_backward = grad_accum_backward
        trainer.do_opt_step = grad_accum_opt_step
        trainer.do_zero_grad = grad_accum_zero_grad

__init__(steps=1)

Source code in trainer_tools/hooks/optimization.py
def __init__(self, steps: int = 1):
    self.steps = steps

trainer_tools.hooks.GradClipHook

Bases: BaseHook

Hook to clip gradients after backward pass.

Source code in trainer_tools/hooks/optimization.py
class GradClipHook(BaseHook):
    """Hook to clip gradients after backward pass."""

    ord = -5

    def __init__(self, max_norm=1.0):
        self.max_norm = max_norm

    def before_fit(self, trainer: Trainer):
        original_opt_step = trainer.do_opt_step

        def clip_opt_step():
            # Before optimizer step, unscale and clip if we are stepping
            if trainer.get_hook(AMPHook, None) and hasattr(trainer, "scaler") and trainer.scaler.is_enabled():
                trainer.scaler.unscale_(trainer.opt)
            nn.utils.clip_grad_norm_(trainer.model.parameters(), self.max_norm)
            return original_opt_step()

        trainer.do_opt_step = clip_opt_step

__init__(max_norm=1.0)

Source code in trainer_tools/hooks/optimization.py
def __init__(self, max_norm=1.0):
    self.max_norm = max_norm

trainer_tools.hooks.LRSchedulerHook

Bases: BaseHook

A hook to integrate a PyTorch learning rate scheduler into the training loop.

Source code in trainer_tools/hooks/optimization.py
class LRSchedulerHook(BaseHook):
    """A hook to integrate a PyTorch learning rate scheduler into the training loop."""

    ord = -100  # Run early but after CheckpointHook so it initializes sched before checkpoint loads it

    def __init__(self, sched_fn):
        self.sched_fn = sched_fn

    @property
    def lr(self):
        return self.sched.get_last_lr()[0]

    def before_fit(self, trainer):
        if isinstance(self.sched_fn, torch.optim.lr_scheduler.LRScheduler):
            self.sched = self.sched_fn
        else:
            self.sched = self.sched_fn(trainer.opt)

    def after_step(self, trainer):
        if trainer.training and trainer._did_opt_step:
            self.sched.step()

__init__(sched_fn)

Source code in trainer_tools/hooks/optimization.py
def __init__(self, sched_fn):
    self.sched_fn = sched_fn

Checkpointing and Utilities

trainer_tools.hooks.CheckpointHook

Bases: MainProcessHook

Saves model, optimizer, scheduler, scaler, and RNG states. Can resume training from a checkpoint.

Works transparently in both single-device and distributed (Accelerate) setups.

Source code in trainer_tools/hooks/checkpoint.py
class CheckpointHook(MainProcessHook):
    """
    Saves model, optimizer, scheduler, scaler, and RNG states.
    Can resume training from a checkpoint.

    Works transparently in both single-device and distributed (Accelerate)
    setups.
    """

    ord = -3

    def __init__(
        self,
        save_dir: str,
        save_every_steps: int = 1000,
        keep_last: int = 3,
        resume_path: Optional[str] = None,
        save_strategy: Literal["best", "latest"] = "best",
        metric_name: str = "valid_loss",
    ):
        self.save_dir, self.every, self.keep_last = Path(save_dir), save_every_steps, keep_last
        self.resume_path, self.save_strategy, self.metric = resume_path, save_strategy, metric_name
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.saved_checkpoints: list[Path] = []
        self.config_saved = False
        self._best_metric = float("inf")
        self._best_ckpt_path: Optional[Path] = None

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _unwrap_model(trainer: Trainer, model: nn.Module | None = None):
        """Return the raw model, stripping any DDP/FSDP wrapper."""
        model = model or trainer.model
        if trainer.is_distributed:
            return trainer.accelerator.unwrap_model(model)
        return model

    @staticmethod
    def _get_scaler(trainer: Trainer):
        """Return the GradScaler from AMPHook or Accelerate, or None."""
        if hasattr(trainer, "scaler"):
            return trainer.scaler
        accel = getattr(trainer, "accelerator", None)
        if accel is not None:
            scaler = getattr(accel, "scaler", None)
            if scaler is not None and scaler.is_enabled():
                return scaler
        return None

    def _save_config(self, trainer: Trainer):
        if self.config_saved:
            return
        if not (config := getattr(trainer, "config", None)):
            return
        config_path = self.save_dir / "config.yaml"
        OmegaConf.save(config, config_path, resolve=True)
        log.info(f"Saved config: {config_path}")
        self.config_saved = True

    # ------------------------------------------------------------------
    # Save / load
    # ------------------------------------------------------------------

    def _save(self, trainer: Trainer, filename: str, *, is_best: bool = False, sync: bool = True):
        # Synchronise all processes so every rank has finished the step.
        # Skipped for interrupt saves where a barrier could deadlock.
        if trainer.is_distributed and sync:
            trainer.accelerator.wait_for_everyone()

        path = self.save_dir / filename
        state = {
            "model": self._unwrap_model(trainer).state_dict(),
            "opt": trainer.opt.state_dict(),
            "epoch": trainer.step_state.epoch,
            "optimizer_step": trainer.step_state.optimizer_step,
            "samples_seen": trainer.step_state.samples_seen,
            "rng_torch": torch.get_rng_state(),
            "rng_numpy": np.random.get_state(),
        }
        if torch.cuda.is_available():
            state["rng_cuda"] = torch.cuda.get_rng_state()
        if (scaler := self._get_scaler(trainer)) is not None:
            state["scaler"] = scaler.state_dict()

        from .optimization import LRSchedulerHook

        if sched_hook := trainer.get_hook(LRSchedulerHook, None):
            state["scheduler"] = sched_hook.sched.state_dict()
        if (ema_hook := trainer.get_hook(EMAHook, None)) and ema_hook.ema_model is not None:
            state["ema"] = ema_hook.ema_model.state_dict()

        torch.save(state, path)
        log.info(f"Saved checkpoint: {path}")

        if "interrupted" in filename or "final" in filename:
            return

        if is_best:
            if self._best_ckpt_path and self._best_ckpt_path.exists() and self._best_ckpt_path != path:
                self._best_ckpt_path.unlink()
            self._best_ckpt_path = path
            return

        self.saved_checkpoints.append(path)
        if len(self.saved_checkpoints) <= self.keep_last:
            return
        oldest = self.saved_checkpoints.pop(0)
        if oldest.exists():
            oldest.unlink()

    def before_fit(self, trainer: Trainer):
        self._save_config(trainer)

        if not self.resume_path:
            return
        if Path(self.resume_path).exists():
            self.load_checkpoint(trainer, self.resume_path)
            log.info(
                f"Resumed training from checkpoint: {self.resume_path}, optimizer_step {trainer.step_state.optimizer_step}"
            )
        else:
            log.info(f"Resume path {self.resume_path} does not exist. Starting fresh training.")

    def after_step(self, trainer: Trainer):
        if not (trainer.training and trainer.step_state.optimizer_step > 0):
            return

        check_freq = trainer.step_state.optimizer_step % self.every == 0
        if not check_freq:
            return

        self._save(trainer, f"checkpoint_step_{trainer.step_state.optimizer_step}.pt", is_best=False)

        if self.save_strategy == "best":
            metrics_hook = trainer.get_hook(MetricsHook, None)
            if metrics_hook is None:
                log.warning("No MetricsHook found, switching save_strategy to 'latest'.")
                self.save_strategy = "latest"
                return

            stats = metrics_hook.step_data if self.metric in metrics_hook.step_data else metrics_hook.epoch_data
            current_metric = stats[self.metric]
            if current_metric < self._best_metric:
                self._best_metric = current_metric
                self._save(trainer, f"checkpoint_best_step_{trainer.step_state.optimizer_step}.pt", is_best=True)

    def after_cancel(self, trainer: Trainer):
        self._save(trainer, "checkpoint_interrupted.pt", sync=False)

    def after_fit(self, trainer: Trainer):
        self._save(trainer, "model_final.pt")

        model_to_save = self._unwrap_model(trainer)
        if (ema_hook := trainer.get_hook(EMAHook, None)) and ema_hook.ema_model is not None:
            model_to_save = ema_hook.ema_model
            log.info("Using EMA model for pretrained export")

        save_pretrained(model_to_save, self.save_dir, config=getattr(trainer, "config", None))

    def load_checkpoint(self, trainer: Trainer, path: str):
        """Restore training state.  Runs on **all** processes in distributed mode."""
        if not os.path.exists(path):
            raise FileNotFoundError(f"{path} not found")
        log.info(f"Loading checkpoint from {path}...")
        checkpoint = torch.load(path, map_location=trainer.device, weights_only=False)

        self._unwrap_model(trainer).load_state_dict(checkpoint["model"])
        trainer.opt.load_state_dict(checkpoint["opt"])
        trainer.step_state.epoch = checkpoint.get("epoch", 0)
        trainer.step_state.optimizer_step = checkpoint.get(
            "optimizer_step", checkpoint.get("step", 0)
        )  # Backward compat
        trainer.step_state.samples_seen = checkpoint.get("samples_seen", 0)

        torch.set_rng_state(checkpoint["rng_torch"].cpu())
        if torch.cuda.is_available() and "rng_cuda" in checkpoint:
            torch.cuda.set_rng_state(checkpoint["rng_cuda"].cpu())
        np.random.set_state(checkpoint["rng_numpy"])

        if "scaler" in checkpoint and (scaler := self._get_scaler(trainer)) is not None:
            scaler.load_state_dict(checkpoint["scaler"])

        if "scheduler" in checkpoint:
            from .optimization import LRSchedulerHook

            try:
                trainer.get_hook(LRSchedulerHook).sched.load_state_dict(checkpoint["scheduler"])
            except KeyError:
                log.warning("Checkpoint has scheduler state but no LRSchedulerHook found.")
        if "ema" in checkpoint:
            trainer._ema_state_buffer = checkpoint["ema"]
        log.info(f"Resumed at Epoch {trainer.step_state.epoch}, OptimizerStep {trainer.step_state.optimizer_step}")

__init__(save_dir, save_every_steps=1000, keep_last=3, resume_path=None, save_strategy='best', metric_name='valid_loss')

Source code in trainer_tools/hooks/checkpoint.py
def __init__(
    self,
    save_dir: str,
    save_every_steps: int = 1000,
    keep_last: int = 3,
    resume_path: Optional[str] = None,
    save_strategy: Literal["best", "latest"] = "best",
    metric_name: str = "valid_loss",
):
    self.save_dir, self.every, self.keep_last = Path(save_dir), save_every_steps, keep_last
    self.resume_path, self.save_strategy, self.metric = resume_path, save_strategy, metric_name
    self.save_dir.mkdir(parents=True, exist_ok=True)
    self.saved_checkpoints: list[Path] = []
    self.config_saved = False
    self._best_metric = float("inf")
    self._best_ckpt_path: Optional[Path] = None

trainer_tools.hooks.EMAHook

Bases: BaseHook

Keeps Exponential moving average of a model

Source code in trainer_tools/hooks/ema.py
class EMAHook(BaseHook):
    """Keeps Exponential moving average of a model"""

    ord = 20

    def __init__(self, decay: float = 0.9999):
        self.decay = decay
        self.ema_model = None

    @staticmethod
    def _unwrap(trainer: Trainer):
        """Return the raw model, stripping any DDP/FSDP wrapper."""
        if trainer.is_distributed:
            return trainer.accelerator.unwrap_model(trainer.model)
        return trainer.model

    def before_fit(self, trainer: Trainer):
        self.ema_model = deepcopy(self._unwrap(trainer))
        self.ema_model.eval()
        for p in self.ema_model.parameters():
            p.requires_grad_(False)

        if hasattr(trainer, "_ema_state_buffer"):
            log.info("Loading EMA state from checkpoint buffer...")
            self.ema_model.load_state_dict(trainer._ema_state_buffer)
            del trainer._ema_state_buffer

    def after_step(self, trainer: Trainer):
        if not trainer.training:
            return
        model = self._unwrap(trainer)
        with t.no_grad():
            for p_ema, p_model in zip(self.ema_model.parameters(), model.parameters()):
                p_ema.data.mul_(self.decay).add_(p_model.data, alpha=1 - self.decay)

    def before_valid(self, trainer: Trainer):
        self.temp_model = trainer.model
        trainer.model = self.ema_model

    def after_epoch(self, trainer: Trainer):
        if hasattr(self, "temp_model"):
            trainer.model = self.temp_model
            del self.temp_model

__init__(decay=0.9999)

Source code in trainer_tools/hooks/ema.py
def __init__(self, decay: float = 0.9999):
    self.decay = decay
    self.ema_model = None

trainer_tools.hooks.accelerate.AccelerateHook

Bases: BaseHook

Integrates HF Accelerate into the training loop.

Handles distributed training (DDP/FSDP), mixed precision, gradient accumulation, gradient clipping, and device placement — all through a single hook.

When using AccelerateHook, do not add AMPHook, GradientAccumulationHook, or GradClipHook — their functionality is subsumed by Accelerate. LRSchedulerHook remains compatible and its scheduler will be prepared automatically.

Parameters:

Name Type Description Default
gradient_accumulation_steps int

Number of micro-batches to accumulate before an optimizer update. Accelerate handles loss scaling and gradient synchronisation suppression automatically.

1
max_grad_norm float | None

Maximum gradient norm for clipping (applied only on synchronisation/update steps). None disables clipping.

None
**kwargs Any

Forwarded to accelerate.Accelerator(). Useful options include mixed_precision ("fp16", "bf16", "no"), gradient_accumulation_plugin, log_with, project_dir, etc.

{}
Source code in trainer_tools/hooks/accelerate.py
class AccelerateHook(BaseHook):
    """
    Integrates HF Accelerate into the training loop.

    Handles distributed training (DDP/FSDP), mixed precision, gradient accumulation,
    gradient clipping, and device placement — all through a single hook.

    When using AccelerateHook, do **not** add ``AMPHook``, ``GradientAccumulationHook``,
    or ``GradClipHook`` — their functionality is subsumed by Accelerate.
    ``LRSchedulerHook`` remains compatible and its scheduler will be prepared automatically.

    Args:
        gradient_accumulation_steps: Number of micro-batches to accumulate before
            an optimizer update. Accelerate handles loss scaling and gradient
            synchronisation suppression automatically.
        max_grad_norm: Maximum gradient norm for clipping (applied only on
            synchronisation/update steps). ``None`` disables clipping.
        **kwargs: Forwarded to ``accelerate.Accelerator()``.
            Useful options include ``mixed_precision`` (``"fp16"``, ``"bf16"``, ``"no"``),
            ``gradient_accumulation_plugin``, ``log_with``, ``project_dir``, etc.
    """

    ord = -50

    def __init__(
        self,
        gradient_accumulation_steps: int = 1,
        max_grad_norm: float | None = None,
        **kwargs: Any,
    ):
        self.accelerator = Accelerator(
            gradient_accumulation_steps=gradient_accumulation_steps,
            **kwargs,
        )
        self.max_grad_norm = max_grad_norm
        self._accumulate_ctx = None

    def before_fit(self, trainer: Trainer):
        incompatible = {AMPHook, GradientAccumulationHook, GradClipHook}
        found = [type(h).__name__ for h in trainer.hooks if type(h) in incompatible]
        assert not found, (
            f"AccelerateHook is not compatible with {', '.join(found)}. "
            "Remove them and configure via AccelerateHook / Accelerator kwargs instead."
        )

        trainer.accelerator = self.accelerator
        trainer.device = self.accelerator.device

        trainer.model, trainer.opt, trainer.train_dl, trainer.valid_dl = self.accelerator.prepare(
            trainer.model, trainer.opt, trainer.train_dl, trainer.valid_dl
        )

        if hook := trainer.get_hook(LRSchedulerHook, None):
            hook.sched = self.accelerator.prepare(hook.sched)

        # Wrap operations to use accelerate's backward
        def accelerate_backward():
            if trainer.loss_t is not None:
                self.accelerator.backward(trainer.loss_t)

        trainer.do_backward = accelerate_backward

        log.info(
            "AccelerateHook initialised — device: %s, mixed-precision: %s, grad-accum steps: %s, distributed: %s",
            self.accelerator.device,
            self.accelerator.mixed_precision,
            self.accelerator.gradient_accumulation_steps,
            self.accelerator.distributed_type,
        )

    def before_step(self, trainer: Trainer):
        if trainer.training:
            self._accumulate_ctx = self.accelerator.accumulate(trainer.model)
            self._accumulate_ctx.__enter__()

    def after_backward(self, trainer: Trainer):
        if self.max_grad_norm and self.accelerator.sync_gradients:
            self.accelerator.clip_grad_norm_(trainer.model.parameters(), self.max_grad_norm)

    def after_step(self, trainer: Trainer):
        if trainer.training and self._accumulate_ctx is not None:
            self._accumulate_ctx.__exit__(None, None, None)
            self._accumulate_ctx = None

__init__(gradient_accumulation_steps=1, max_grad_norm=None, **kwargs)

Source code in trainer_tools/hooks/accelerate.py
def __init__(
    self,
    gradient_accumulation_steps: int = 1,
    max_grad_norm: float | None = None,
    **kwargs: Any,
):
    self.accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        **kwargs,
    )
    self.max_grad_norm = max_grad_norm
    self._accumulate_ctx = None

trainer_tools.hooks.ProgressBarHook

Bases: MainProcessHook

A hook to display progress bars for epochs and batches.

Source code in trainer_tools/hooks/pbar.py
class ProgressBarHook(MainProcessHook):
    """A hook to display progress bars for epochs and batches."""

    ord = 5

    def __init__(self, freq=10):
        self.freq = freq

    def before_fit(self, trainer):
        self.epoch_bar = tqdm(
            range(trainer.epochs),
            desc="Epoch",
            initial=trainer.step_state.epoch,
            total=trainer.epochs,
        )

    def before_epoch(self, trainer):
        self._init_pbar(
            trainer,
            desc=f"Epoch {trainer.step_state.epoch + 1}/{trainer.epochs} [Train]",
            initial=trainer.step_state.batch_idx,
        )

    def before_valid(self, trainer):
        self.bar.close()
        self._init_pbar(trainer, desc=f"Epoch {trainer.step_state.epoch + 1}/{trainer.epochs} [Valid]")

    def _init_pbar(self, trainer, desc, initial=0):
        total = len(trainer.dl)
        if trainer.is_distributed:
            total *= trainer.accelerator.num_processes
        self.running_loss, self.count = 0.0, 0
        self.bar = tqdm(
            trainer.dl,
            initial=initial,
            total=total,
            desc=desc,
            leave=False,
        )

    def after_step(self, trainer):
        self.running_loss += trainer.loss
        self.count += 1
        self.bar.update(trainer.accelerator.num_processes if trainer.is_distributed else 1)
        if (self.count - 1) % self.freq == 0:
            self.bar.set_postfix(loss=f"{self.running_loss / self.count:.4f}", refresh=False)

    def after_epoch(self, trainer):
        self.epoch_bar.update(1)
        self.bar.close()

    def after_fit(self, _):
        self.epoch_bar.close()

__init__(freq=10)

Source code in trainer_tools/hooks/pbar.py
def __init__(self, freq=10):
    self.freq = freq