Initial support for writing PyTorch custom ops in Mojo

In the latest MAX nightlies, we’ve started shipping an initial interface for defining PyTorch custom ops using Mojo. For anyone experienced with PyTorch, this provides a familiar entrypoint to experiment with new Mojo-based algorithms or optimizations that target GPUs. It also allows for progressive introduction of MAX and Mojo into existing PyTorch models.

This support is provided through the CustomOpLibrary class in the max.torch Python module. You start from a Mojo package that contains Mojo custom operations defined exactly as you would for a MAX Graph (in fact, you can copy and paste code between the two). You then register this directory of Mojo operation code with the CustomOpLibrary and it will handle the remainder of the PyTorch integration for you. The Mojo compiler will be invoked to compile these operations automatically upon running the Python script.

New in today’s nightly release are a couple of examples of extending PyTorch using Mojo: a “hello world” that demonstrates how to use a Mojo custom operation in a basic PyTorch calculation, and an implementation of the Whisper audio model with one layer replaced by a Mojo operation. These examples are currently limited to running on systems with a MAX-compatible GPU.

We do have some work planned to improve the performance of this feature, so expect near-term improvements on that front. We’ve been working hard to address the issues that early users have found (special thanks to @bertaveira for identifying a few of these!), but please report here and / or file GitHub issues for any others you encounter.

7 Likes

As an update to this functionality: @stef recently added the ability to provide entire MAX graphs as PyTorch custom operators. This significantly expands the scope of what can be used from MAX within a PyTorch custom operator to include our built-in highly optimized MAX kernels all the way to full graphs representing large portions of ML models.

To provide a MAX graph to a PyTorch custom operator, use the new @graph_op decorator around a function describing the MAX graph. An example of this in action is present in a new example in the modular repository.

We hope this enables an even easier path to start experimenting with MAX inside of a familiar PyTorch context, and allows for portions of a PyTorch model to be progressively replaced with optimized MAX operations and graphs.

2 Likes