Metaprogramming with Python in Mojo

Here’s a bit of a fun proposal for metaprogramming.

Mojo parameters are a really neat take on metaprogramming – and the ability to pass functions as parameters is quite interesting (allowing e.g. other functions to provide decoration / functionality around compile-time passed functions).

This feature reminds me of some of the features of systems like JAX, and I’d like to take this post to propose an extension.

JAX is a type of eDSL in Python – whose object language is more restricted than arbitrary Python. Still, “much of Python” can be used as a metaprogramming language for JAX. Chris’ latest post (Modular: Democratizing AI Compute, Part 7: What about Triton and Python eDSLs?) talks a bit about eDSLs – but doesn’t go into much detail about how JAX works, so I’ll describe JAX a bit below.

JAX takes “JAX-compatible Python functions” and converts them into linear IR called a Jaxpr. Once you have a Jaxpr, you can do lots of things with it. JAX supports lowering it to XLA, or one can write an interpreter which interpreters the Jaxpr (and can itself be staged to a Jaxpr, etc). So JAX builds a very powerful metaprogramming system into Python. One can think of JAX as being a 2-language system with “compile time known Python computations” as the metalanguage, and “Jaxpr” as the object language (whose terms actually denote computations which can be executed by JAX’s execution model).

Transformation interpreters which drive things like AD, or other program transformations – can be written in the metalanguage, and the object language is more restricted!

It feels like some of the features of JAX could be interesting to explore with Mojo’s metaprogramming. For instance, I’ve been working on and off on a project called juju (GitHub - femtomc/juju: From JAX computations to MAX graphs.) which takes JAX’s metaprogramming model, and pairs it with an interpreter which converts a JAX computation into a MAX graph via the MAX Python API. Suppose I was able to invoke Python computations as a compile time parameter into Mojo functions – well, then, one thing I could do is write “permissive ML code” in Python using a JAX-like system, and trace them out to a computation in a “Mojo-compatible” object language (more realistic in short term: a MAX computation object).

See, here I don’t need to support arbitrary Python dynamism at the Mojo object level! I’m allowing Python to be used as a permissive metaprogramming language – but the contract is that the execution of the Python code needs to result in an object which Mojo / MAX can understand. This is very similar to JAX’s model (one can’t return arbitrary objects, or Python functions from JAX computations, the object level is strongly typed with JAX specific types).

This type of capability feels like one approach to easing the compatibility between Python and Mojo. It also feels complementary to allowing the invocation of CPython in Mojo at the object level – for lots of ML-specific code, we don’t actually want any Python anyways. Still, the implementation of ML systems like JAX is often significantly easier with dynamism … so you have to confront this tension in Mojo if one wants to support a Mojo-native ML library. There’s various other arguments for why a Mojo-native library would be nice (related to sharp edges of eDSLs)! Happy to discuss those in this thread.

On the other hand, if I could invoke Python at compile time, and create native Mojo objects – well, no reinvention of the wheel is required – and it’d be quite easy to transition experienced Python ML folks over to Mojo gradually (for transparency – I think we’d need to think carefully about e.g. PyTorch, whose model for ML computation differs from JAX, but which is moving increasingly towards JAX with recent developments).

There are likely tons of sharp edges with allowing “invocation of an arbitrary dynamic, interpreted language at compile time” – especially if this invocation is repeated for code intended to be shared, but I’m hoping that more experienced hands can comment here :slight_smile:

1 Like

The biggest issue I see with this is the potential security holes, and the difficulty of making it into a nice API. In general, metaprogramming will want 2-way communication with the compiler, which might mean that all we would get is the ability to write a compiler pass in Python via the MLIR bindings. If we allow “arbitrary code at compile time”, it opens up a very large bag of worms from a security standpoint.

My preference would actually be to allow running WASI modules, since those can be very heavily sandboxed. Those are, of course, a big hammer, so I’d want something like Zig’s reflection API as the “normal path” for things like serialization. This gives Mojo compile-time capabilities that are better than what Rust has, without the security issues that come from proc macros. It might also allow shipping the WASI modules embedded inside of the .mojopackage file, which would solve some of the compile-time complaints that people have about Rust’s proc macros. Python is very difficult to sandbox effectively, which I think is an issue, since it wasn’t designed with a permissions system in mind. WASI fixes that so that your gRPC library can’t also access the internet, or so that a Swagger API client generator can’t go take a look at ~/.ssh.

