from abc import ABC, abstractmethod
from numba.core.registry import DelayedRegistry, CPUDispatcher
from numba.core.decorators import jit
from numba.core.errors import InternalTargetMismatchError, NumbaValueError
from threading import local as tls
_active_context = tls()
_active_context_default = 'cpu'
class _TargetRegistry(DelayedRegistry):
def __getitem__(self, item):
try:
return super().__getitem__(item)
except KeyError:
msg = "No target is registered against '{}', known targets:\n{}"
known = '\n'.join([f"{k: <{10}} -> {v}"
for k, v in target_registry.items()])
raise NumbaValueError(msg.format(item, known)) from None
# Registry mapping target name strings to Target classes
target_registry = _TargetRegistry()
# Registry mapping Target classes the @jit decorator for that target
jit_registry = DelayedRegistry()
class target_override(object):
"""Context manager to temporarily override the current target with that
prescribed."""
def __init__(self, name):
self._orig_target = getattr(_active_context, 'target',
_active_context_default)
self.target = name
def __enter__(self):
_active_context.target = self.target
def __exit__(self, ty, val, tb):
_active_context.target = self._orig_target
def current_target():
"""Returns the current target
"""
return getattr(_active_context, 'target', _active_context_default)
def get_local_target(context):
"""
Gets the local target from the call stack if available and the TLS
override if not.
"""
# TODO: Should this logic be reversed to prefer TLS override?
if len(context.callstack._stack) > 0:
target = context.callstack[0].target
else:
target = target_registry.get(current_target(), None)
if target is None:
msg = ("The target found is not registered."
"Given target was {}.")
raise ValueError(msg.format(target))
else:
return target
def resolve_target_str(target_str):
"""Resolves a target specified as a string to its Target class."""
return target_registry[target_str]
def resolve_dispatcher_from_str(target_str):
"""Returns the dispatcher associated with a target string"""
target_hw = resolve_target_str(target_str)
return dispatcher_registry[target_hw]
def _get_local_target_checked(tyctx, hwstr, reason):
"""Returns the local target if it is compatible with the given target
name during a type resolution; otherwise, raises an exception.
Parameters
----------
tyctx: typing context
hwstr: str
target name to check against
reason: str
Reason for the resolution. Expects a noun.
Returns
-------
target_hw : Target
Raises
------
InternalTargetMismatchError
"""
# Get the class for the target declared by the function
hw_clazz = resolve_target_str(hwstr)
# get the local target
target_hw = get_local_target(tyctx)
# make sure the target_hw is in the MRO for hw_clazz else bail
if not target_hw.inherits_from(hw_clazz):
raise InternalTargetMismatchError(reason, target_hw, hw_clazz)
return target_hw
class JitDecorator(ABC):
@abstractmethod
def __call__(self):
return NotImplemented
class Target(ABC):
""" Implements a target """
@classmethod
def inherits_from(cls, other):
"""Returns True if this target inherits from 'other' False otherwise"""
return issubclass(cls, other)
class Generic(Target):
"""Mark the target as generic, i.e. suitable for compilation on
any target. All must inherit from this.
"""
class CPU(Generic):
"""Mark the target as CPU.
"""
class GPU(Generic):
"""Mark the target as GPU, i.e. suitable for compilation on a GPU
target.
"""
class CUDA(GPU):
"""Mark the target as CUDA.
"""
class NPyUfunc(Target):
"""Mark the target as a ufunc
"""
target_registry['generic'] = Generic
target_registry['CPU'] = CPU
target_registry['cpu'] = CPU
target_registry['GPU'] = GPU
target_registry['gpu'] = GPU
target_registry['CUDA'] = CUDA
target_registry['cuda'] = CUDA
target_registry['npyufunc'] = NPyUfunc
dispatcher_registry = DelayedRegistry(key_type=Target)
# Register the cpu target token with its dispatcher and jit
cpu_target = target_registry['cpu']
dispatcher_registry[cpu_target] = CPUDispatcher
jit_registry[cpu_target] = jit