Viewing File: /home/ubuntu/combine_ai/combine/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py

import collections
import dataclasses
import functools
import inspect
import itertools
import operator
import sys
import types
from typing import Dict, List

import torch._C
import torch._numpy as tnp
from .. import config, polyfill, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
from ..utils import (
    check_constant_args,
    identity,
    is_tensor_base_attr_getter,
    proxy_args_kwargs,
)
from .base import MutableLocal, VariableTracker
from .dicts import DefaultDictVariable
from .functions import (
    NestedUserFunctionVariable,
    UserFunctionVariable,
    UserMethodVariable,
)
from .user_defined import UserDefinedObjectVariable


class SuperVariable(VariableTracker):
    def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
        super().__init__(**kwargs)
        self.typevar = typevar
        self.objvar = objvar
        self.specialized = specialized  # directly get attr from self.typevar if true

    def reconstruct(self, codegen):
        codegen(variables.BuiltinVariable(super))
        codegen(self.typevar)
        if self.objvar is not None:
            codegen(self.objvar)
            return create_call_function(2, True)
        else:
            return create_call_function(1, True)

    def _resolved_getattr_and_source(self, tx, name):
        assert self.objvar, "1-arg super not implemented"
        if self.specialized:
            return getattr(self.typevar.as_python_constant(), name)
        search_type = self.typevar.as_python_constant()

        # We default to the python type of the object. However, if this is
        # a `type` or subclass of `type`, then the original object represents
        # the user defined type.
        type_to_use = self.objvar.python_type()
        type_to_use_source = (
            TypeSource(self.objvar.source) if self.objvar.source else None
        )
        if issubclass(type_to_use, type):
            type_to_use = self.objvar.value
            type_to_use_source = self.objvar.source

        source = None
        if self.objvar.source is not None:
            # Walk the mro tuple to find out the actual class where the
            # attribute resides.
            search_mro = type_to_use.__mro__
            start_index = search_mro.index(search_type) + 1
            for index in range(start_index, len(search_mro)):
                if hasattr(search_mro[index], name):
                    # Equivalent of something like type(L['self']).__mro__[1].attr_name
                    source = AttrSource(
                        GetItemSource(AttrSource(type_to_use_source, "__mro__"), index),
                        name,
                    )
                    break

        # TODO(jansel): there is a small chance this could trigger user code, prevent that
        return getattr(super(search_type, type_to_use), name), source

    def var_getattr(self, tx, name: str) -> "VariableTracker":
        # Check if getattr is a constant. If not, delay the actual work by
        # wrapping the result in GetAttrVariable. Mostly super is called with a
        # method, so most of the work is delayed to call_function.
        #
        # We could have just implemented a const_getattr. However, super is
        # special when it comes to finding sources. Compared to other VTs, super
        # requires the attr name to walk the mro and find the actual source (and
        # not just AttrSource).
        value, source = self._resolved_getattr_and_source(self, name)
        if not variables.ConstantVariable.is_literal(value):
            return GetAttrVariable(self, name)
        if source:
            install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
            return variables.ConstantVariable.create(value, source=source)
        return variables.ConstantVariable.create(value)

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        inner_fn, source = self._resolved_getattr_and_source(self, name)

        if inner_fn is object.__init__:
            return LambdaVariable(identity)
        elif inner_fn is torch.nn.Module.__init__:
            objvar = self.objvar
            from ..side_effects import AttributeMutationNew

            if (
                isinstance(objvar, variables.UserDefinedObjectVariable)
                and isinstance(objvar.mutable_local, AttributeMutationNew)
                and not (args or kwargs)
            ):
                tx.output.side_effects.store_attr(
                    objvar,
                    "__call_nn_module_init",
                    variables.ConstantVariable.create(True),
                )
                return variables.ConstantVariable.create(None)
            else:
                unimplemented("super() nn.Module.__init__")
        elif isinstance(inner_fn, types.FunctionType):
            return variables.UserFunctionVariable(
                inner_fn, source=source
            ).call_function(tx, [self.objvar] + args, kwargs)
        elif isinstance(inner_fn, types.MethodType):
            return variables.UserMethodVariable(
                inner_fn.__func__, self.objvar, source=source
            ).call_function(tx, args, kwargs)
        elif (
            inner_fn is collections.OrderedDict.__getitem__
            and isinstance(self.objvar, variables.UserDefinedObjectVariable)
            and self.objvar.source
            and len(args) == 1
            and len(kwargs) == 0
            and args[0].is_python_constant()
        ):
            from .builder import VariableBuilder

            key = args[0].as_python_constant()
            return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
                collections.OrderedDict.__getitem__(self.objvar.value, key)
            )
        elif (
            inner_fn in (collections.OrderedDict.__setitem__, object.__setattr__)
            and isinstance(self.objvar, variables.CustomizedDictVariable)
            and args
            and variables.ConstDictVariable.is_valid_key(args[0])
            and self.objvar.mutable_local
        ):
            assert not kwargs and len(args) == 2
            k = variables.ConstDictVariable.get_key(args[0])

            newval = dict(self.objvar.items)
            newval[k] = args[1]
            return tx.replace_all(
                self.objvar,
                self.objvar.modifed(newval),
            )
        else:
            unimplemented(f"non-function or method super: {inner_fn}")


