Skip to content

Metrics Hooks

trainer_tools.hooks.metrics.MetricsHook

Bases: MainProcessHook

Aggregates data from multiple Metrics and logs to console/tracker. Only ONE instance of this hook is needed per Trainer.

Source code in trainer_tools/hooks/metrics/metrics_hook.py
class MetricsHook(MainProcessHook):
    """
    Aggregates data from multiple Metrics and logs to console/tracker.
    Only ONE instance of this hook is needed per Trainer.
    """

    ord = -10

    def __init__(
        self,
        metrics: List[Metric],
        verbose: bool = True,
        tracker_type: Optional[str] = None,
        config: Union[dict, str, None] = None,
        freq: int = 1,
        log_file: Optional[str] = "metrics.jsonl",
        name: Optional[str] = None,
        project: Optional[str] = None,
        **tracker_kwargs: Any,
    ):
        """Aggregates Metric outputs and logs them to the console and/or a tracker.

        Args:
            metrics: List of :class:`Metric` instances to evaluate and aggregate.
            verbose: If ``True``, all metric keys are printed after each epoch.
                If ``False``, only keys containing ``"loss"`` are printed.
            tracker_type: Backend to use for logging. Supported values are
                ``"wandb"``, ``"trackio"``, ``"mlop"``, and ``"file"``.
                ``"file"`` writes metrics to *log_file* as JSONL.
                ``None`` logs to stdout only (via the standard logger).
            config: Hyperparameter config forwarded to the tracker on init.
                Accepts a ``dict`` or a JSON-encoded string.
            freq: Step interval at which step-level metrics are logged.
            log_file: Path to the JSONL file used when ``tracker_type`` is
                ``"file"`` (or ``None``). Defaults to ``"metrics.jsonl"``.
            name: Run name forwarded to the tracker. Convenience alias for
                passing ``name`` via *tracker_kwargs*.
            project: Project name forwarded to the tracker. Convenience alias
                for passing ``project`` via *tracker_kwargs*.
            **tracker_kwargs: Additional keyword arguments passed to the
                tracker's ``init`` call.
        """
        if name is not None:
            tracker_kwargs.setdefault("name", name)
        if project is not None:
            tracker_kwargs.setdefault("project", project)
        self.verbose, self.tracker_kwargs = verbose, tracker_kwargs
        self.config = flatten_config(json.loads(config) if isinstance(config, str) else config or {})
        self.freq, self.log_file = freq, Path(log_file) if log_file else None

        self.metric_types = metrics
        self._phases: dict[str, list[Metric]] = defaultdict(list)
        for m in self.metric_types:
            self._phases[m.phase].append(m)

        # Buffers & History
        self.step_data = {}
        self.epoch_data = {}
        self.aggregators = defaultdict(float)
        self.counts = defaultdict(int)
        self._init_tracker(tracker_type)

    def _init_tracker(self, t_type):
        self.tracker, self.use_tracker, self.use_file = None, False, False
        if t_type is None:
            return
        if t_type == "file":
            self.use_file = True
            if self.log_file:
                self.log_file.parent.mkdir(parents=True, exist_ok=True)
                self.log_file.unlink(missing_ok=True)
            return
        if t_type not in _TRACKERS:
            log.warning(f"Tracker '{t_type}' not found or not installed.")
            return
        self.tracker, self.use_tracker = _TRACKERS[t_type], True
        if t_type == "trackio":
            self.tracker_kwargs.setdefault("embed", False)  # improves performance

    def _run_metrics(self, trainer, phase):
        prefix = "train" if trainer.training else "valid"
        for m in self._phases[phase]:
            if not m.should_run(trainer):
                continue

            data = m(trainer)
            if not data:
                continue
            p_data = data if not m.use_prefix else {f"{prefix}_{k}": v for k, v in data.items()}

            for k, v in p_data.items():
                if isinstance(v, torch.Tensor):
                    v = v.detach().cpu().item() if v.numel() == 1 else v.detach().cpu().numpy().tolist()

                self.step_data[k] = v
                if isinstance(v, (int, float)):
                    self.aggregators[k] += v
                    self.counts[k] += 1

    def before_fit(self, trainer):
        dl = getattr(trainer, "dl", getattr(trainer, "train_dl"))
        self.steps_per_epoch = len(dl)
        if self.use_tracker:
            self.tracker.init(config=self.config, **self.tracker_kwargs)

    def before_epoch(self, trainer):
        self.aggregators.clear()
        self.counts.clear()

    def after_pred(self, trainer):
        self._run_metrics(trainer, "after_pred")

    def after_loss(self, trainer):
        self._run_metrics(trainer, "after_loss")

    def after_backward(self, trainer):
        self._run_metrics(trainer, "after_backward")

    def after_step(self, trainer):
        self._run_metrics(trainer, "after_step")
        if not trainer.training:
            self.step_data.clear()
            return

        self.step_data["step"] = trainer.step_state.samples_seen
        if getattr(trainer, "_did_opt_step", False):
            # optimizer_step not yet incremented in step_state, so we add 1 for freq check
            if (trainer.step_state.optimizer_step + 1) % self.freq == 0:
                if self.use_tracker:
                    current_step = self.step_data.pop("step", trainer.step_state.samples_seen)
                    self.tracker.log(self.step_data, current_step)
                elif self.use_file:
                    with open(self.log_file, "a") as f:
                        f.write(json.dumps(self.step_data) + "\n")
            self.step_data.clear()

    def after_epoch(self, trainer):
        self.epoch_data = epoch_means = {k: self.aggregators[k] / self.counts[k] for k in self.aggregators}
        val_stats = {k: v for k, v in epoch_means.items() if k.startswith("valid_")}

        if self.use_tracker and val_stats:
            self.tracker.log(val_stats, trainer.step_state.samples_seen)
        elif self.use_file and val_stats:
            with open(self.log_file, "a") as f:
                f.write(json.dumps({"epoch": trainer.step_state.epoch, **epoch_means}) + "\n")

        logs = [f"Epoch {trainer.step_state.epoch + 1}/{trainer.epochs}"]

        for k in sorted(epoch_means.keys()):
            if self.verbose or "loss" in k.lower():
                logs.append(f"{k}: {epoch_means[k]:.4f}")

        log.info(" | ".join(logs))

    def after_fit(self, trainer):
        if self.use_tracker:
            self.tracker.finish()

    def plot(self, axes=None, metrics=["loss"], show_epochs=False):
        self.history = load_metrics(self.log_file)
        all_roots = {k.replace("train_", "").replace("valid_", "") for cat in self.history for k in self.history[cat]}
        keys = sorted(all_roots) if metrics is None else [m for m in metrics if m in all_roots]

        if not keys:
            return

        if axes is None:
            fig, axes = plt.subplots(len(keys), 1, figsize=(8, 4 * len(keys)))
        axes = np.atleast_1d(axes)

        steps_per_epoch = getattr(self, "steps_per_epoch", None)

        for ax, root in zip(axes, keys):
            for cat in ["step", "epoch"]:
                if cat == "epoch" and not show_epochs:
                    continue
                for pre in ["train_", "valid_", ""]:
                    full_key = f"{pre}{root}"
                    if full_key in self.history[cat]:
                        data = self.history[cat][full_key]
                        if cat == "epoch" and steps_per_epoch is not None:
                            data = data.copy()
                            data[0] = (data[0] + 1) * steps_per_epoch
                        fmt = "o-" if "valid" in pre or cat == "epoch" else "-"
                        ax.plot(*data, fmt, label=f"{cat} {pre}{root}".strip())

            ax.set_ylabel(root.replace("_", " ").title())
            ax.legend()
            ax.grid(True, alpha=0.3)

        axes[-1].set_xlabel("Step")
        plt.tight_layout()

