Using Mojo to build JAX GPU kernels

Hi! I’m very excited to explore Mojo for GPU kernels and to contribute to the ecosystem. I’m looking for some basic examples for how to build kernels for use from existing Python frameworks. Are there any examples that have proper FFI to JAX specifically? I’m looking for something like this:

https://docs.jax.dev/en/latest/ffi.html#ffi-calls-on-a-gpu