"No active MLIR context" with new `CustomOpLibrary` torch integration

Hi, I know it is very new but I was trying out the new CustomOpLibrary torch interface. However I cannot get it to work with any example. I always get the same error:

File "/.../.venv/lib/python3.12/site-packages/max/graph/type.py", line 805, in to_mlir
    self.shape.to_mlir(), self.dtype, self.device.to_mlir()
    ^^^^^^^^^^^^^^^^^^^^
  File "/.../.venv/lib/python3.12/site-packages/max/graph/type.py", line 494, in to_mlir
    shape_type = mosh.ShapeType()
                 ^^^^^^^^^^^^^^^^
RuntimeError: No active MLIR context

As suggested on Discord I made a simple reprduceable example to post here. here is my setup. The folder structure is:

pyproject.toml
example.py
- kernels
--- __init__.mojo (empty)
--- kernel.mojo

pyproject.toml:

[project]
name = "example"
version = "0.0.0"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
    "torch>=2.6.0",
    "pillow>=11.2.1, <12",
    "modular>=25.4.0.dev2025052105",
]

[tool.uv]
[[tool.uv.index]]
url = "https://dl.modular.com/public/nightly/python/simple/"

example.py

from pathlib import Path
import torch
from max.torch import CustomOpLibrary

TILE_SIZE = 16
# Register Mojo kernels in Torch
mojo_kernels = Path(__file__).parent / "kernels"
op_library = CustomOpLibrary(mojo_kernels)
add_const_kernel = op_library.add_const[
    {
        "const": 10
    }
]

def add_const(x: torch.Tensor) -> torch.Tensor:
    result = torch.zeros_like(x)
    add_const_kernel(result, x)
    return result

if __name__ == "__main__":
    x = torch.randn(10).cuda()

    print(add_const(x))

kernel.mojo

import compiler
from gpu import thread_idx, block_idx, barrier
from layout import Layout, LayoutTensor, UNKNOWN_VALUE
from runtime.asyncrt import DeviceContextPtr
from math import ceildiv
from tensor import InputTensor, OutputTensor

alias BLOCK_SIZE = 32
alias Dyn1DLayout = Layout.row_major(UNKNOWN_VALUE)
alias dtype = DType.float32

@compiler.register("add_const")
struct AddConst:
    @staticmethod
    fn execute[
        const: Int,
        target: StaticString,
    ](
        # Outputs
        result: OutputTensor[type = DType.float32, rank=1],
        # Inputs
        x: InputTensor[type = DType.float32, rank=1],
        # Context
        ctx: DeviceContextPtr,
    ) raises:
        x_tensor = x.to_layout_tensor()
        result_tensor = result.to_layout_tensor()

        @parameter
        if target == "cpu":
            raise Error("Rasterize3DGS CPU target not implemented yet.")
        elif target == "gpu":
            # Get GPU context
            var gpu_ctx = ctx.get_device_context()

            # Define grid and block dimensions for the kernel launch
            var grid = (ceildiv(x.dim_size(0), BLOCK_SIZE))
            var block = (BLOCK_SIZE)

            gpu_ctx.enqueue_function[add_const_kernel[const]](
                x_tensor,
                result_tensor,
                grid_dim=grid,
                block_dim=block,
            )

        else:
            raise Error("Unsupported target:", target)


fn add_const_kernel[
    const: Int
](
    x: LayoutTensor[dtype, Dyn1DLayout, MutableAnyOrigin],
    result: LayoutTensor[dtype, Dyn1DLayout, MutableAnyOrigin],
):
    i = block_idx.x * BLOCK_SIZE + thread_idx.x
    result[i] = x[i] + const

I also tried without using the UNKOWN_VALUE but I always get the same issue. Anything I might be doing wrong here?

1 Like

So this worked for me in a modified version of your example:

operations/
    __init__.mojo
    add_one.mojo
example.py
mojoproject.toml

example.py

from pathlib import Path
import torch
from max.torch import CustomOpLibrary

# Register Mojo kernels in Torch
mojo_kernels = Path(__file__).parent / "operations"
op_library = CustomOpLibrary(mojo_kernels)
add_const_kernel = op_library.add_constant_custom[
    {
        "value": 10
    }
]

def add_const(x: torch.Tensor) -> torch.Tensor:
    result = torch.zeros_like(x)
    add_const_kernel(result, x)
    return result

if __name__ == "__main__":
    x = torch.randn(10).cuda()

    print(add_const(x))

add_one.mojo