This would, of course, require a Mojo to support WASI, but I don’t think that will be too difficult since it’s very POSIX shaped, and for this use-case we don’t really need async IO.

If one is allowed to execute arbitrary code at build time (via CPython) today, what is the difference between build time and compile time where compile time execution requires sandboxing?

For what it’s worth, I’m coming from an AI background – so I’ve never been super concerned about security of compute, so my questions are coming from a place of genuine curiosity.

On the topic of WASI, I don’t think of WASI as a viable target for acceleration (which is perhaps a more specific subset of codes than the proposal of “a valid Mojo computation” – for instance, a MAX graph), but also feel free to send me some references to update my perspective, if available.

Build time python means that the mojopackage file is still an inert library, which means it’s safe to handle. If the LSP loading a mojopackage library can run python code, it can grab ssh keys or browser cookies.

WASI isn’t really something you would use for acceleration, it’s a box you put code in when you want to be able to put security controls on what the code can do. It’s also something more amenable to being embedded inside of a compiler than python. If we want the full “generate source code or an AST by running fully arbitrary computation” power, like Rust’s proc macros, then the code that is run needs 2-way communication with the compiler.

Mojo is perfectly capable of assembling a MAX graph without python being involved at all, and I think we should strive for that. Mojo’s more powerful type systems mean that we can actually give compiler errors for problems that would only show up once you had deployed Python code, since if we know the dimensions of a tensor at compile time, it makes it much easier to type things. ML libs like JAX may be easier in Python, but that’s what we have MAX for, Modular has already done the hard part of making a ML compiler.

Thanks for context on first two points.

I disagree with you in several ways. For context on experience, I’ve been an active user of machine learning frameworks and languages for about 8 years now.

Striving for this picture requires understanding why ML libraries in Python have achieved mindshare in the first place, and the notion that ML /AI researchers are badgering on with Python until a “strongly typed array language” comes along is not accurate. Furthermore, JAX does report shape errors during tracing. My overwhelming experience is that ML / AI researchers don’t care if “it is much easier to type things”.

I don’t think of MAX as being analogous to JAX. I think of MAX as being like XLA, but infinitely more programmable (along several dimensions). JAX isn’t XLA, it’s a convenient & dynamic programming environment with a pretty-good array programming model, plus a model for composable program transformations which provide automation for user-directed vectorization and AD with lowering to XLA.

The dynamism is a feature, not a bug, for ML / AI researchers. “Making an ML compiler” is not the only hard part of constructing a useful machine learning framework.

What would such a framework look like for mojo? Well, it’d probably be based out of def – are we going to replicate JAX in def? What parts of JAX? Do we want PyTrees (probably, because that’s a very convenient way to bundle up neural network parameters)? Do we want vmap and grad automation? Is grad going to be implemented using forward mode + YOLO (convenient for people who want to define their own forward mode rules)? All questions you’d have to answer in the design of a new ML framework in native Mojo.

It may seem trivial to make such a thing on top of an ML compiler, but it is not at all trivial. After all, an entire team is dedicated to JAX.

Would be happy to sketch out implementations of several of the above pieces if object was significantly refactored! Of course, then you’ll still have to wait for the libraries which build the things that people actually want to use on top of the framework … it’s easy to say “just replicate the thing and make it better in a new language” (look at the series of missteps in Julia AD for a brief and brutal history of exactly this type of thinking), but one needs serious people power to pull that off well.

1 Like

My overwhelming experience is that ML / AI researchers don’t care if “it is much easier to type things”.

Strongly typed is a two-way street. If the API is strongly typed, then it’s much, much easier to learn the API, since it’s “which variant of this enum do I want to put in here?” not “what was that magic string again?”. I know of a lot of researchers who refuse to move from pytorch to JAX simply because learning the API is annoying, but these same people will write data cleaning or postprocessing code in Haskell.

Additionally, right now production training clusters spend massive amounts of CPU on data loading. This is something that we want to optimize as much as possible. For me, the ideal end state for ML in MAX is that you use Python or Mojo to inspect the environment and talk to the cluster manager, then you set up a huge MAX graph which does everything else and is optimized for your environment, and you don’t leave that MAX graph for however many weeks training runs for.

