Source code for ray.tune.integration.pytorch_lightning

import inspect
import logging
import os
import tempfile
import warnings
from contextlib import contextmanager
from typing import Dict, List, Optional, Type, Union

from ray import train
from ray.train import Checkpoint
from ray.util import log_once
from ray.util.annotations import Deprecated, PublicAPI

try:
    from lightning import Callback, LightningModule, Trainer
except ModuleNotFoundError:
    from pytorch_lightning import Callback, LightningModule, Trainer


logger = logging.getLogger(__name__)

# Get all Pytorch Lightning Callback hooks based on whatever PTL version is being used.
_allowed_hooks = {
    name
    for name, fn in inspect.getmembers(Callback, predicate=inspect.isfunction)
    if name.startswith("on_")
}


def _override_ptl_hooks(callback_cls: Type["TuneCallback"]) -> Type["TuneCallback"]:
    """Overrides all allowed PTL Callback hooks with our custom handle logic."""

    def generate_overridden_hook(fn_name):
        def overridden_hook(
            self,
            trainer: Trainer,
            *args,
            pl_module: Optional[LightningModule] = None,
            **kwargs,
        ):
            if fn_name in self._on:
                self._handle(trainer=trainer, pl_module=pl_module)

        return overridden_hook

    # Set the overridden hook to all the allowed hooks in TuneCallback.
    for fn_name in _allowed_hooks:
        setattr(callback_cls, fn_name, generate_overridden_hook(fn_name))

    return callback_cls


@_override_ptl_hooks
class TuneCallback(Callback):
    """Base class for Tune's PyTorch Lightning callbacks.

    Args:
        When to trigger checkpoint creations. Must be one of
        the PyTorch Lightning event hooks (less the ``on_``), e.g.
        "train_batch_start", or "train_end". Defaults to "validation_end"
    """

    def __init__(self, on: Union[str, List[str]] = "validation_end"):
        if not isinstance(on, list):
            on = [on]

        for hook in on:
            if f"on_{hook}" not in _allowed_hooks:
                raise ValueError(
                    f"Invalid hook selected: {hook}. Must be one of "
                    f"{_allowed_hooks}"
                )

        # Add back the "on_" prefix for internal consistency.
        on = [f"on_{hook}" for hook in on]

        self._on = on

    def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
        raise NotImplementedError


[docs]@PublicAPI class TuneReportCheckpointCallback(TuneCallback): """PyTorch Lightning report and checkpoint callback Saves checkpoints after each validation step. Also reports metrics to Tune, which is needed for checkpoint registration. Args: metrics: Metrics to report to Tune. If this is a list, each item describes the metric key reported to PyTorch Lightning, and it will reported under the same name to Tune. If this is a dict, each key will be the name reported to Tune and the respective value will be the metric key reported to PyTorch Lightning. filename: Filename of the checkpoint within the checkpoint directory. Defaults to "checkpoint". save_checkpoints: If True (default), checkpoints will be saved and reported to Ray. If False, only metrics will be reported. on: When to trigger checkpoint creations and metric reports. Must be one of the PyTorch Lightning event hooks (less the ``on_``), e.g. "train_batch_start", or "train_end". Defaults to "validation_end". Example: .. code-block:: python import pytorch_lightning as pl from ray.tune.integration.pytorch_lightning import ( TuneReportCheckpointCallback) # Save checkpoint after each training batch and after each # validation epoch. trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback( metrics={"loss": "val_loss", "mean_accuracy": "val_acc"}, filename="trainer.ckpt", on="validation_end")]) """ def __init__( self, metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, filename: str = "checkpoint", save_checkpoints: bool = True, on: Union[str, List[str]] = "validation_end", ): super(TuneReportCheckpointCallback, self).__init__(on=on) if isinstance(metrics, str): metrics = [metrics] self._save_checkpoints = save_checkpoints self._filename = filename self._metrics = metrics def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule): # Don't report if just doing initial validation sanity checks. if trainer.sanity_checking: return if not self._metrics: report_dict = {k: v.item() for k, v in trainer.callback_metrics.items()} else: report_dict = {} for key in self._metrics: if isinstance(self._metrics, dict): metric = self._metrics[key] else: metric = key if metric in trainer.callback_metrics: report_dict[key] = trainer.callback_metrics[metric].item() else: logger.warning( f"Metric {metric} does not exist in " "`trainer.callback_metrics." ) return report_dict @contextmanager def _get_checkpoint(self, trainer: Trainer) -> Optional[Checkpoint]: if not self._save_checkpoints: yield None return with tempfile.TemporaryDirectory() as checkpoint_dir: trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename)) checkpoint = Checkpoint.from_directory(checkpoint_dir) yield checkpoint def _handle(self, trainer: Trainer, pl_module: LightningModule): if trainer.sanity_checking: return report_dict = self._get_report_dict(trainer, pl_module) if not report_dict: return with self._get_checkpoint(trainer) as checkpoint: train.report(report_dict, checkpoint=checkpoint)
class _TuneCheckpointCallback(TuneCallback): def __init__(self, *args, **kwargs): raise DeprecationWarning( "`ray.tune.integration.pytorch_lightning._TuneCheckpointCallback` " "is deprecated." ) @Deprecated class TuneReportCallback(TuneReportCheckpointCallback): def __init__( self, metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, on: Union[str, List[str]] = "validation_end", ): if log_once("tune_ptl_report_deprecated"): warnings.warn( "`ray.tune.integration.pytorch_lightning.TuneReportCallback` " "is deprecated. Use " "`ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback`" " instead." ) super(TuneReportCallback, self).__init__( metrics=metrics, save_checkpoints=False, on=on )