__init__(metrics, verbose=True, tracker_type=None, config=None, freq=1, log_file='metrics.jsonl', name=None, project=None, **tracker_kwargs)

Aggregates Metric outputs and logs them to the console and/or a tracker.

Parameters:

Name Type Description Default
metrics List[Metric]

List of :class:Metric instances to evaluate and aggregate.

required
verbose bool

If True, all metric keys are printed after each epoch. If False, only keys containing "loss" are printed.

True
tracker_type Optional[str]

Backend to use for logging. Supported values are "wandb", "trackio", "mlop", and "file". "file" writes metrics to log_file as JSONL. None logs to stdout only (via the standard logger).

None
config Union[dict, str, None]

Hyperparameter config forwarded to the tracker on init. Accepts a dict or a JSON-encoded string.

None
freq int

Step interval at which step-level metrics are logged.

1
log_file Optional[str]

Path to the JSONL file used when tracker_type is "file" (or None). Defaults to "metrics.jsonl".

'metrics.jsonl'
name Optional[str]

Run name forwarded to the tracker. Convenience alias for passing name via tracker_kwargs.

None
project Optional[str]

Project name forwarded to the tracker. Convenience alias for passing project via tracker_kwargs.

None
**tracker_kwargs Any

Additional keyword arguments passed to the tracker's init call.