class UnknownVariable(VariableTracker):
    """
    It could be anything!
    """


class DelayGraphBreakVariable(UnknownVariable):
    """
    Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
    """


class ComptimeVariable(VariableTracker):
    """
    This variable is special, it lets you execute arbitrary code at
    Dynamo compile time
    """

    def reconstruct(self, codegen):
        raise NotImplementedError("comptime is special form")

    def var_getattr(self, tx, name: str) -> "VariableTracker":
        from ..comptime import comptime

        # To support the comptime.print_graph convenience accessors
        from .functions import UserFunctionVariable

        return UserFunctionVariable(
            getattr(comptime, name), source=AttrSource(self.source, name)
        )

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        from ..comptime import ComptimeContext

        # TODO: support an expression form as well

        assert not kwargs
        assert len(args) == 1
        fn = args[0]
        if isinstance(fn, UserFunctionVariable):
            fn.get_function()(ComptimeContext(tx))
        elif isinstance(fn, NestedUserFunctionVariable):
            # We have to manually bind the freevars ourselves
            code = fn.get_code()
            assert not fn.closure, (
                "comptime function must not have free variables, "
                f"but these variables were free: {code.co_freevars}"
            )
            func = types.FunctionType(
                code,
                fn.f_globals,
                fn.fn_name.as_python_constant(),
                tuple(fn.defaults.items) if fn.defaults else None,
                # We could automatically promote free variables into
                # ComptimeVar but this is confusing if you access
                # a free variable that we actually DO have the runtime
                # value for
                # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
                tuple(),
            )
            func(ComptimeContext(tx))
        else:
            raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")

        return variables.ConstantVariable.create(None)


class ClosureVariable(UnknownVariable):
    def __init__(self, name, **kwargs):
        super().__init__(**kwargs)
        self.name = name

    def reconstruct(self, codegen):
        return [codegen.create_load_closure(self.name)]


# closure variable created by an inlined function
class InlinedClosureVariable(UnknownVariable):
    def __init__(self, name, **kwargs):
        super().__init__(**kwargs)
        self.name = name

    def reconstruct(self, codegen):
        return [codegen.create_load_closure(self.name)]


class NewCellVariable(VariableTracker):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)


class NewGlobalVariable(VariableTracker):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)


class InspectSignatureVariable(VariableTracker):
    """represents inspect.signature(...)"""

    @staticmethod
    def create(callable, **kwargs):
        if kwargs:
            unimplemented(f"inspect.signature with {kwargs}")
        return InspectSignatureVariable(callable)

    def __init__(self, inspected: VariableTracker, **kwargs):
        super().__init__(**kwargs)
        self.inspected = inspected

    def var_getattr(self, tx, name: str) -> "VariableTracker":
        if name == "parameters":
            return variables.ConstDictVariable(
                {
                    name: InspectParameterVariable()
                    for name in self.inspected.inspect_parameter_names()
                },
                user_cls=dict,
            )
        return super().var_getattr(tx, name)


