Viewing File: /home/ubuntu/codegamaai-test/voice_clone/vocie_ai/lib/python3.10/site-packages/omegaconf/_utils.py

import copy
import os
import re
import string
import sys
import warnings
from enum import Enum
from textwrap import dedent
from typing import Any, Dict, List, Match, Optional, Tuple, Type, Union

import yaml

from .errors import (
    ConfigIndexError,
    ConfigTypeError,
    ConfigValueError,
    KeyValidationError,
    OmegaConfBaseException,
    ValidationError,
)

try:
    import dataclasses

except ImportError:  # pragma: no cover
    dataclasses = None  # type: ignore # pragma: no cover

try:
    import attr

except ImportError:  # pragma: no cover
    attr = None  # type: ignore # pragma: no cover


# source: https://yaml.org/type/bool.html
YAML_BOOL_TYPES = [
    "y",
    "Y",
    "yes",
    "Yes",
    "YES",
    "n",
    "N",
    "no",
    "No",
    "NO",
    "true",
    "True",
    "TRUE",
    "false",
    "False",
    "FALSE",
    "on",
    "On",
    "ON",
    "off",
    "Off",
    "OFF",
]


class OmegaConfDumper(yaml.Dumper):  # type: ignore
    str_representer_added = False

    @staticmethod
    def str_representer(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
        with_quotes = yaml_is_bool(data) or is_int(data) or is_float(data)
        return dumper.represent_scalar(
            yaml.resolver.BaseResolver.DEFAULT_SCALAR_TAG,
            data,
            style=("'" if with_quotes else None),
        )


def get_omega_conf_dumper() -> Type[OmegaConfDumper]:
    if not OmegaConfDumper.str_representer_added:
        OmegaConfDumper.add_representer(str, OmegaConfDumper.str_representer)
        OmegaConfDumper.str_representer_added = True
    return OmegaConfDumper


def yaml_is_bool(b: str) -> bool:
    return b in YAML_BOOL_TYPES


def get_yaml_loader() -> Any:
    # Custom constructor that checks for duplicate keys
    # (from https://gist.github.com/pypt/94d747fe5180851196eb)
    def no_duplicates_constructor(
        loader: yaml.Loader, node: yaml.Node, deep: bool = False
    ) -> Any:
        mapping: Dict[str, Any] = {}
        for key_node, value_node in node.value:
            key = loader.construct_object(key_node, deep=deep)
            value = loader.construct_object(value_node, deep=deep)
            if key in mapping:
                raise yaml.constructor.ConstructorError(
                    "while constructing a mapping",
                    node.start_mark,
                    f"found duplicate key {key}",
                    key_node.start_mark,
                )
            mapping[key] = value
        return loader.construct_mapping(node, deep)

    class OmegaConfLoader(yaml.SafeLoader):  # type: ignore
        pass

    loader = OmegaConfLoader
    loader.add_implicit_resolver(
        "tag:yaml.org,2002:float",
        re.compile(
            """^(?:
         [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
        |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
        |\\.[0-9_]+(?:[eE][-+][0-9]+)?
        |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
        |[-+]?\\.(?:inf|Inf|INF)
        |\\.(?:nan|NaN|NAN))$""",
            re.X,
        ),
        list("-+0123456789."),
    )  # type : ignore
    loader.yaml_implicit_resolvers = {
        key: [
            (tag, regexp)
            for tag, regexp in resolvers
            if tag != "tag:yaml.org,2002:timestamp"
        ]
        for key, resolvers in loader.yaml_implicit_resolvers.items()
    }
    loader.add_constructor(
        yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, no_duplicates_constructor
    )
    return loader


def _get_class(path: str) -> type:
    from importlib import import_module

    module_path, _, class_name = path.rpartition(".")
    mod = import_module(module_path)
    try:
        klass: type = getattr(mod, class_name)
    except AttributeError:
        raise ImportError(f"Class {class_name} is not in module {module_path}")
    return klass


def _is_union(type_: Any) -> bool:
    return getattr(type_, "__origin__", None) is Union


def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
    if getattr(type_, "__origin__", None) is Union:
        args = type_.__args__
        if len(args) == 2 and args[1] == type(None):  # noqa E721
            return True, args[0]
    if type_ is Any:
        return True, Any

    return False, type_


def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:
    import typing  # lgtm [py/import-and-import-from]

    forward = typing.ForwardRef if hasattr(typing, "ForwardRef") else typing._ForwardRef  # type: ignore
    if type(type_) is forward:
        return _get_class(f"{module}.{type_.__forward_arg__}")
    else:
        return type_


def _raise_missing_error(obj: Any, name: str) -> None:
    raise ValueError(
        f"Missing default value for {get_type_of(obj).__name__}.{name}, to indicate "
        "default must be populated later use OmegaConf.MISSING"
    )


def get_attr_data(obj: Any) -> Dict[str, Any]:
    from omegaconf.omegaconf import _maybe_wrap

    d = {}
    is_type = isinstance(obj, type)
    obj_type = obj if is_type else type(obj)
    for name, attrib in attr.fields_dict(obj_type).items():
        is_optional, type_ = _resolve_optional(attrib.type)
        is_nested = is_attr_class(type_)
        type_ = _resolve_forward(type_, obj.__module__)
        if not is_type:
            value = getattr(obj, name)
        else:
            value = attrib.default
            if value == attr.NOTHING:
                if is_nested:
                    msg = dedent(
                        f"""
                        The field `{name}` of type '{type_str(type_)}' does not have a default value.
                        The behavior of OmegaConf for such cases is changing in OmegaConf 2.1.
                        See https://github.com/omry/omegaconf/issues/412 for more details.
                        """
                    )
                    warnings.warn(category=UserWarning, message=msg, stacklevel=8)
                    value = type_
                else:
                    _raise_missing_error(obj, name)
                    assert False
        if _is_union(type_):
            e = ConfigValueError(
                f"Union types are not supported:\n{name}: {type_str(type_)}"
            )
            format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))

        d[name] = _maybe_wrap(
            ref_type=type_, is_optional=is_optional, key=name, value=value, parent=None
        )
    return d


