Using Mojo to build JAX GPU kernels

Hi! I’m very excited to explore Mojo for GPU kernels and to contribute to the ecosystem. I’m looking for some basic examples for how to build kernels for use from existing Python frameworks. Are there any examples that have proper FFI to JAX specifically? I’m looking for something like this:

https://docs.jax.dev/en/latest/ffi.html#ffi-calls-on-a-gpu

This topic was automatically closed 180 days after the last reply. New replies are no longer allowed.