class InspectParameterVariable(VariableTracker):
    """This is not implemented, if used will graph break."""

    pass


def produce_trampoline_autograd_fwd(fn_cls):
    def trampoline_autograd_fwd(*args, **kwargs):
        return fn_cls.forward(*args, **kwargs)

    trampoline_autograd_fwd._origin = produce_trampoline_autograd_fwd
    return trampoline_autograd_fwd


def produce_trampoline_autograd_bwd(fn_cls):
    def trampoline_autograd_bwd(*args, **kwargs):
        return fn_cls.backward(*args, **kwargs)

    trampoline_autograd_bwd._origin = produce_trampoline_autograd_bwd
    return trampoline_autograd_bwd


def produce_trampoline_autograd_apply(fn_cls):
    def trampoline_autograd_apply(*args, **kwargs):
        return fn_cls.apply(*args, **kwargs)

    trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
    return trampoline_autograd_apply


class AutogradFunctionVariable(VariableTracker):
    """represents a torch.autograd.Function subclass"""

    def __init__(self, fn_cls, **kwargs):
        super().__init__(**kwargs)
        self.fn_cls = fn_cls

    def call_apply(self, tx, args, kwargs):
        requires_grad = False

        def visit(node):
            nonlocal requires_grad
            if isinstance(node, variables.TensorVariable):
                if node.requires_grad is not False:
                    requires_grad = True
            if isinstance(node, variables.NNModuleVariable):
                if node.is_training(tx):
                    requires_grad = True
            return node

        VariableTracker.apply(visit, (args, kwargs))

        ctx = AutogradFunctionContextVariable.create(tx)
        args = [ctx, *args]

        if (
            requires_grad
            and torch.is_grad_enabled()
            and config.capture_autograd_function
        ):
            # Note - this is the same check used in autograd/function.py, except inverted.
            # If we want to support functorch transforms here, we will need to enable this.
            if (
                self.fn_cls.setup_context
                != torch.autograd.function._SingleLevelFunction.setup_context
            ):
                unimplemented(
                    "NYI - autograd.Function with custom setup_context method"
                )

            vjp_fn = self.fn_cls.vjp  # type: ignore[attr-defined]
            if vjp_fn is not torch.autograd.Function.vjp:
                unimplemented("NYI - User defind vjp")

            jvp_fn = self.fn_cls.jvp  # type: ignore[attr-defined]
            if jvp_fn is not torch.autograd.Function.jvp:
                unimplemented("NYI - User defind jvp")

            from .higher_order_ops import (
                safe_or_raise_always_restore,
                TorchHigherOrderOperatorVariable,
            )

            trampoline_autograd_apply = produce_trampoline_autograd_apply(self.fn_cls)
            trampoline_autograd_fwd = produce_trampoline_autograd_fwd(self.fn_cls)
            trampoline_autograd_bwd = produce_trampoline_autograd_bwd(self.fn_cls)

            # NOTE [On Tracing autograd.Function w/ grad]
            # The complex system described here revolves around the soundness evaluation of an autograd.Function in
            # PyTorch. The system follows a well-defined strategy for tracing, which involves three key steps: tracing
            # forward, tracing backward, and if both are sound the potential recording of an "apply" operation into the
            # graph.We trace forward, and evaluate soundness. Soundness, in this context, refers to the absence of side
            # effects, the avoidance of lifting new arguments into the trace, the production of a single tensor output,
            # and a limited input scope confined to contexts, tensors, and constants. If the forward trace is sound,
            # we install any guards accumulated from tracing. If not, we graph break. We trace backward, and evaluate
            # for soundness, same as forward, except with more strictness. We enable a strict mode on the tx, and
            # reject certain ops when running under this strict mode. If the backward trace is sound, we discard the
            # trace by restoring. Otherwise, we raise.

            # if both the forward and backward traces are sound, we write the autograd function’s apply into the graph.

            # For tracing forward and backward, we use UserFunctionVariable. Although it does not directly contribute
            # to soundness evaluation, it plus a  GlobalSource makes sure we can produce valid guards,
            # and that we can inline properly here. Inlining is required in order to be able to ensure that the
            # soundness evaluation works as described above.
            graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()

            module_source = AttrSource(
                tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
            )
            fwd_bwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
                tx.output,
                parent=tx.output.current_tracer,
                source_target="autograd.Function",
            )
            higher_order_autograd_fn = TorchHigherOrderOperatorVariable.make(
                trampoline_autograd_fwd,
                source=AttrSource(module_source, "forward"),
                fwd_bwd_tracer=fwd_bwd_tracer,
            )
            speculated_fwd_result = higher_order_autograd_fn.call_function(
                tx, args, kwargs
            )

            if isinstance(speculated_fwd_result, variables.TupleVariable):
                bwd_args = [ctx, *speculated_fwd_result.items]
            else:
                bwd_args = [ctx, speculated_fwd_result]
            safe_or_raise_always_restore(
                tx,
                graph_checkpoint,
                checkpoint,
                TorchHigherOrderOperatorVariable.make(
                    trampoline_autograd_bwd,
                    source=AttrSource(module_source, "backward"),
                    fwd_bwd_tracer=fwd_bwd_tracer,
                ),
                bwd_args,
            )
            # If fwd and backward are sound, we want apply in the graph.
            # And we don't want backwards for the obvious reasons.
            args = args[1:]
            return TorchHigherOrderOperatorVariable.make(
                trampoline_autograd_apply,
                fwd_bwd_tracer=None,
            ).call_function(tx, args, kwargs)

        if self.source:
            source = AttrSource(AttrSource(self.source, "__class__"), "forward")
        else:
            source = None
        fn = self.fn_cls.forward
        if isinstance(fn, types.FunctionType):
            return variables.UserFunctionVariable(fn, source=source).call_function(
                tx, args, kwargs
            )
        elif isinstance(fn, types.MethodType):
            return variables.UserMethodVariable(
                fn.__func__,
                variables.UserDefinedClassVariable(self.fn_cls),
                source=source,
            ).call_function(tx, args, kwargs)
        else:
            unimplemented(
                f"non-function or method in subclass of torch.autograd.Function: {fn}"
            )

    def call_function(self, tx, args, kwargs):
        return AutogradFunctionVariable(self.fn_cls)

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ):
        from ..allowed_functions import is_user_defined_allowed
        from .builder import wrap_fx_proxy

        if name == "apply":
            if is_user_defined_allowed(self.fn_cls):
                trampoline_autograd_apply = produce_trampoline_autograd_apply(
                    self.fn_cls
                )
                return wrap_fx_proxy(
                    tx=tx,
                    proxy=tx.output.create_proxy(
                        "call_function",
                        trampoline_autograd_apply,
                        *proxy_args_kwargs(args, kwargs),
                    ),
                )
            else:
                return self.call_apply(tx, args, kwargs)
        elif name == "backward":
            with tx.strict_translation_mode():
                if isinstance(self.fn_cls.backward, types.FunctionType):
                    backward = UserFunctionVariable(self.fn_cls.backward)
                elif isinstance(self.fn_cls.backward, types.MethodType):
                    backward = UserMethodVariable(
                        self.fn_cls.backward.__func__,
                        variables.UserDefinedClassVariable(self.fn_cls),
                    )
                    args = [backward.obj] + args
                else:
                    unimplemented(
                        f"backward is a non-function or method: {self.fn_cls.backward}"
                    )

                return tx.inline_call(tx, backward, args, kwargs)

        elif name == "forward":
            if isinstance(self.fn_cls.forward, types.FunctionType):
                forward = UserFunctionVariable(self.fn_cls.forward)
            elif isinstance(self.fn_cls.forward, types.MethodType):
                forward = UserMethodVariable(
                    self.fn_cls.forward.__func__,
                    variables.UserDefinedClassVariable(self.fn_cls),
                )
                args = [forward.obj] + args
            else:
                unimplemented(
                    f"forward is a non-function or method: {self.fn_cls.forward}"
                )

            return tx.inline_call(tx, forward, args, kwargs)

        else:
            unimplemented(f"Unsupported method: {name}")