def get_dataclass_data(obj: Any) -> Dict[str, Any]:
    from omegaconf.omegaconf import _maybe_wrap

    d = {}
    for field in dataclasses.fields(obj):
        name = field.name
        is_optional, type_ = _resolve_optional(field.type)
        type_ = _resolve_forward(type_, obj.__module__)
        is_nested = is_structured_config(type_)

        if hasattr(obj, name):
            value = getattr(obj, name)
            if value == dataclasses.MISSING:
                _raise_missing_error(obj, name)
                assert False
        else:
            if field.default_factory == dataclasses.MISSING:  # type: ignore
                if is_nested:
                    msg = dedent(
                        f"""
                        The field `{name}` of type '{type_str(type_)}' does not have a default value.
                        The behavior of OmegaConf for such cases is changing in OmegaConf 2.1.
                        See https://github.com/omry/omegaconf/issues/412 for more details.
                        """
                    )
                    warnings.warn(category=UserWarning, message=msg, stacklevel=8)

                    value = type_
                else:
                    _raise_missing_error(obj, name)
                    assert False
            else:
                value = field.default_factory()  # type: ignore

        if _is_union(type_):
            e = ConfigValueError(
                f"Union types are not supported:\n{name}: {type_str(type_)}"
            )
            format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
        d[name] = _maybe_wrap(
            ref_type=type_, is_optional=is_optional, key=name, value=value, parent=None
        )
    return d


def is_dataclass(obj: Any) -> bool:
    from omegaconf.base import Node

    if dataclasses is None or isinstance(obj, Node):
        return False
    return dataclasses.is_dataclass(obj)


def is_attr_class(obj: Any) -> bool:
    from omegaconf.base import Node

    if attr is None or isinstance(obj, Node):
        return False
    return attr.has(obj)


def is_structured_config(obj: Any) -> bool:
    return is_attr_class(obj) or is_dataclass(obj)


def is_dataclass_frozen(type_: Any) -> bool:
    return type_.__dataclass_params__.frozen  # type: ignore


def is_attr_frozen(type_: type) -> bool:
    # This is very hacky and probably fragile as well.
    # Unfortunately currently there isn't an official API in attr that can detect that.
    # noinspection PyProtectedMember
    return type_.__setattr__ == attr._make._frozen_setattrs  # type: ignore


def get_type_of(class_or_object: Any) -> Type[Any]:
    type_ = class_or_object
    if not isinstance(type_, type):
        type_ = type(class_or_object)
    assert isinstance(type_, type)
    return type_


