Viewing File: /home/ubuntu/.local/lib/python3.10/site-packages/spacy_loggers/pytorch.py

"""
A logger that queries PyTorch metrics and passes that information to downstream loggers.
"""
from typing import Dict, Any, Optional, Tuple, IO
import re
import sys

from spacy import Language
from .util import LoggerT


def pytorch_logger_v1(
    prefix: str = "pytorch",
    device: int = 0,
    cuda_mem_pool: str = "all",
    cuda_mem_metric: str = "all",
) -> LoggerT:
    try:
        import torch
    except ImportError:
        raise ImportError(
            "The 'torch' library could not be found - did you install it? "
            "Alternatively, specify the 'ConsoleLogger' in the "
            "'training.logger' config section, instead of the 'PyTorchLogger'."
        )

    def setup_logger(nlp: Language, stdout: IO = sys.stdout, stderr: IO = sys.stderr):
        expected_cuda_mem_pool = ("all", "large_pool", "small_pool")
        expected_cuda_mem_metric = ("all", "current", "peak", "allocated", "free")

        if cuda_mem_pool not in expected_cuda_mem_pool:
            raise ValueError(
                f"Got CUDA memory pool '{cuda_mem_pool}', but expected one of: '{expected_cuda_mem_pool}'"
            )
        elif cuda_mem_metric not in expected_cuda_mem_metric:
            raise ValueError(
                f"Got CUDA memory metric '{cuda_mem_metric}', but expected one of: '{expected_cuda_mem_metric}'"
            )

        def normalize_mem_value_to_mb(name: str, value: int) -> Tuple[str, float]:
            if "_bytes" in name:
                return re.sub("_bytes", "_megabytes", name), value / (1024.0**2)
            else:
                return name, value

        def log_step(info: Optional[Dict[str, Any]]):
            if info is None:
                return

            cuda_mem_stats = torch.cuda.memory_stats(device)
            for stat, val in cuda_mem_stats.items():
                splits = stat.split(".")
                if len(splits) == 3:
                    name, pool, metric = splits
                    name, val = normalize_mem_value_to_mb(name, val)
                    if pool != cuda_mem_pool:
                        continue
                    elif cuda_mem_metric != "all" and metric != cuda_mem_metric:
                        continue
                    info[f"{prefix}.{name}.{pool}.{metric}"] = val
                elif len(splits) == 2:
                    name, metric = splits
                    name, val = normalize_mem_value_to_mb(name, val)
                    if cuda_mem_metric != "all" and metric != cuda_mem_metric:
                        continue
                    info[f"{prefix}.{name}.{metric}"] = val
                else:
                    # Either global statistic or something that we haven't accounted for,
                    # e.g: a newly added statistic. So, we'll just include it to be safe.
                    info[f"{prefix}.{stat}"] = val

        def finalize():
            pass

        return log_step, finalize

    return setup_logger
Back to Directory File Manager