import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor_internal import (
    InputTensor,
    ManagedTensorSlice,
    OutputTensor,
    foreach,
)

from utils.index import IndexList


@compiler.register("add_constant_custom")
struct AddConstantCustom[value: Int]:
    @staticmethod
    fn execute[
        target: StaticString,
    ](
        out: OutputTensor,
        x: InputTensor[type = out.type, rank = out.rank],
        ctx: DeviceContextPtr,
    ) raises:
        @parameter
        @always_inline
        fn add_constant[
            width: Int
        ](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
            return x.load[width](idx) + value

        foreach[add_constant, target=target](out, ctx)

    @staticmethod
    fn shape(
        x: InputTensor,
    ) raises -> IndexList[x.rank]:
        raise "NotImplemented"

mojoproject.toml

[project]
authors = ["Modular, Inc. <hello@modular.com>"]
channels = ["https://conda.modular.com/max-nightly", "https://conda.modular.com/max", "https://repo.prefix.dev/modular-community", "conda-forge", "pytorch"]
name = "pytorch-test"
platforms = ["linux-64"]
version = "0.1.0"

[tasks]
example = "python example.py"

[dependencies]
max = "*"
pytorch = {version = ">=2.5.0,<=2.7.0", channel = "pytorch"}

Now, that uses foreach over direct GPU functions, but it does seem to compile and run correctly. We’re looking into what’s going on in your specific functions, though.

I’ve tried it with uv and its simplified version failed with CUDA call failed: CUDA_ERROR_ILLEGAL_ADDRESS (an illegal memory access was encountered).

Repro uv init example && cd example. Then using

[project]
name = "example"
version = "0.0.0"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
    "torch>=2.6.0",
    "pillow>=11.2.1, <12",
    "modular>=25.4.0.dev2025052105",
]

[tool.uv]
[[tool.uv.index]]
url = "https://dl.modular.com/public/nightly/python/simple/"

and kernel.mojo

import compiler
from gpu import thread_idx, block_idx, block_dim, barrier
from layout import Layout, LayoutTensor, UNKNOWN_VALUE
from runtime.asyncrt import DeviceContextPtr
from math import ceildiv
from gpu.host import DeviceBuffer
from tensor import InputTensor, OutputTensor
from memory import UnsafePointer

alias BLOCK_SIZE = 32
alias Dyn1DLayout = Layout.row_major(32)
alias dtype = DType.float32

@compiler.register("add_const")
struct AddConst:
    @staticmethod
    fn execute[
        target: StaticString,
    ](
        # Outputs
        result: OutputTensor[type = DType.float32, rank=1],
        # Inputs
        x: InputTensor[type = DType.float32, rank=1],
        # Context
        ctx: DeviceContextPtr,
    ) raises:
        x_tensor = x.to_layout_tensor()
        result_tensor = result.to_layout_tensor()

        @parameter
        if target == "cpu":
            raise Error("Rasterize3DGS CPU target not implemented yet.")
        elif target == "gpu":
            # Get GPU context
            var gpu_ctx = ctx.get_device_context()

            # Define grid and block dimensions for the kernel launch
            var grid = (ceildiv(x.dim_size(0), BLOCK_SIZE))
            var block = (BLOCK_SIZE)

            gpu_ctx.enqueue_memset(
                DeviceBuffer[result.type](
                    gpu_ctx,
                    rebind[UnsafePointer[Scalar[result.type]]](result_tensor.ptr),
                    x.dim_size(0),
                    owning=False,
                ),
                0,
            )

            gpu_ctx.enqueue_function[add_const_kernel](
                x_tensor,
                result_tensor,
                grid_dim=grid,
                block_dim=block,
            )

        else:
            raise Error("Unsupported target:", target)


fn add_const_kernel(
    x: LayoutTensor[dtype, Dyn1DLayout, MutableAnyOrigin],
    result: LayoutTensor[dtype, Dyn1DLayout, MutableAnyOrigin],
):
    i = block_idx.x * block_dim.x + thread_idx.x
    if i < x.dim[0]():
        result[i] = x[i] + 10

and main.py

from pathlib import Path
import torch
from max.torch import CustomOpLibrary

mojo_kernels = Path(__file__).parent / "kernels"
op_library = CustomOpLibrary(mojo_kernels)
add_const_kernel = op_library.add_const

def add_const(x: torch.Tensor) -> torch.Tensor:
    result = torch.zeros_like(x, dtype=x.dtype, device=x.device)
    add_const_kernel(result, x)
    return result

if __name__ == "__main__":
    x = torch.randn(10).cuda()
    print(add_const(x))