@dataclasses.dataclass
class SavedTensorBox:
    tensors: List[VariableTracker] = dataclasses.field(default_factory=list)


class AutogradFunctionContextVariable(UserDefinedObjectVariable):
    """
    Tracks an autograd.Function() context using mutation tracking in side_effects.py
    """

    _nonvar_fields = {
        "proxy",
        "inference",
        *UserDefinedObjectVariable._nonvar_fields,
    }

    def __init__(
        self,
        value,
        value_type=None,
        inference=False,
        proxy=None,
        saved_tensors=None,
        **kwargs,
    ):
        super().__init__(value=value, value_type=value_type, **kwargs)
        self.inference = inference
        self.proxy = proxy
        self.saved_tensors = saved_tensors

    @staticmethod
    def create(tx):
        proxy = tx.output.create_proxy(
            "call_function", torch.autograd.function.FunctionCtx, tuple(), {}
        )
        out = tx.output.side_effects.track_object_new(
            None,
            torch.autograd.function.FunctionCtx,
            functools.partial(
                AutogradFunctionContextVariable,
                inference=True,
                proxy=proxy,
                saved_tensors=SavedTensorBox(),
            ),
            {},
        )
        proxy.node.meta["example_value"] = out.value
        return out

    def as_proxy(self):
        if self.proxy is None:
            unimplemented("proxy not set")
        return self.proxy

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        if name != "save_for_backward":
            unimplemented(f"autograd.Function context method: {name}")
        if self.saved_tensors is None:
            unimplemented(
                "save_for_backward only supported on a newly constructed FunctionCtx"
            )

        if not self.inference:
            assert self.source and not kwargs
            tx.output.side_effects.track_save_for_backward(self, args)

        for arg in args:
            self.saved_tensors.tensors.append(arg)
        return variables.ConstantVariable.create(None)

    def var_getattr(self, tx, name):
        if name == "save_for_backward":
            return LambdaVariable(
                lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
            )
        if name == "saved_tensors" and self.saved_tensors is not None:
            return variables.TupleVariable(list(self.saved_tensors.tensors))
        return super().var_getattr(tx, name)