def is_structured_config_frozen(obj: Any) -> bool:
    type_ = get_type_of(obj)

    if is_dataclass(type_):
        return is_dataclass_frozen(type_)
    if is_attr_class(type_):
        return is_attr_frozen(type_)
    return False


def get_structured_config_data(obj: Any) -> Dict[str, Any]:
    if is_dataclass(obj):
        return get_dataclass_data(obj)
    elif is_attr_class(obj):
        return get_attr_data(obj)
    else:
        raise ValueError(f"Unsupported type: {type(obj).__name__}")


class ValueKind(Enum):
    VALUE = 0
    MANDATORY_MISSING = 1
    INTERPOLATION = 2
    STR_INTERPOLATION = 3


def get_value_kind(value: Any, return_match_list: bool = False) -> Any:
    """
    Determine the kind of a value
    Examples:
    MANDATORY_MISSING : "???
    VALUE : "10", "20", True,
    INTERPOLATION: "${foo}", "${foo.bar}"
    STR_INTERPOLATION: "ftp://${host}/path"

    :param value: input string to classify
    :param return_match_list: True to return the match list as well
    :return: ValueKind
    """

    key_prefix = r"\${(\w+:)?"
    legal_characters = r"([\w\.%_ \\/:,-]*?)}"
    match_list: Optional[List[Match[str]]] = None

    def ret(
        value_kind: ValueKind,
    ) -> Union[ValueKind, Tuple[ValueKind, Optional[List[Match[str]]]]]:
        if return_match_list:
            return value_kind, match_list
        else:
            return value_kind

    from .base import Container

    if isinstance(value, Container):
        if value._is_interpolation() or value._is_missing():
            value = value._value()

    value = _get_value(value)
    if value == "???":
        return ret(ValueKind.MANDATORY_MISSING)

    if not isinstance(value, str):
        return ret(ValueKind.VALUE)

    match_list = list(re.finditer(key_prefix + legal_characters, value))
    if len(match_list) == 0:
        return ret(ValueKind.VALUE)

    if len(match_list) == 1 and value == match_list[0].group(0):
        return ret(ValueKind.INTERPOLATION)
    else:
        return ret(ValueKind.STR_INTERPOLATION)


def is_bool(st: str) -> bool:
    st = str.lower(st)
    return st == "true" or st == "false"


def is_float(st: str) -> bool:
    try:
        float(st)
        return True
    except ValueError:
        return False


def is_int(st: str) -> bool:
    try:
        int(st)
        return True
    except ValueError:
        return False


def decode_primitive(s: str) -> Any:
    if is_bool(s):
        return str.lower(s) == "true"

    if is_int(s):
        return int(s)

    if is_float(s):
        return float(s)

    return s


def is_primitive_list(obj: Any) -> bool:
    from .base import Container

    return not isinstance(obj, Container) and isinstance(obj, (list, tuple))


def is_primitive_dict(obj: Any) -> bool:
    t = get_type_of(obj)
    return t is dict


def is_dict_annotation(type_: Any) -> bool:
    origin = getattr(type_, "__origin__", None)
    if sys.version_info < (3, 7, 0):
        return origin is Dict or type_ is Dict  # pragma: no cover
    else:
        return origin is dict  # pragma: no cover


def is_list_annotation(type_: Any) -> bool:
    origin = getattr(type_, "__origin__", None)
    if sys.version_info < (3, 7, 0):
        return origin is List or type_ is List  # pragma: no cover
    else:
        return origin is list  # pragma: no cover


def is_tuple_annotation(type_: Any) -> bool:
    origin = getattr(type_, "__origin__", None)
    if sys.version_info < (3, 7, 0):
        return origin is Tuple or type_ is Tuple  # pragma: no cover
    else:
        return origin is tuple  # pragma: no cover


def is_dict_subclass(type_: Any) -> bool:
    return type_ is not None and isinstance(type_, type) and issubclass(type_, Dict)


def is_dict(obj: Any) -> bool:
    return is_primitive_dict(obj) or is_dict_annotation(obj) or is_dict_subclass(obj)


def is_primitive_container(obj: Any) -> bool:
    return is_primitive_list(obj) or is_primitive_dict(obj)


def get_list_element_type(ref_type: Optional[Type[Any]]) -> Optional[Type[Any]]:
    args = getattr(ref_type, "__args__", None)
    if ref_type is not List and args is not None and args[0] is not Any:
        element_type = args[0]
    else:
        element_type = None

    if not (valid_value_annotation_type(element_type)):
        raise ValidationError(f"Unsupported value type : {element_type}")
    assert element_type is None or isinstance(element_type, type)
    return element_type


