Viewing File: /home/ubuntu/combine_ai/combine/lib/python3.10/site-packages/torch/_inductor/triton_helpers.py
import triton
import triton.language as tl
@triton.jit
def promote_to_tensor(x):
# Addition promotes to tensor for us
return x + tl.zeros((1,), tl.int1)
@triton.jit
def is_floating(x):
return promote_to_tensor(x).dtype.is_floating()
@triton.jit
def _prod_accumulate(a, b):
return a * b
@triton.jit
def prod(input, axis):
return tl.reduce(input, axis, _prod_accumulate)
@triton.jit
def minimum(a, b):
mask = a < b
if is_floating(a):
mask |= a != a
return tl.where(mask, a, b)
@triton.jit
def maximum(a, b):
mask = a > b
if is_floating(a):
mask |= a != a
return tl.where(mask, a, b)
@triton.jit
def min2(a, dim):
return tl.reduce(a, dim, minimum)
@triton.jit
def max2(a, dim):
return tl.reduce(a, dim, maximum)
@triton.jit
def minimum_with_index(a_value, a_index, b_value, b_index):
mask = a_value < b_value
equal = a_value == b_value
if is_floating(a_value):
a_isnan = a_value != a_value
b_isnan = b_value != b_value
mask |= a_isnan and not b_isnan
# Consider NaNs as equal
equal |= a_isnan and b_isnan
# Prefer lowest index if values are equal
mask |= equal & (a_index < b_index)
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
@triton.jit
def maximum_with_index(a_value, a_index, b_value, b_index):
mask = a_value > b_value
equal = a_value == b_value
if is_floating(a_value):
a_isnan = a_value != a_value
b_isnan = b_value != b_value
mask |= a_isnan and not b_isnan
# Consider NaNs as equal
equal |= a_isnan and b_isnan
# Prefer lowest index if values are equal
mask |= equal & (a_index < b_index)
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
@triton.jit
def min_with_index(value, index, dim):
return tl.reduce((value, index), dim, minimum_with_index)
@triton.jit
def max_with_index(value, index, dim):
return tl.reduce((value, index), dim, maximum_with_index)
@triton.jit
def welford_reduce(value, mean, m2, weight):
delta = value - mean
new_weight = weight + 1
new_mean = mean + delta / new_weight
return (
new_mean,
m2 + delta * (value - new_mean),
new_weight,
)
@triton.jit
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
delta = mean_2 - mean_1
new_weight = weight_1 + weight_2
w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
return (
mean_1 + delta * w2_over_w,
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
new_weight,
)
@triton.jit
def welford(mean, m2, weight, dim):
return tl.reduce((mean, m2, weight), dim, welford_combine)
@triton.jit
def device_assert_then(cond, msg, r):
tl.device_assert(cond, msg)
return r
@triton.jit
def randint64(seed, offset, low, high):
r0, r1, r2, r3 = tl.randint4x(seed, offset)
r0 = r0.to(tl.uint64)
r1 = r1.to(tl.uint64)
result = r0 | (r1 << 32)
size = high - low
result = result % size.to(tl.uint64)
result = result.to(tl.int64) + low
return result
@triton.jit
def _any_combine(a, b):
return a | b
@triton.jit
def any(a, dim):
return tl.reduce(a, dim, _any_combine)
@triton.jit
def bucketize_binary_search(
values, # 1D tensor
offsets_ptr,
indexing_dtype,
right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
OFFSETS_SIZE: int,
BLOCK_SHAPE, # tuple/list of block shape
):
"""
See [Note: Inductor bucketize op]
"""
low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
full_range = OFFSETS_SIZE + 1
while full_range > 1:
mid = (high + low) // 2
mask = mid < OFFSETS_SIZE
bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)
if right:
is_above = values >= bucket_upper_bound
else:
is_above = values > bucket_upper_bound
low = tl.where(is_above & mask, mid + 1, low)
high = tl.where(is_above, high, mid)
full_range = (full_range + 1) // 2
return low
Back to Directory
File Manager