class LambdaVariable(VariableTracker):
    def __init__(self, fn, **kwargs):
        super().__init__(**kwargs)
        self.fn = fn

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        return self.fn(*args, **kwargs)


class GetAttrVariable(VariableTracker):
    def __init__(self, obj, name, **kwargs):
        super().__init__(**kwargs)
        assert isinstance(obj, VariableTracker)
        assert isinstance(name, str)
        self.obj = obj
        self.name = name

    def __str__(self):
        return f"{self.__class__.__name__}({self.obj}, {self.name})"

    @staticmethod
    def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
        return getattr(base_proxy, attr)

    def as_proxy(self):
        return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)

    def const_getattr(self, tx, name):
        if not isinstance(self.obj, variables.NNModuleVariable):
            raise NotImplementedError()
        step1 = tx.output.get_submodule(self.obj.module_key)
        if self.name not in step1.__dict__:
            raise NotImplementedError()
        step2 = inspect.getattr_static(step1, self.name)
        if name not in step2.__dict__:
            raise NotImplementedError()
        return inspect.getattr_static(step2, name)

    def reconstruct(self, codegen):
        codegen(self.obj)
        return codegen.create_load_attrs(self.name)

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        return self.obj.call_method(tx, self.name, args, kwargs)


class MethodWrapperVariable(VariableTracker):
    def __init__(self, method_wrapper, **kwargs):
        super().__init__(**kwargs)
        self.method_wrapper = method_wrapper

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
            args[0], variables.TensorVariable
        ):
            assert len(args) == 1 and len(kwargs) == 0

            return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)

        super().call_function(tx, args, kwargs)

    def is_python_constant(self):
        return True

    def as_python_constant(self):
        return self.method_wrapper


