The code I want to enable all follows this basic pattern:
# **********************************
# MAIN TIME EVOLUTION LOOP
# **********************************
for step in range(1, nsteps + 1):
phi_old.fill_boundary(geometry)
var update_mfi = phi_old.mfiter()
while update_mfi.is_valid():
var bx = update_mfi.validbox()
var phi_old_array = phi_old.array(update_mfi)
var phi_new_array = phi_new.array(update_mfi)
var tile_dx = dx.copy()
def advance_cell(
i: Int, j: Int, k: Int
) raises {var phi_new_array^, var phi_old_array^, var tile_dx^, var dt,}:
phi_new_array[i, j, k] = phi_old_array[i, j, k] + dt * (
(phi_old_array[i + 1, j, k] - 2.0 * phi_old_array[i, j, k] + phi_old_array[i - 1, j, k])
/ (tile_dx.x * tile_dx.x)
+ (phi_old_array[i, j + 1, k] - 2.0 * phi_old_array[i, j, k] + phi_old_array[i, j - 1, k])
/ (tile_dx.y * tile_dx.y)
+ (phi_old_array[i, j, k + 1] - 2.0 * phi_old_array[i, j, k] + phi_old_array[i, j, k - 1])
/ (tile_dx.z * tile_dx.z)
)
ParallelFor(advance_cell, bx)
update_mfi.next()
time = time + dt
phi_old.copy_from(phi_new, 0, 0, 1)
In this example, advance_cell requires a value capture for dt. That is a common pattern, and yes, in my code (and many others), value captures are absolutely necessary. My understanding of the way this analogous functionality works with CUDA is that dt and other value captures (for CUDA extended lambdas) are automatically translated into device function arguments by the compiler and copied to global memory by the runtime.
In Kokkos (and other CUDA C++ abstractions), switching between CPU/GPU is hidden inside ParallelFor. There is a CPU version of ParallelFor and GPU version of ParallelFor, and which one is used is selected based on ifdef macros that are set by the build system. The comptime if selector would be perfectly fine in Mojo for my purposes.
The issue is really that codes are written such that value captures are really necessary, since control flow is done in host code and sets variables accordingly that need to be copied implicitly for device kernels to use.
The other reason why the above pattern won’t work is that almost every kernel has a different set of work arrays it uses. For instance, if they were rewritten as free functions (that is, without capturing anything), all 89 kernels in my code would not conform to a single function signature (there are not quite 89 unique function signatures, but there are at least 20-30 unique function signatures). It is not feasible to have separate iterators/kernel launch helpers for every possible set of input arguments that all of the kernels need. So, as far as I understand how Mojo works, DevicePassable closures are the only way to handle this.