from warnings import warn
from numba.core import types, config, sigutils
from numba.core.errors import DeprecationError, NumbaInvalidConfigWarning
from numba.cuda.compiler import declare_device_function
from numba.cuda.dispatcher import CUDADispatcher
from numba.cuda.simulator.kernel import FakeCUDAKernel
_msg_deprecated_signature_arg = ("Deprecated keyword argument `{0}`. "
"Signatures should be passed as the first "
"positional argument.")
def jit(func_or_sig=None, device=False, inline=False, link=[], debug=None,
opt=True, lineinfo=False, cache=False, **kws):
"""
JIT compile a Python function for CUDA GPUs.
:param func_or_sig: A function to JIT compile, or *signatures* of a
function to compile. If a function is supplied, then a
:class:`Dispatcher <numba.cuda.dispatcher.CUDADispatcher>` is returned.
Otherwise, ``func_or_sig`` may be a signature or a list of signatures,
and a function is returned. The returned function accepts another
function, which it will compile and then return a :class:`Dispatcher
<numba.cuda.dispatcher.CUDADispatcher>`. See :ref:`jit-decorator` for
more information about passing signatures.
.. note:: A kernel cannot have any return value.
:param device: Indicates whether this is a device function.
:type device: bool
:param link: A list of files containing PTX or CUDA C/C++ source to link
with the function
:type link: list
:param debug: If True, check for exceptions thrown when executing the
kernel. Since this degrades performance, this should only be used for
debugging purposes. If set to True, then ``opt`` should be set to False.
Defaults to False. (The default value can be overridden by setting
environment variable ``NUMBA_CUDA_DEBUGINFO=1``.)
:param fastmath: When True, enables fastmath optimizations as outlined in
the :ref:`CUDA Fast Math documentation <cuda-fast-math>`.
:param max_registers: Request that the kernel is limited to using at most
this number of registers per thread. The limit may not be respected if
the ABI requires a greater number of registers than that requested.
Useful for increasing occupancy.
:param opt: Whether to compile from LLVM IR to PTX with optimization
enabled. When ``True``, ``-opt=3`` is passed to NVVM. When
``False``, ``-opt=0`` is passed to NVVM. Defaults to ``True``.
:type opt: bool
:param lineinfo: If True, generate a line mapping between source code and
assembly code. This enables inspection of the source code in NVIDIA
profiling tools and correlation with program counter sampling.
:type lineinfo: bool
:param cache: If True, enables the file-based cache for this function.
:type cache: bool
"""
if link and config.ENABLE_CUDASIM:
raise NotImplementedError('Cannot link PTX in the simulator')
if kws.get('boundscheck'):
raise NotImplementedError("bounds checking is not supported for CUDA")
if kws.get('argtypes') is not None:
msg = _msg_deprecated_signature_arg.format('argtypes')
raise DeprecationError(msg)
if kws.get('restype') is not None:
msg = _msg_deprecated_signature_arg.format('restype')
raise DeprecationError(msg)
if kws.get('bind') is not None:
msg = _msg_deprecated_signature_arg.format('bind')
raise DeprecationError(msg)
debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
fastmath = kws.get('fastmath', False)
extensions = kws.get('extensions', [])
if debug and opt:
msg = ("debug=True with opt=True (the default) "
"is not supported by CUDA. This may result in a crash"
" - set debug=False or opt=False.")
warn(NumbaInvalidConfigWarning(msg))
if debug and lineinfo:
msg = ("debug and lineinfo are mutually exclusive. Use debug to get "
"full debug info (this disables some optimizations), or "
"lineinfo for line info only with code generation unaffected.")
warn(NumbaInvalidConfigWarning(msg))
if device and kws.get('link'):
raise ValueError("link keyword invalid for device function")
if sigutils.is_signature(func_or_sig):
signatures = [func_or_sig]
specialized = True
elif isinstance(func_or_sig, list):
signatures = func_or_sig
specialized = False
else:
signatures = None
if signatures is not None:
if config.ENABLE_CUDASIM:
def jitwrapper(func):
return FakeCUDAKernel(func, device=device, fastmath=fastmath)
return jitwrapper
def _jit(func):
targetoptions = kws.copy()
targetoptions['debug'] = debug
targetoptions['lineinfo'] = lineinfo
targetoptions['link'] = link
targetoptions['opt'] = opt
targetoptions['fastmath'] = fastmath
targetoptions['device'] = device
targetoptions['extensions'] = extensions
disp = CUDADispatcher(func, targetoptions=targetoptions)
if cache:
disp.enable_caching()
for sig in signatures:
argtypes, restype = sigutils.normalize_signature(sig)
if restype and not device and restype != types.void:
raise TypeError("CUDA kernel must have void return type.")
if device:
from numba.core import typeinfer
with typeinfer.register_dispatcher(disp):
disp.compile_device(argtypes, restype)
else:
disp.compile(argtypes)
disp._specialized = specialized
disp.disable_compile()
return disp
return _jit
else:
if func_or_sig is None:
if config.ENABLE_CUDASIM:
def autojitwrapper(func):
return FakeCUDAKernel(func, device=device,
fastmath=fastmath)
else:
def autojitwrapper(func):
return jit(func, device=device, debug=debug, opt=opt,
lineinfo=lineinfo, link=link, cache=cache, **kws)
return autojitwrapper
# func_or_sig is a function
else:
if config.ENABLE_CUDASIM:
return FakeCUDAKernel(func_or_sig, device=device,
fastmath=fastmath)
else:
targetoptions = kws.copy()
targetoptions['debug'] = debug
targetoptions['lineinfo'] = lineinfo
targetoptions['opt'] = opt
targetoptions['link'] = link
targetoptions['fastmath'] = fastmath
targetoptions['device'] = device
targetoptions['extensions'] = extensions
disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions)
if cache:
disp.enable_caching()
return disp
def declare_device(name, sig):
"""
Declare the signature of a foreign function. Returns a descriptor that can
be used to call the function from a Python kernel.
:param name: The name of the foreign function.
:type name: str
:param sig: The Numba signature of the function.
"""
argtypes, restype = sigutils.normalize_signature(sig)
if restype is None:
msg = 'Return type must be provided for device declarations'
raise TypeError(msg)
return declare_device_function(name, restype, argtypes)