Here’s a bit of a fun proposal for metaprogramming.
Mojo parameters are a really neat take on metaprogramming – and the ability to pass functions as parameters is quite interesting (allowing e.g. other functions to provide decoration / functionality around compile-time passed functions).
This feature reminds me of some of the features of systems like JAX, and I’d like to take this post to propose an extension.
JAX is a type of eDSL in Python – whose object language is more restricted than arbitrary Python. Still, “much of Python” can be used as a metaprogramming language for JAX. Chris’ latest post (Modular: Democratizing AI Compute, Part 7: What about Triton and Python eDSLs?) talks a bit about eDSLs – but doesn’t go into much detail about how JAX works, so I’ll describe JAX a bit below.
JAX takes “JAX-compatible Python functions” and converts them into linear IR called a Jaxpr
. Once you have a Jaxpr
, you can do lots of things with it. JAX supports lowering it to XLA, or one can write an interpreter which interpreters the Jaxpr
(and can itself be staged to a Jaxpr
, etc). So JAX builds a very powerful metaprogramming system into Python. One can think of JAX as being a 2-language system with “compile time known Python computations” as the metalanguage, and “Jaxpr” as the object language (whose terms actually denote computations which can be executed by JAX’s execution model).
Transformation interpreters which drive things like AD, or other program transformations – can be written in the metalanguage, and the object language is more restricted!
It feels like some of the features of JAX could be interesting to explore with Mojo’s metaprogramming. For instance, I’ve been working on and off on a project called juju
(GitHub - femtomc/juju: From JAX computations to MAX graphs.) which takes JAX’s metaprogramming model, and pairs it with an interpreter which converts a JAX computation into a MAX graph via the MAX Python API. Suppose I was able to invoke Python computations as a compile time parameter into Mojo functions – well, then, one thing I could do is write “permissive ML code” in Python using a JAX-like system, and trace them out to a computation in a “Mojo-compatible” object language (more realistic in short term: a MAX computation object).
See, here I don’t need to support arbitrary Python dynamism at the Mojo object level! I’m allowing Python to be used as a permissive metaprogramming language – but the contract is that the execution of the Python code needs to result in an object which Mojo / MAX can understand. This is very similar to JAX’s model (one can’t return arbitrary objects, or Python functions from JAX computations, the object level is strongly typed with JAX specific types).
This type of capability feels like one approach to easing the compatibility between Python and Mojo. It also feels complementary to allowing the invocation of CPython in Mojo at the object level – for lots of ML-specific code, we don’t actually want any Python anyways. Still, the implementation of ML systems like JAX is often significantly easier with dynamism … so you have to confront this tension in Mojo if one wants to support a Mojo-native ML library. There’s various other arguments for why a Mojo-native library would be nice (related to sharp edges of eDSLs)! Happy to discuss those in this thread.
On the other hand, if I could invoke Python at compile time, and create native Mojo objects – well, no reinvention of the wheel is required – and it’d be quite easy to transition experienced Python ML folks over to Mojo gradually (for transparency – I think we’d need to think carefully about e.g. PyTorch, whose model for ML computation differs from JAX, but which is moving increasingly towards JAX with recent developments).
There are likely tons of sharp edges with allowing “invocation of an arbitrary dynamic, interpreted language at compile time” – especially if this invocation is repeated for code intended to be shared, but I’m hoping that more experienced hands can comment here