{}
Source code in trainer_tools/hooks/metrics/metrics_hook.py
def __init__(
    self,
    metrics: List[Metric],
    verbose: bool = True,
    tracker_type: Optional[str] = None,
    config: Union[dict, str, None] = None,
    freq: int = 1,
    log_file: Optional[str] = "metrics.jsonl",
    name: Optional[str] = None,
    project: Optional[str] = None,
    **tracker_kwargs: Any,
):
    """Aggregates Metric outputs and logs them to the console and/or a tracker.

    Args:
        metrics: List of :class:`Metric` instances to evaluate and aggregate.
        verbose: If ``True``, all metric keys are printed after each epoch.
            If ``False``, only keys containing ``"loss"`` are printed.
        tracker_type: Backend to use for logging. Supported values are
            ``"wandb"``, ``"trackio"``, ``"mlop"``, and ``"file"``.
            ``"file"`` writes metrics to *log_file* as JSONL.
            ``None`` logs to stdout only (via the standard logger).
        config: Hyperparameter config forwarded to the tracker on init.
            Accepts a ``dict`` or a JSON-encoded string.
        freq: Step interval at which step-level metrics are logged.
        log_file: Path to the JSONL file used when ``tracker_type`` is
            ``"file"`` (or ``None``). Defaults to ``"metrics.jsonl"``.
        name: Run name forwarded to the tracker. Convenience alias for
            passing ``name`` via *tracker_kwargs*.
        project: Project name forwarded to the tracker. Convenience alias
            for passing ``project`` via *tracker_kwargs*.
        **tracker_kwargs: Additional keyword arguments passed to the
            tracker's ``init`` call.
    """
    if name is not None:
        tracker_kwargs.setdefault("name", name)
    if project is not None:
        tracker_kwargs.setdefault("project", project)
    self.verbose, self.tracker_kwargs = verbose, tracker_kwargs
    self.config = flatten_config(json.loads(config) if isinstance(config, str) else config or {})
    self.freq, self.log_file = freq, Path(log_file) if log_file else None

    self.metric_types = metrics
    self._phases: dict[str, list[Metric]] = defaultdict(list)
    for m in self.metric_types:
        self._phases[m.phase].append(m)

    # Buffers & History
    self.step_data = {}
    self.epoch_data = {}
    self.aggregators = defaultdict(float)
    self.counts = defaultdict(int)
    self._init_tracker(tracker_type)

Included Metrics

Base Metric Class

trainer_tools.hooks.Metric

Bases: ABC

Base class for data collection strategies.

Parameters:

Name Type Description Default
name str

Identifier for the metric.

None
freq int

How often to collect (in steps) during TRAINING. Validation always collects every step.

1
phase str

The hook method name where collection occurs (e.g. 'after_loss').

'after_step'
use_prefix bool

If True (default), keys are prefixed with 'train_'/'valid_'. If False, keys are logged exactly as returned (e.g. 'grad_norm').

True
Source code in trainer_tools/hooks/metrics/base.py
class Metric(ABC):
    """
    Base class for data collection strategies.

    Args:
        name: Identifier for the metric.
        freq: How often to collect (in steps) during TRAINING.
              Validation always collects every step.
        phase: The hook method name where collection occurs (e.g. 'after_loss').
        use_prefix: If True (default), keys are prefixed with 'train_'/'valid_'.
                    If False, keys are logged exactly as returned (e.g. 'grad_norm').
    """

    def __init__(self, name: str = None, freq: int = 1, phase: str = "after_step", use_prefix: bool = True):
        self.name = name
        self.freq = freq
        self.phase = phase
        self.use_prefix = use_prefix

    def should_run(self, trainer) -> bool:
        if not trainer.training:
            return True
        return trainer.step_state.optimizer_step % self.freq == 0

    def get_value(self, trainer, key, fn=None):
        """Helper to extract a value from trainer.result or compute it via fn."""
        if fn is not None:
            return fn(trainer)
        if key not in trainer.result:
            cb = "train_step" if trainer.model.training else "eval_step"
            raise KeyError(f"Metric requested key '{key}' but it was not returned by {cb}.")
        return trainer.result[key]

    @abstractmethod
    def __call__(self, trainer) -> dict:
        """Return a dictionary of scalar metrics."""
        pass