class GetSetDescriptorVariable(VariableTracker):
    def __init__(self, desc, **kwargs):
        super().__init__(**kwargs)
        self.desc = desc

    def var_getattr(self, tx, name):
        if name == "__get__" and self.source:
            from .builder import VariableBuilder

            return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
                self.desc.__get__
            )
        else:
            return super().var_getattr(tx, name)

    def is_python_constant(self):
        return True

    def as_python_constant(self):
        return self.desc


class PythonModuleVariable(VariableTracker):
    def __init__(self, value: types.ModuleType, **kwargs):
        super().__init__(**kwargs)
        self.value = value

    def python_type(self):
        return types.ModuleType


class SkipFilesVariable(VariableTracker):
    def __init__(self, value, reason=None, **kwargs):
        super().__init__(**kwargs)
        self.value = value
        self.reason = reason

    def python_type(self):
        return type(self.value)

    def as_python_constant(self):
        return self.value

    @classmethod
    def create_with_source(cls, value, source):
        install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
        return cls(
            value,
            source=source,
        )

    @staticmethod
    @functools.lru_cache(None)
    def fold_through_function_to_wrapper():
        return {
            collections.namedtuple: variables.UserDefinedClassVariable,
        }

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        from .builtin import BuiltinVariable

        if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
            unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
        # Allowlist a few popular classes(e.g, collections.OrderedDict) calls in skip files.
        elif self.value is collections.OrderedDict:
            return BuiltinVariable.call_custom_dict(
                tx, collections.OrderedDict, *args, **kwargs
            )
        elif (
            self.value is collections.defaultdict
            and len(args) <= 1
            and DefaultDictVariable.is_supported_arg(args[0])
        ):
            return DefaultDictVariable(
                {},
                collections.defaultdict,
                args[0],
                mutable_local=MutableLocal(),
            )
        # Fold through the functions(e.g, collections.namedtuple)
        # that inputs & outputs are all python constants
        elif (
            self.value in self.fold_through_function_to_wrapper().keys()
            and check_constant_args(args, kwargs)
        ):
            value = self.value(
                *[x.as_python_constant() for x in args],
                **{k: v.as_python_constant() for k, v in kwargs.items()},
            )
            return self.fold_through_function_to_wrapper().get(self.value)(
                value, mutable_local=MutableLocal()
            )
        elif (
            self.value is itertools.product
            and not kwargs
            and all(arg.has_unpack_var_sequence(tx) for arg in args)
        ):
            seqs = [arg.unpack_var_sequence(tx) for arg in args]
            items = []
            for item in itertools.product(*seqs):
                items.append(variables.TupleVariable(list(item)))
            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
        elif (
            self.value is itertools.chain
            and not kwargs
            and all(arg.has_unpack_var_sequence(tx) for arg in args)
        ):
            seqs = [arg.unpack_var_sequence(tx) for arg in args]
            items = []
            for item in itertools.chain(*seqs):
                items.append(item)
            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
        elif self.value is itertools.accumulate:
            from .builtin import BuiltinVariable

            if any(key not in ["initial", "func"] for key in kwargs.keys()):
                unimplemented(
                    "Unsupported kwargs for itertools.accumulate: "
                    f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
                )

            acc = kwargs.get("initial")

            if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
                seq = args[0].unpack_var_sequence(tx)

                if "func" in kwargs and len(args) == 1:
                    func = kwargs["func"].call_function
                elif len(args) == 2:
                    func = args[1].call_function
                elif len(args) == 1:
                    # Default to operator.add
                    func = BuiltinVariable(operator.add).call_function
                else:
                    unimplemented(
                        "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
                    )
            else:
                unimplemented("Unsupported arguments for itertools.accumulate")

            items = []
            if acc is not None:
                items.append(acc)
            for item in seq:
                if acc is None:
                    acc = item
                else:
                    try:
                        acc = func(tx, [acc, item], {})
                    except Exception:
                        raise unimplemented(  # noqa: TRY200
                            f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})"
                        )
                items.append(acc)

            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
        elif (
            self.value is itertools.combinations
            and not kwargs
            and len(args) == 2
            and args[0].has_unpack_var_sequence(tx)
            and args[1].is_python_constant()
        ):
            iterable = args[0].unpack_var_sequence(tx)
            r = args[1].as_python_constant()

            items = []
            for item in itertools.combinations(iterable, r):
                items.append(variables.TupleVariable(list(item)))
            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
        elif self.value is itertools.groupby:
            if any(kw != "key" for kw in kwargs.keys()):
                unimplemented(
                    "Unsupported kwargs for itertools.groupby: "
                    f"{','.join(set(kwargs.keys()) - {'key'})}"
                )

            def retrieve_const_key(key):
                if isinstance(key, variables.SymNodeVariable):
                    return key.evaluate_expr()
                elif isinstance(key, variables.ConstantVariable):
                    return key.as_python_constant()
                else:
                    raise unimplemented(
                        "Unsupported key type for itertools.groupby: " + str(type(key))
                    )

            if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
                seq = args[0].unpack_var_sequence(tx)
                keyfunc = (
                    (
                        lambda x: (
                            retrieve_const_key(
                                kwargs.get("key").call_function(tx, [x], {})
                            )
                        )
                    )
                    if "key" in kwargs
                    else None
                )
            else:
                unimplemented("Unsupported arguments for itertools.groupby")

            result = []
            try:
                for k, v in itertools.groupby(seq, key=keyfunc):
                    result.append(
                        variables.TupleVariable(
                            [
                                variables.ConstantVariable.create(k)
                                if variables.ConstantVariable.is_literal(k)
                                else k,
                                variables.ListIteratorVariable(
                                    list(v), mutable_local=MutableLocal()
                                ),
                            ],
                            mutable_local=MutableLocal(),
                        )
                    )
            except Exception:
                raise unimplemented(  # noqa: TRY200
                    "Unexpected failure when calling itertools.groupby"
                )
            return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
        elif (
            self.value is functools.wraps
            and not kwargs
            and len(args) == 1
            and (
                args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx)
            )
        ):

            def wraps(fn):
                if isinstance(fn, variables.NestedUserFunctionVariable):
                    if args[0].source:
                        reconstructible = args[0].source
                    else:
                        reconstructible = args[0]
                    return fn.clone(wrapped_reconstructible=reconstructible)
                unimplemented(f"functools.wraps({fn})")

            return variables.LambdaVariable(wraps)
        elif self.value is collections.deque and not kwargs:
            if len(args) == 0:
                items = []
            elif len(args) == 1 and args[0].has_unpack_var_sequence(tx):
                items = args[0].unpack_var_sequence(tx)
            else:
                unimplemented("deque() with more than 1 arg not supported")
            return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
        elif self.value is functools.partial:
            if not args:
                unimplemented("functools.partial malformed")
            # The first arg, a callable (the ctor below will assert on types)
            fn = args[0]
            rest_args = args[1:]
            # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
            # args and keywords
            return variables.functions.FunctoolsPartialVariable(
                fn, args=rest_args, keywords=kwargs
            )
        elif self.value is itertools.repeat:
            if len(args) < 2:
                return variables.RepeatIteratorVariable(
                    *args, mutable_local=MutableLocal()
                )

            from .builder import SourcelessBuilder

            return tx.inline_user_function_return(
                SourcelessBuilder()(tx, polyfill.repeat), args, kwargs
            )
        elif self.value is itertools.count:
            return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
        elif self.value is itertools.cycle:
            return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
        else:
            try:
                path = inspect.getfile(self.value)
            except TypeError:
                path = f"Builtin {self.value.__name__}"
            msg = f"'skip function {self.value.__qualname__} in file {path}'"
            msg += f"', {self.reason}'" if self.reason else ""
            unimplemented(msg)

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        if (
            self.value in {collections.OrderedDict, collections.defaultdict}
            and name == "fromkeys"
        ):
            from .builtin import BuiltinVariable

            return BuiltinVariable.call_custom_dict_fromkeys(
                tx, self.value, *args, **kwargs
            )
        return super().call_method(tx, name, args, kwargs)


