Viewing File: /home/ubuntu/combine_ai/combine/lib/python3.10/site-packages/torch/distributed/c10d_logger.py

#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
import logging
import time
from typing import Any, Dict, List, Tuple

import torch
import torch.distributed as dist

from torch.distributed.logging_handlers import _log_handlers

__all__: List[str] = []


def _get_or_create_logger() -> logging.Logger:
    logging_handler, log_handler_name = _get_logging_handler()
    logger = logging.getLogger(f"c10d-{log_handler_name}")
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
    )
    logging_handler.setFormatter(formatter)
    logger.propagate = False
    logger.addHandler(logging_handler)
    return logger


def _get_logging_handler(destination: str = "default") -> Tuple[logging.Handler, str]:
    log_handler = _log_handlers[destination]
    log_handler_name = type(log_handler).__name__
    return (log_handler, log_handler_name)


global _c10d_logger
_c10d_logger = _get_or_create_logger()


def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
    if dist.is_initialized():
        msg_dict = {
            "func_name": f"{func_name}",
            "args": f"{args}, {kwargs}",
            "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}",  # type: ignore[arg-type]
            "backend": f"{dist.get_backend(kwargs.get('group'))}",
            "world_size": f"{dist.get_world_size()}",
            "group_size": f"{dist.get_world_size(kwargs.get('group'))}",
            "global_rank": f"{dist.get_rank()}",
            "local_rank": f"{dist.get_rank(kwargs.get('group'))}",
        }
        if msg_dict["backend"] == "nccl":
            nccl_version = torch.cuda.nccl.version()
            msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version)
    else:
        msg_dict = {
            "func_name": f"{func_name}",
            "args": f"{args}, {kwargs}",
        }
    return msg_dict


def _exception_logger(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as error:
            msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
            msg_dict["error"] = f"{error}"
            _c10d_logger.debug(msg_dict)
            raise

    return wrapper


def _time_logger(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        t1 = time.time_ns()
        func_return = func(*args, **kwargs)
        time_spent = time.time_ns() - t1

        msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
        msg_dict["time_spent"] = f"{time_spent}ns"
        _c10d_logger.debug(msg_dict)

        return func_return

    return wrapper
Back to Directory File Manager