Hi all!
I’m trying to express a particular idiom within a GPU kernel, where I grab an index out of a tensor, and then use it to index into another tensor.
The full code is below for reference, but below the code I show the relevant segment.
# GPU kernel for computing cross-covariance matrix elements
fn compute_cross_covariance_kernel[
n: Int
](
W_tensor: LayoutTensor[mut=True, float32, Layout.row_major(3, 3)],
correspondences_tensor: LayoutTensor[
mut=True, DType.int32, Layout.row_major(n)
],
distances_tensor: LayoutTensor[mut=True, float32, Layout.row_major(n)],
source_centered_tensor: LayoutTensor[
mut=False, float32, Layout.row_major(n, 3)
],
target_centered_tensor: LayoutTensor[
mut=False, float32, Layout.row_major(n, 3)
],
max_dist: Float32,
):
# Each block computes one element of the 3x3 matrix
row = block_idx.x
col = block_idx.y
if row >= 3 or col >= 3:
return
# Shared stack memory for partial sums
var shared = stack_allocation[
THREADS_PER_BLOCK,
Scalar[float32],
address_space = AddressSpace.SHARED,
]()
tid = thread_idx.x
var sum = 0.0
# Each thread processes multiple correspondences
var idx = tid
while idx < n:
# Check if correspondence is valid (within distance threshold)
if distances_tensor[idx] < max_dist:
var target_idx = correspondences_tensor[idx]
s_val = source_centered_tensor[idx, row]
t_val = target_centered_tensor[target_idx, col]
sum += s_val * t_val
idx += THREADS_PER_BLOCK
# Store in shared memory and reduce
shared[tid] = sum
barrier()
# Parallel reduction
var stride = THREADS_PER_BLOCK // 2
while stride > 0:
if tid < stride:
shared[tid] = shared[tid] + shared[tid + stride]
barrier()
stride //= 2
# Thread 0 writes final result
if tid == 0:
W_tensor[row, col] = shared[0]
The relevant part is here:
var target_idx: Int = correspondences_tensor[idx]
s_val = source_centered_tensor[idx, row]
t_val = target_centered_tensor[target_idx, col]
sum += s_val * t_val
trying this as-is runs into a resolution error (as far as I can tell).
Is there a known idiom for this? I did a cursory search through the GPU puzzles but not experienced enough to find a good match.