class TypingVariable(VariableTracker):
    def __init__(self, value, **kwargs):
        super().__init__(**kwargs)
        self.value = value

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        if name == "__getitem__" and len(args) == 1:
            return variables.ConstantVariable.create(
                self.value[args[0].as_python_constant()],
            )
        unimplemented("typing")

    def python_type(self):
        return type(self.value)

    def as_python_constant(self):
        return self.value


@functools.lru_cache(maxsize=1)
def get_np_to_tnp_map():
    from ..utils import NP_TO_TNP_MODULE

    np_fn_to_tnp_fn = {}

    for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
        for fn_name, tnp_fn in tnp_mod.__dict__.items():
            if callable(tnp_fn):
                # some internal details do leak from tnp
                # which are not part of numpy API.
                if np_fn := getattr(np_mod, fn_name, None):
                    np_fn_to_tnp_fn[np_fn] = tnp_fn

    return np_fn_to_tnp_fn


class NumpyVariable(VariableTracker):
    """
    Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
    """

    def __init__(self, value, **kwargs):
        super().__init__(**kwargs)
        self.value = value

    def call_function(
        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
    ) -> "VariableTracker":
        if not config.trace_numpy:
            unimplemented(f"numpy.{self.value}()")

        from ..utils import numpy_to_tensor_wrapper

        from .tensor import NumpyNdarrayVariable

        # lookup method name in tnp. Things like np.dtype(float) are not supported yet.
        if self.value.__name__ == "dtype":
            unimplemented(
                f"numpy dtype function is not supported yet. Got type {type(self.value)}."
            )
        else:  # We are dealing with a callable.
            func = get_np_to_tnp_map().get(self.value)
            if func is None:
                unimplemented(
                    f"Can't find numpy function {self.value} in torch._numpy. "
                    " Please file an issue to request support for this function."
                )

            if (
                func.__module__ == "torch._numpy.random"
                and config.use_numpy_random_stream
            ):
                msg = f"delegate '{func.__qualname__}' to NumPy itself via "
                msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}"
                unimplemented(msg)

            # TODO(larryliu0820): currently assuming all numpy.* functions are returning a ndarray that can be
            #  wrapped by NumpyNdarrayVariable which is wrong!
            proxy = tx.output.create_proxy(
                "call_function",
                numpy_to_tensor_wrapper(func),
                *proxy_args_kwargs(args, kwargs),
            )
            return NumpyNdarrayVariable.create(tx, proxy)

    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        unimplemented("numpy")

    def python_type(self):
        return type(self.value)

    def as_python_constant(self):
        return self.value

    def as_proxy(self):
        # this handles numpy dtype attribute such as np.float32. TODO(larryliu0820): we should split NumpyVariable
        #  into NumpyVariable for instances/objects and NumpyVariable for types.
        if config.trace_numpy and isinstance(self.value, type):
            # retrieve attribute str. E.g., "float32" if given np.float32

            attr = self.value.__name__
            # get tnp equivalent
            tnp_dtype = tnp.dtype(attr)
            # returning a string here because we are assuming all `dtype` kwargs for numpy
            # functions can take an equivalent string and the behavior of the function would
            # be the same as taking a numpy dtype.
            return tnp_dtype.name

        return super().as_proxy()


# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
class NullVariable(VariableTracker):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __str__(self):
        return "NullVariable"

    def reconstruct(self, codegen):
        if sys.version_info < (3, 11):
            unimplemented("cannot reconstruct NullVariable in < Python 3.11")
        return [create_instruction("PUSH_NULL")]


class DeletedVariable(VariableTracker):
    """Marker used to implement delattr()"""
Back to Directory File Manager