def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:

    args = getattr(ref_type, "__args__", None)
    if args is None:
        bases = getattr(ref_type, "__orig_bases__", None)
        if bases is not None and len(bases) > 0:
            args = getattr(bases[0], "__args__", None)

    key_type: Any
    element_type: Any
    if ref_type is None:
        key_type = None
        element_type = None
    else:
        if args is not None:
            key_type = args[0]
            element_type = args[1]
            # None is the sentry for any type
            if key_type is Any:
                key_type = None
            if element_type is Any:
                element_type = None
        else:
            key_type = None
            element_type = None

    if not valid_value_annotation_type(element_type) and not is_structured_config(
        element_type
    ):
        raise ValidationError(f"Unsupported value type : {element_type}")

    if not _valid_dict_key_annotation_type(key_type):
        raise KeyValidationError(f"Unsupported key type {key_type}")
    return key_type, element_type


def valid_value_annotation_type(type_: Any) -> bool:
    return type_ is Any or is_primitive_type(type_) or is_structured_config(type_)


def _valid_dict_key_annotation_type(type_: Any) -> bool:
    return type_ is None or issubclass(type_, str) or issubclass(type_, Enum)


def is_primitive_type(type_: Any) -> bool:
    type_ = get_type_of(type_)
    return issubclass(type_, Enum) or type_ in (int, float, bool, str, type(None))


def _is_interpolation(v: Any) -> bool:
    if isinstance(v, str):
        ret = get_value_kind(v) in (
            ValueKind.INTERPOLATION,
            ValueKind.STR_INTERPOLATION,
        )
        assert isinstance(ret, bool)
        return ret
    return False


def _get_value(value: Any) -> Any:
    from .base import Container
    from .nodes import ValueNode

    if isinstance(value, Container) and value._is_none():
        return None
    if isinstance(value, ValueNode):
        value = value._value()
    return value


def get_ref_type(obj: Any, key: Any = None) -> Optional[Type[Any]]:
    from omegaconf import DictConfig, ListConfig
    from omegaconf.base import Container, Node
    from omegaconf.nodes import ValueNode

    def none_as_any(t: Optional[Type[Any]]) -> Union[Type[Any], Any]:
        if t is None:
            return Any
        else:
            return t

    if isinstance(obj, Container) and key is not None:
        obj = obj._get_node(key)

    is_optional = True
    ref_type = None
    if isinstance(obj, ValueNode):
        is_optional = obj._is_optional()
        ref_type = obj._metadata.ref_type
    elif isinstance(obj, Container):
        if isinstance(obj, Node):
            ref_type = obj._metadata.ref_type
        is_optional = obj._is_optional()
        kt = none_as_any(obj._metadata.key_type)
        vt = none_as_any(obj._metadata.element_type)
        if (
            ref_type is Any
            and kt is Any
            and vt is Any
            and not obj._is_missing()
            and not obj._is_none()
        ):
            ref_type = Any  # type: ignore
        elif not is_structured_config(ref_type):
            if kt is Any:
                kt = Union[str, Enum]
            if isinstance(obj, DictConfig):
                ref_type = Dict[kt, vt]  # type: ignore
            elif isinstance(obj, ListConfig):
                ref_type = List[vt]  # type: ignore
    else:
        if isinstance(obj, dict):
            ref_type = Dict[Union[str, Enum], Any]
        elif isinstance(obj, (list, tuple)):
            ref_type = List[Any]
        else:
            ref_type = get_type_of(obj)

    ref_type = none_as_any(ref_type)
    if is_optional and ref_type is not Any:
        ref_type = Optional[ref_type]  # type: ignore
    return ref_type


def _raise(ex: Exception, cause: Exception) -> None:
    # Set the environment variable OC_CAUSE=1 to get a stacktrace that includes the
    # causing exception.
    env_var = os.environ["OC_CAUSE"] if "OC_CAUSE" in os.environ else None
    debugging = sys.gettrace() is not None
    full_backtrace = (debugging and not env_var == "0") or (env_var == "1")
    if full_backtrace:
        ex.__cause__ = cause
    else:
        ex.__cause__ = None
    raise ex  # set end OC_CAUSE=1 for full backtrace