then run uv run python main.py

1 Like

Hi Bernardo, thanks so much for the easy repro! I was able to reproduce and fix the issue, it might not get into the nightly today but if it doesn’t it will tomorrow.

In the meantime you can work around by adding

from max import mlir
mlir.Context().__enter__()

to your script anywhere before add_const_kernel = op_library.add_const[....

3 Likes

Combining Mojo&MAX with PyTorch is fascinating. I tried customizing a simple pattern matching for Inductor, and it worked. By the way, can Mojo kernels support AOT (Ahead-of-Time) compilation?

import compiler
from max.tensor import InputTensor, OutputTensor, foreach
from runtime.asyncrt import DeviceContextPtr

from utils.index import IndexList


@compiler.register("custom_pow2_add")
struct CustomPow2Add:
    @staticmethod
    def execute[
        target: StaticString
    ](
        output: OutputTensor,
        x: InputTensor[type = output.type, rank = output.rank],
        y: InputTensor[type = output.type, rank = output.rank],
        ctx: DeviceContextPtr,
    ):
        @parameter
        @always_inline
        fn run[width: Int](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
            return x.load[width](idx) ** 2 + y.load[width](idx)

        foreach[run, target=target](output, ctx)
import torch
from torch._inductor.pattern_matcher import (
    fwd_only,
    PatternMatcherPass,
    register_replacement,
)
from typing import Callable, Iterable

from pathlib import Path
from max.torch import CustomOpLibrary


mojo_kernels = Path(__file__).parent / "mojo_kernels"
op_library = CustomOpLibrary(mojo_kernels)
custom_pow2_add = op_library.custom_pow2_add


def custom_op(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    print("custum_op")
    result = torch.zeros_like(a)
    custom_pow2_add(result, a, b)
    return result


def pattern(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    c = a**2
    return c + b


patterns = PatternMatcherPass()
inputs = (torch.randn(10, 10), torch.randn(10, 10))
register_replacement(pattern, custom_op, inputs, fwd_only, patterns)

count = 0


def custom_pass(graph: torch.fx.graph):
    global count
    count = patterns.apply(graph)


def custom_backend(
    graph: torch.fx.GraphModule, example_inputs: Iterable[torch.Tensor]
) -> Callable:
    from torch._inductor import config

    current_config = config.get_config_copy()
    from torch._inductor.compile_fx import compile_fx

    current_config["post_grad_custom_post_pass"] = custom_pass
    return compile_fx(graph, example_inputs, config_patches=current_config)


@torch.compile(backend=custom_backend)
def f_mojo(x: torch.Tensor, y: torch.tensor) -> torch.Tensor:
    return x**2 + y


def f_torch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x**2 + y


if __name__ == "__main__":
    inp1 = torch.rand(3, 5)
    inp2 = torch.rand(3, 5)
    print(f_mojo(inp1, inp2))
    print(f_torch(inp1, inp2))
    print(count)
1 Like

Thank you all for the help. nightly seems to have fixed the No MLIR issue but then as pointed out by @Ehsan there was an illegal CUDA access issue. I tried running your fixed example with uv and now I get a new issue:

File "/.../.venv/lib/python3.12/site-packages/max/engine/api.py", line 526, in load
    _model = self._impl.compile_from_object(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Failed to run the MOToMGP pass manager:
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: failed to run the pass manager for offload functions
/.../kernels/kernel.mojo:67:20: error: call expansion failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:2337:8: note: function instantiation failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:2378:10: note: call expansion failed
note: constraint failed: This method only works with tensors that have depth-1 layouts (no nested shapes).
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: Could not elaborate the provided code: failed to run the pass manager
error: The graph compiler tried to JIT compile the provided kernels but failed during elaboration

The code is an exact replica of the one posted by @Ehsan except I changed to the latest nightly to fix the No MLIR issue: modular>=25.4.0.dev2025052116

Seems like the issue is now with this line:

if i < x.dim[0]():

Feeding this size into the function as an argument fixes it but I wonder why this is a problem?

So as mentioned before could not check the LayoutTensor dimension inside the kernel function for some reason. nevertheless now I was experimenting with adding one dimension and I cannot make it work without getting:

CUDA call failed: CUDA_ERROR_ILLEGAL_ADDRESS (an illegal memory access was encountered)

I am wondering if I bumped into a new bug or if I am doing something wrong. I modified this simple example and verified the same happened on this very simple example. On other examples I was experiemnting with anything I try to index a LayoutTensor with over 1 dimension this seems to happen.

kernel.mojo

import compiler
from gpu import thread_idx, block_idx, block_dim, barrier
from layout import Layout, LayoutTensor, UNKNOWN_VALUE
from runtime.asyncrt import DeviceContextPtr
from math import ceildiv
from gpu.host import DeviceBuffer
from tensor import InputTensor, OutputTensor
from memory import UnsafePointer

alias BLOCK_SIZE = 32
alias Dyn2DLayout = Layout.row_major(UNKNOWN_VALUE, UNKNOWN_VALUE)
alias dtype = DType.float32

@compiler.register("add_const")
struct AddConst:
    @staticmethod
    fn execute[
        target: StaticString,
    ](
        # Outputs
        result: OutputTensor[type = DType.float32, rank=2],
        # Inputs
        x: InputTensor[type = DType.float32, rank=2],
        # Context
        ctx: DeviceContextPtr,
    ) raises:
        x_tensor = x.to_layout_tensor()
        result_tensor = result.to_layout_tensor()

        @parameter
        if target == "cpu":
            raise Error("Rasterize3DGS CPU target not implemented yet.")
        elif target == "gpu":
            # Get GPU context
            var gpu_ctx = ctx.get_device_context()

            # Define grid and block dimensions for the kernel launch
            var grid = (ceildiv(x.dim_size(0), BLOCK_SIZE), ceildiv(x.dim_size(1), BLOCK_SIZE))
            var block = (BLOCK_SIZE, BLOCK_SIZE)

            gpu_ctx.enqueue_function[add_const_kernel](
                x_tensor,
                result_tensor,
                x.dim_size(0),
                grid_dim=grid,
                block_dim=block,
            )

        else:
            raise Error("Unsupported target:", target)


fn add_const_kernel(
    x: LayoutTensor[dtype, Dyn2DLayout, MutableAnyOrigin],
    result: LayoutTensor[dtype, Dyn2DLayout, MutableAnyOrigin],
    size: Int,
):
    i = block_idx.x * block_dim.x + thread_idx.x
    j = block_idx.y * block_dim.y + thread_idx.y
    if i < size and j < size:
        result[i, j] = x[i, j] + 10

example.py

from pathlib import Path
import torch
from max.torch import CustomOpLibrary

mojo_kernels = Path(__file__).parent / "kernels"
op_library = CustomOpLibrary(mojo_kernels)
add_const_kernel = op_library.add_const

def add_const(x: torch.Tensor) -> torch.Tensor:
    result = torch.zeros_like(x, dtype=x.dtype, device=x.device)
    add_const_kernel(result, x)
    return result

if __name__ == "__main__":
    x = torch.randn(10,10).cuda()
    print(add_const(x))

pyproject.toml

[project]
name = "example"
version = "0.0.0"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
    "torch>=2.6.0",
    "pillow>=11.2.1, <12",
    "modular>=25.4.0.dev2025052116",
]

[tool.uv]
[[tool.uv.index]]
url = "https://dl.modular.com/public/nightly/python/simple/"

Good news with the new nightly (rm uv.lock in case) it works as expected!

import compiler
from gpu import thread_idx, block_idx, block_dim, barrier
from layout import Layout, LayoutTensor, UNKNOWN_VALUE
from runtime.asyncrt import DeviceContextPtr
from math import ceildiv
from gpu.host import DeviceBuffer
from tensor import InputTensor, OutputTensor
from memory import UnsafePointer

alias BLOCK_SIZE = 32
alias Dyn1DLayout = Layout.row_major(UNKNOWN_VALUE)
alias dtype = DType.float32

@compiler.register("add_const")
struct AddConst:
    @staticmethod
    fn execute[
        target: StaticString,
    ](
        # Outputs
        result: OutputTensor[type = DType.float32, rank=1],
        # Inputs
        x: InputTensor[type = DType.float32, rank=1],
        # Context
        ctx: DeviceContextPtr,
    ) raises:
        x_tensor = x.to_layout_tensor()
        result_tensor = result.to_layout_tensor()

        @parameter
        if target == "cpu":
            raise Error("Rasterize3DGS CPU target not implemented yet.")
        elif target == "gpu":
            # Get GPU context
            var gpu_ctx = ctx.get_device_context()

            # Define grid and block dimensions for the kernel launch
            var grid = (ceildiv(x.dim_size(0), BLOCK_SIZE))
            var block = (BLOCK_SIZE)

            gpu_ctx.enqueue_memset(
                DeviceBuffer[result.type](
                    gpu_ctx,
                    rebind[UnsafePointer[Scalar[result.type]]](result_tensor.ptr),
                    x.dim_size(0),
                    owning=False,
                ),
                0,
            )

            gpu_ctx.enqueue_function[add_const_kernel](
                x_tensor,
                result_tensor,
                grid_dim=grid,
                block_dim=block,
            )

        else:
            raise Error("Unsupported target:", target)


fn add_const_kernel(
    x: LayoutTensor[dtype, Dyn1DLayout, MutableAnyOrigin],
    result: LayoutTensor[dtype, Dyn1DLayout, MutableAnyOrigin],
):
    i = block_idx.x * block_dim.x + thread_idx.x
    if i < x.dim[0]():
        result[i] = x[i] + 10

however, I get

ValueError: Failed to run the MOToMGP pass manager:
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: failed to run the pass manager for offload functions
/home/ubuntu/workspace/tmp/example/kernels/kernel.mojo:63:4: error: call expansion failed
/home/ubuntu/workspace/tmp/example/kernels/kernel.mojo:63:4: note: function instantiation failed
/home/ubuntu/workspace/tmp/example/kernels/kernel.mojo:68:20: note: call expansion failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:2337:8: note: function instantiation failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:2378:10: note: call expansion failed
note: constraint failed: This method only works with tensors that have depth-1 layouts (no nested shapes).
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: Could not elaborate the provided code: failed to run the pass manager
error: The graph compiler tried to JIT compile the provided kernels but failed during elaboration

for the original case with const

@bertaveira it turns out the main issue is in dynamic layout since UNKNOWN_VALUE is represented as -1 then this ceildiv(x.dim_size(0), BLOCK_SIZE) makes the grid_dim to be zero so is an invalid error. We’ll be working on showing much better error messages. In the meantime, dynamic tensors can’t work like that and need to directly specify the grid_dim.

But then there is no way to get the size of a dynamic tensor in runtime? We have to pass the sizes always as argument? Doesn’t that somewhat defeat the niceness of using layout tensors to begin with? Layout tensors will know in runtime since they have to so why can’t we access that information?

Any developments or ideas how to solve this of get ingthe shape of an InputTensor with UNKOWN_VALUE shape?

I was waiting to see if it was added in a nightly and also tried to feed it as an int argument. But apparently with CustomOps one can only have InputTensors and no other arguments. The only way I see is to make it a parameter. But that defeats the whole point of using UNKOWN_VALUE.

Am I missing something here? Seems like a pretty huge gap that makes it impossible to trully have runtime size tensors with custom ops since this even prevents us from deciding the grid size since there is no way to know the dimensions of the UNKOWN_VALUE inputs

These are great questions! I’m delegating to @stef who’s dug into this earlier.

@bertaveira thanks for following up! I think some wires got crossed in our response here.

As you point out, using a dynamic shape value at runtime is a really common use case, and is the entire point of a dynamic layout!

  • The dim functions on LayoutTensor provide access to dynamic dimension values
  • Be careful to note that x.shape[dim]() is static and does not know the dynamic layout value! This is what @Ehsan was referring to in his previous post, it will return -1 for an UNKNOWN_VALUE
  • ManagedTensorSlice works a bit differently, where x.shape() returns an IndexList, x.dim_size[dim]() returns statically-known dimension sizes, and x.dim_size(dim) returns dynamically-known dimension sizes.

In the code you posted this is not a problem, but was a thing I happened to stumble on during reproduction. As long as you use one of the dimension functions which reports dynamic size (which you were!), you won’t run into this problem. We’re also merging better error checking for this so it will fail with a good error message rather than reporting a generic GPU memory error :slight_smile:

The specific issue you were encountering in this example was a subtle bug in our compiler stack related to unsafely capturing values which had no GPU representation and trying to send them across the device boundary, which then caused a GPU memory error when running the function. This has been fixed and your example should work in the latest nightly. Please update if you’re still having trouble!

FYI we also recently open sourced our entire kernel library with thousands of high performance examples :smiley: You can see plenty of usages of many of these features on github, for instance here are usages of grid_dim, here’s an example of a multi-head attention kernel with a dynamic batch size as part of the grid_dim, and here’s an example of a fused attention implementation which uses a parameterized kernel to use “dynamic” static shape info for “grid_dim” – it will compile a specialized implementation per input layout.

2 Likes

(Also, keep an eye out for scalar inputs to custom ops coming to a nightly near you soon! :wink: though you definitely don’t need it for passing dynamic shape info!)