__init__(name=None, freq=1, phase='after_step', use_prefix=True)

Source code in trainer_tools/hooks/metrics/base.py
def __init__(self, name: str = None, freq: int = 1, phase: str = "after_step", use_prefix: bool = True):
    self.name = name
    self.freq = freq
    self.phase = phase
    self.use_prefix = use_prefix

__call__(trainer) abstractmethod

Return a dictionary of scalar metrics.

Source code in trainer_tools/hooks/metrics/base.py
@abstractmethod
def __call__(self, trainer) -> dict:
    """Return a dictionary of scalar metrics."""
    pass

should_run(trainer)

Source code in trainer_tools/hooks/metrics/base.py
def should_run(self, trainer) -> bool:
    if not trainer.training:
        return True
    return trainer.step_state.optimizer_step % self.freq == 0

get_value(trainer, key, fn=None)

Helper to extract a value from trainer.result or compute it via fn.

Source code in trainer_tools/hooks/metrics/base.py
def get_value(self, trainer, key, fn=None):
    """Helper to extract a value from trainer.result or compute it via fn."""
    if fn is not None:
        return fn(trainer)
    if key not in trainer.result:
        cb = "train_step" if trainer.model.training else "eval_step"
        raise KeyError(f"Metric requested key '{key}' but it was not returned by {cb}.")
    return trainer.result[key]

Loss Metric

trainer_tools.hooks.Loss

Bases: Metric

Source code in trainer_tools/hooks/metrics/base.py
class Loss(Metric):
    def __init__(self, freq=1, loss_key="loss", loss_fn=None):
        super().__init__("loss", freq, phase="after_step")
        self.loss_key = loss_key
        self.loss_fn = loss_fn

    def __call__(self, trainer: Trainer):
        val = self.get_value(trainer, self.loss_key, self.loss_fn)

        if isinstance(val, torch.Tensor):
            val = val.item()

        return {self.name: val}

__init__(freq=1, loss_key='loss', loss_fn=None)

Source code in trainer_tools/hooks/metrics/base.py
def __init__(self, freq=1, loss_key="loss", loss_fn=None):
    super().__init__("loss", freq, phase="after_step")
    self.loss_key = loss_key
    self.loss_fn = loss_fn

Accuracy Metric

trainer_tools.hooks.Accuracy

Bases: Metric

Source code in trainer_tools/hooks/metrics/base.py
class Accuracy(Metric):
    def __init__(self, name="accuracy", freq=1, pred_key="preds", target_key="targets", pred_fn=None, target_fn=None):
        super().__init__(name, freq, phase="after_step")
        self.pred_key = pred_key
        self.target_key = target_key
        self.pred_fn = pred_fn
        self.target_fn = target_fn

    def __call__(self, trainer: Trainer):
        if self.pred_fn is not None:
            preds = self.pred_fn(trainer)
        else:
            logits = self.get_value(trainer, self.pred_key)
            preds = logits.argmax(dim=-1) if logits.ndim > 1 else (logits > 0.5)

        target = self.get_value(trainer, self.target_key, self.target_fn)

        return {self.name: (preds == target).float().mean().item()}

__init__(name='accuracy', freq=1, pred_key='preds', target_key='targets', pred_fn=None, target_fn=None)

Source code in trainer_tools/hooks/metrics/base.py
def __init__(self, name="accuracy", freq=1, pred_key="preds", target_key="targets", pred_fn=None, target_fn=None):
    super().__init__(name, freq, phase="after_step")
    self.pred_key = pred_key
    self.target_key = target_key
    self.pred_fn = pred_fn
    self.target_fn = target_fn