def format_and_raise(
    node: Any,
    key: Any,
    value: Any,
    msg: str,
    cause: Exception,
    type_override: Any = None,
) -> None:
    from omegaconf import OmegaConf
    from omegaconf.base import Node

    if isinstance(cause, OmegaConfBaseException) and cause._initialized:
        ex = cause
        if type_override is not None:
            ex = type_override(str(cause))
            ex.__dict__ = copy.deepcopy(cause.__dict__)
        _raise(ex, cause)

    object_type: Optional[Type[Any]]
    object_type_str: Optional[str] = None
    ref_type: Optional[Type[Any]]
    ref_type_str: Optional[str]

    child_node: Optional[Node] = None
    if node is None:
        full_key = ""
        object_type = None
        ref_type = None
        ref_type_str = None
    else:
        if key is not None and not OmegaConf.is_none(node):
            child_node = node._get_node(key, validate_access=False)

        full_key = node._get_full_key(key=key)

        object_type = OmegaConf.get_type(node)
        object_type_str = type_str(object_type)

        ref_type = get_ref_type(node)
        ref_type_str = type_str(ref_type)

    msg = string.Template(msg).substitute(
        REF_TYPE=ref_type_str,
        OBJECT_TYPE=object_type_str,
        KEY=key,
        FULL_KEY=full_key,
        VALUE=value,
        VALUE_TYPE=f"{type(value).__name__}",
        KEY_TYPE=f"{type(key).__name__}",
    )

    template = """$MSG
\tfull_key: $FULL_KEY
\treference_type=$REF_TYPE
\tobject_type=$OBJECT_TYPE"""

    s = string.Template(template=template)

    message = s.substitute(
        REF_TYPE=ref_type_str, OBJECT_TYPE=object_type_str, MSG=msg, FULL_KEY=full_key
    )
    exception_type = type(cause) if type_override is None else type_override
    if exception_type == TypeError:
        exception_type = ConfigTypeError
    elif exception_type == IndexError:
        exception_type = ConfigIndexError

    ex = exception_type(f"{message}")
    if issubclass(exception_type, OmegaConfBaseException):
        ex._initialized = True
        ex.msg = message
        ex.parent_node = node
        ex.child_node = child_node
        ex.key = key
        ex.full_key = full_key
        ex.value = value
        ex.object_type = object_type
        ex.object_type_str = object_type_str
        ex.ref_type = ref_type
        ex.ref_type_str = ref_type_str

    _raise(ex, cause)


def type_str(t: Any) -> str:
    is_optional, t = _resolve_optional(t)
    if t is None:
        return type(t).__name__
    if t is Any:
        return "Any"
    if sys.version_info < (3, 7, 0):  # pragma: no cover
        # Python 3.6
        if hasattr(t, "__name__"):
            name = str(t.__name__)
        else:
            if t.__origin__ is not None:
                name = type_str(t.__origin__)
            else:
                name = str(t)
                if name.startswith("typing."):
                    name = name[len("typing.") :]
    else:  # pragma: no cover
        # Python >= 3.7
        if hasattr(t, "__name__"):
            name = str(t.__name__)
        else:
            if t._name is None:
                if t.__origin__ is not None:
                    name = type_str(t.__origin__)
            else:
                name = str(t._name)

    args = getattr(t, "__args__", None)
    if args is not None:
        args = ", ".join([type_str(t) for t in (list(t.__args__))])
        ret = f"{name}[{args}]"
    else:
        ret = name
    if is_optional:
        return f"Optional[{ret}]"
    else:
        return ret


def _ensure_container(target: Any) -> Any:
    from omegaconf import OmegaConf

    if is_primitive_container(target):
        assert isinstance(target, (list, dict))
        target = OmegaConf.create(target)
    elif is_structured_config(target):
        target = OmegaConf.structured(target)
    assert OmegaConf.is_config(target)
    return target


def is_generic_list(type_: Any) -> bool:
    """
    Checks if a type is a generic list, for example:
    list returns False
    typing.List returns False
    typing.List[T] returns True

    :param type_: variable type
    :return: bool
    """
    return is_list_annotation(type_) and get_list_element_type(type_) is not None


def is_generic_dict(type_: Any) -> bool:
    """
    Checks if a type is a generic dict, for example:
    list returns False
    typing.List returns False
    typing.List[T] returns True

    :param type_: variable type
    :return: bool
    """
    return is_dict_annotation(type_) and len(get_dict_key_value_types(type_)) > 0
Back to Directory File Manager