I don’t think of MAX as being analogous to JAX. I think of MAX as being like XLA, but infinitely more programmable (along several dimensions). JAX isn’t XLA, it’s a convenient & dynamic programming environment with a pretty-good array programming model, plus a model for composable program transformations which provide automation for user-directed vectorization and AD with lowering to XLA.

From my perspective, MAX is programmable enough, especially considering the plan to be able to use arbitrary types with MAX, that it’s worthwhile to re-examine how necessary some python framework features are. Correct me if I’m mistaken, but from what I know vmap is mostly a workaround to express for loops in a functional style, and it’s equivalent to Mojo’s algorithm.functional.vectorize, possibly with some parallelize thrown in. grad is a bit harder, and might involve either talking to a MAX device to help calculate it, but that should still be doable. I’m not sure I understand why PyTrees are desirable vs structs. If we have a good way to convert a well-typed function into a custom op, once single source works (which Brad has mentioned is a planned feature), I think that drastically lowers the amount of stuff which needs to be “framework ops”. Python ML frameworks put in a herculean effort to make Python fast, but Mojo is already fast, so some of those design challenges should go away. In my opinion, MAX is a good chance to start with a blank sheet of paper and force each convention and operator to justify itself. Does vmap make sense in a world where vectorize exists? Do we need special bundles to pass around NN parameters instead of just passing around structs and arrays of structs? I think there are things from JAX that make a lot of sense to try to bring into a MAX context, such as parallelism with sharding constraints.

Also, it’s important to consider that the more of the framework that exists in fn, the easier it will be to have the LSP and debugger help people. The internals of object are fairly nasty to inspect with a debugger, so they might not want to deal with them.

2 Likes

Great comments!

you don’t leave that MAX graph for however many weeks training runs for.

Agree with everything else in your first two paragraphs: I may be overreaching based on the current state of AI, but I’ve thought of this type of workflow as overspecializing to the current state of AI. As test time compute has become more popular, less people are spending time training “the big models” and more playing around with structured search processes on top of “the big models”.

Correct me if I’m mistaken, but from what I know vmap is mostly a workaround to express for loops in a functional style, and it’s equivalent to Mojo’s algorithm.functional.vectorize

I think this might be about right – the language which vmap is defined on is pure, so it’s a bit more restrictive. Management of memory is out of the hands of the user of vmap (and JAX, generally – although pallas starts to put things back into the user’s grasp).

grad is a bit harder, and might involve either talking to a MAX device to help calculate it

Yes, unfortunately grad is a very subtle piece of automation. JAX does well by restricting attention to a pure, first-order language with array operations. There, grad is very convenient to express, and covers the entire language.

I’d argue separately (in another discussion) that this is the right model for grad (and see Julia for several years of trying to break out of this model, which (my opinion) failed, especially compared to JAX). In short, “grad on a DSL” feels like best bang for the buck.

I’m not sure I understand why PyTrees are desirable vs structs

The Pytree stuff is related to JAX’s restricted object language. You want to feel expressive when you’re using Python to metaprogram JAX, but JAX only understands computations on lists of arrays – so Pytree lets you break down structs into list of arrays and re-build from lists of arrays.

By virtue of participating in this “deconstruct / reconstruct” interface, Pytrees are “structs-of-arrays” by default.

I’d almost recommend a similar metaprogramming trick in Mojo for vectorize. I’d like to be able to write a structure whose actual data components can be either float or array[float, ...] (pseudo-types) – the data elements would be batched out by vectorize as required. This eliminates the distinction between struct and struct-of-arrays when you’re programming in the “vectorizable” model.

Python ML frameworks put in a herculean effort to make Python fast, but Mojo is already fast, so some of those design challenges should go away. In my opinion, MAX is a good chance to start with a blank sheet of paper and force each convention and operator to justify itself.

I strongly agree with this sentiment, and thanks for reminding me of it.

Also, it’s important to consider that the more of the framework that exists in fn, the easier it will be to have the LSP and debugger help people. The internals of object are fairly nasty to inspect with a debugger, so they might not want to deal with them.

Great point – and strongly agree. JAX, like PyTorch, has known nasty tracebacks (endemic for eDSLs, especially).