How to package/interface with a GPU kernel with dynamic sized tensors (dynamic LayoutTensor)

Hi everyone, thank you for the awesome work on the GPU kernel development. A more generic and easy to distribute alternative to CUDA cannot come soon enough :joy: . I am new here but I am experimenting with rewriting some of my CUDA kernels in Mojo to evaluate it. I am trying to understand the best way to structure this interface, specially when it comes to LayoutTensor .

I want to make these kernels into a reusable package and be able to feed any size tensor and have that work. Given this, inside this package of kernels should I create a new LayoutTensor alias at every call given that the shapes can change at any point?

Also when it comes to feeding the tensors into the package and out: It is a bit impractical to pass these data types around since then they all need to import from this alias which also makes it hard do create them on the fly. How would you deal with passing this information into the package and out of the package, considering the sizes can change.

For example one could think of somehow doing this:

from my_pkg import awesome_kernel

def main():
    ...
    ctx.enqueue_function[awesome_kernel](
        arbritrary_size_tensor,
        grid_dim=1,
        block_dim=VECTOR_WIDTH,
    )

But I am not sure how to support both LayoutTensor and have arbritary sized tensors.

Another idea would be to have a wrapper function that serves as an entry point to the kernel. Maybe then this function could create the LayoutTensor on the fly on every function call.

from my_pkg import awesome_kernel_wrapper

def main():
    ...
    awesome_kernel_wrapper(
        ctx, # So the wrapper can schedule the kernel
        arbritrary_size_tensor,
    )

I prefer this second option but this implies passing the device context to the wrapper function in another package.

I am curious to see what code structures are others considering or how one of these could be achieved.

1 Like

I think that, with a sufficient application of generics, you should be able to use LayoutTensor for everything. It has support for runtime dimensions, which means you get the best of both worlds as far as the ability to use compile-time known dims but still supporting runtime dims. It’s going to be a little more verbose, but no worse than doing it with templates in C++. This means that your exposed API to the user is ā€œhand me a LayoutTensor and I’ll use itā€, and they can make that runtime or compile time as they desire. Now, if you want arbitrary rank tensors, that may require some type gymnastics, but I think it still might be possible to do.

Passing around the ctx is fine if you want to stay in pure Mojo. However, I would strongly recommend making use of MAX custom ops. Since it’s a JIT compiler, it has the ability to convert runtime values in your high-level setup code into compile-time values as far as the hot loop is concerned, and generally gives fairly large performance boosts thanks to kernel fusion and other techniques. Most people using Mojo kernels are using them from MAX, so this will also help you interoperate with the rest of the ecosystem. There should be ~250kloc of examples in the coming weeks as the MAX kernels are open sourced, in addition to the existing examples in the MAX repo.

Thank you for the reply. As of now I am trying to not go into Max territory. Max is very focused on inference server and LLMs which is not really what I am interested in and it seems very daunting to even touch it since it is so complex and a bit messy to take out the parts that I care about from the Inference Engine and model serving aspects of it. All i need is very fine grained control on GPU kernels (triton fails since no control over shared memory), and python interoperability.

I tried your first approach but I am failing to find a way to get it working. Maybe I am missing something but I am a bit lost and without being able to look at library code I hit a bit of a wall here.

Here is a test script (adapted from an example kernel from MAX). I tried to make a dynamic tensor using make_shape_unknown but so far this is not working out. It compiles and runs but it is not giving the right results.

EDIT: It seems to work now but just for validation here is the script. Does it make sense that we have to assign an arbitrary size (in this case size 2) and then make it unkown shape?

from gpu import thread_idx
from gpu.host import DeviceContext
from layout import Layout, LayoutTensor
from math import ceildiv
from sys import has_nvidia_gpu_accelerator, has_amd_gpu_accelerator

alias float_dtype = DType.float32
alias dynamic_layout = Layout.row_major(2).make_shape_unknown()
alias DynamicTensor = LayoutTensor[float_dtype, dynamic_layout, MutableAnyOrigin]


fn vector_addition(
    lhs_tensor: DynamicTensor,
    rhs_tensor: DynamicTensor,
    out_tensor: DynamicTensor,
):
    """The calculation to perform across the vector on the GPU."""
    var tid = thread_idx.x
    out_tensor[tid] = lhs_tensor[tid] + rhs_tensor[tid]


def main():
    constrained[
        has_nvidia_gpu_accelerator() or has_amd_gpu_accelerator(),
        "This example requires a supported GPU",
    ]()

    alias VECTOR_WIDTH = 10
    alias layout = Layout.row_major(VECTOR_WIDTH)

    # Get context for the attached GPU
    var ctx = DeviceContext()

    # Allocate data on the GPU address space
    var lhs_buffer = ctx.enqueue_create_buffer[float_dtype](VECTOR_WIDTH)
    var rhs_buffer = ctx.enqueue_create_buffer[float_dtype](VECTOR_WIDTH)
    var out_buffer = ctx.enqueue_create_buffer[float_dtype](VECTOR_WIDTH)

    # Fill in values across the entire width
    _ = lhs_buffer.enqueue_fill(1.25)
    _ = rhs_buffer.enqueue_fill(2.5)

    # Wrap the device buffers in tensors
    var lhs_tensor = LayoutTensor[float_dtype, layout](lhs_buffer)
    var rhs_tensor = LayoutTensor[float_dtype, layout](rhs_buffer)
    var out_tensor = LayoutTensor[float_dtype, layout](out_buffer)

    # Launch the vector_addition function as a GPU kernel
    ctx.enqueue_function[vector_addition](
        lhs_tensor,
        rhs_tensor,
        out_tensor,
        grid_dim=1,
        block_dim=VECTOR_WIDTH,
    )

    # Map to host so that values can be printed from the CPU
    with out_buffer.map_to_host() as host_buffer:
        var host_tensor = LayoutTensor[float_dtype, layout](host_buffer)
        print("Resulting vector:", host_tensor)

I am not sure if this is exactly what you had in mind but am I doing something wrong? It is a bit weird to have to set a random dimension number and then make it unknown.

alias dynamic_layout = Layout.row_major(2).make_shape_unknown()
alias DynamicTensor = LayoutTensor[float_dtype, dynamic_layout, MutableAnyOrigin]

Not sure why but went for a break, then came back and now I get an error running that exact same script:

CUDA call failed: CUDA_ERROR_ILLEGAL_ADDRESS (an illegal memory access was encountered)
[1]    11458 IOT instruction (core dumped)  mojo run examples/test2.mojo

Max is very focused on inference server and LLMs

The marketing is very focused on that, true, but any computation which mostly does linear algebra of some sort already works quite well. @BradLarson has actually been been using it for computer vision stuff: GitHub - BradLarson/max-cv: An image processing framework built upon MAX. The marketing focus on AI/ML is because that’s where Modular expects to make most of their money, but it does work fine for general purpose computation. As an example, I have a bunch of work related to using MAX for building a TCP/IP stack.

Also, if you want python to call the kernels, for now you will need to use MAX.

You should be able to be generic over compile-time known values by doing something like this:


from gpu import thread_idx
from gpu.host import DeviceContext
from layout import Layout, LayoutTensor
from buffer.dimlist import Dim
from math import ceildiv
from sys import has_nvidia_gpu_accelerator, has_amd_gpu_accelerator


fn vector_addition[
    dtype: DType,
    shape: Layout,
    lhs_origin: Origin,
    rhs_origin: Origin,
    output_origin: MutableOrigin,
](
    read lhs_tensor: LayoutTensor[dtype, shape, lhs_origin],
    read rhs_tensor: LayoutTensor[dtype, shape, rhs_origin],
    read out_tensor: LayoutTensor[dtype, shape, output_origin],
):
    """The calculation to perform across the vector on the GPU."""
    var tid = thread_idx.x
    out_tensor[tid] = lhs_tensor[tid] + rhs_tensor[tid]


fn vector_sum_for_dtype_and_shape[
    dtype: DType, layout: Layout
](
    ctx: DeviceContext, lhs_value: Scalar[dtype], rhs_value: Scalar[dtype]
) raises:
    """Perform a vector addition of two vectors and sum the result."""
    # Allocate data on the GPU address space
    var lhs_buffer = ctx.enqueue_create_buffer[dtype](layout.size())
    var rhs_buffer = ctx.enqueue_create_buffer[dtype](layout.size())
    var out_buffer = ctx.enqueue_create_buffer[dtype](layout.size())

    # Fill in values across the entire width
    _ = lhs_buffer.enqueue_fill(lhs_value)
    _ = rhs_buffer.enqueue_fill(rhs_value)

    # Wrap the device buffers in tensors
    var lhs_tensor = LayoutTensor[dtype, layout](lhs_buffer)
    var rhs_tensor = LayoutTensor[dtype, layout](rhs_buffer)
    var out_tensor = LayoutTensor[dtype, layout](out_buffer)

    # Launch the vector_addition function as a GPU kernel
    ctx.enqueue_function[
        vector_addition[
            dtype,
            layout,
            lhs_tensor.origin,
            rhs_tensor.origin,
            out_tensor.origin,
        ]
    ](
        lhs_tensor,
        rhs_tensor,
        out_tensor,
        grid_dim=1,
        block_dim=Dim(layout.size()),
    )

    # Map to host so that values can be printed from the CPU
    with out_buffer.map_to_host() as host_buffer:
        var host_tensor = LayoutTensor[dtype, layout](host_buffer)
        print("Resulting vector:", host_tensor)


def main():
    constrained[
        has_nvidia_gpu_accelerator() or has_amd_gpu_accelerator(),
        "This example requires a supported GPU",
    ]()

    alias VECTOR_WIDTH = 10
    alias layout = Layout.row_major(VECTOR_WIDTH)
    alias float_dtype = DType.float32

    # Get context for the attached GPU
    var ctx = DeviceContext()

    vector_sum_for_dtype_and_shape[float_dtype, Layout.row_major(2)](ctx, 5.0, -7.0)
    vector_sum_for_dtype_and_shape[float_dtype, Layout.row_major(10)](ctx, 5.5, -7.0)

Allowing for RuntimeLayout in a performant way is going to need some inspection of the layout module to figure out whether when it can share the kernel with the static Layout. It does make the code a little more complex, but as long as you have compile-time known layouts this should generalize well. If you actually need ā€œload this vector/matrix/tensor from a fileā€ behavior, then MAX is your best bet to keep getting good performance.

I think that by type erasing the shape you may be walking off the end of the buffer since you aren’t checking the runtime shape. Since that’s UB, it can manifest in many, many ways.

Thank you, I will check MAX out. For now I am just evaluating for fun and to see the potential of Mojo for kernel code development. I wanted to jsut write some kernels and compare with CUDA and find the quirks of it.

Yes this generic approach works however I am afraid it might cause massive amounts of recompilations in some cases. For some kernels I work with (like gaussian splatting), the sizes are expected to change on every call so changing the generic call would cause a recompilation right? So there is no way to have it fully agnostic of sizes of dimensions, without generics?

I am also fine using pointers and moving pointer locations around, just like I do in CUDA, but I haven’t really worked much with mojo and there aren’t really any examples of this I can find. The LayoutTensor is really nice that it allows one to iterate so easily but I understand if it is limited in a way.

Maybe more of a feature request then for the future. Being able to have a dynamic sized LayoutTensor or another way to iterate raw pointers (even if unsafe, not sure if there is already) would be to me an essential feature.

So there is no way to have it fully agnostic of sizes of dimensions, without generics?

If you don’t see a lot of benefit from having static dimensions baked in, you could use RuntimeLayout and those constructors for LayoutTensor. The issues I described were with supporting both at the same time, unless Brad Larson has a way to do it I haven’t thought of.

Being able to have a dynamic sized LayoutTensor or another way to iterate raw pointers (even if unsafe, not sure if there is already) would be to me an essential feature.

Part of Mojo’s promise is portability, so using pointers may force you very deep into the depths of the device support or MAX since every single device family will have different pointer types. The features you are describing do exist, they’re just not as commonly used because for the most part people are using this for either AI or more typical HPC things. For my own, network-oriented usecases, I do a lot of weird things, but those are very unergonomic and force me to split the program between Mojo and SYCL, bouncing some data between kernels via the CPU.

Thank you, RuntimeLayout seems to be exactly what I wanted to find! However I am not managing to use it. If I try to use in in the kernel signature it says ā€œerror: cannot use a dynamic value in type specificationā€. Then if I try to make a LayoutTensor from it then I get ā€œerror: cannot use a dynamic value in type parameterā€

Same example as above but changed this:

alias float_dtype = DType.float32
var dynamic_layout = RuntimeLayout[Layout.row_major(1)]()
alias DynamicTensor = LayoutTensor[float_dtype, dynamic_layout, MutableAnyOrigin]

I get an error on this last line saying ā€œerror: cannot use a dynamic value in type parameterā€

I am probably doing something wrong here. Cannot find any usage on any code online of this RuntimeLayout to reference.

You need to have a Layout of the correct rank with everything made unknown, and then there’s a different constructor for LayoutTensor with a RuntimeLayout.

Sorry to bother you but could you give me a hint? I am very stuck trying to figure this out with just some documentation without examples or open code. I find many initialisers, many Tensor structures, many Layout structures and little to no info on them besides some generic API documentation and keep getting hit with compiler errors.

I am now even trying to work with MAX and with the ManagedTensorSlice but this also creates many annoyances since one has to do .store and .load as well as for all of these compute the index with IndexList[](). I really would like to use the Layouts but I am really not understanding and maybe not so deep into it to understand what you mean.

I am experimenting with dynamic LayoutTensors as well. Here is what I have at the moment:

from layout import Layout, LayoutTensor
from layout.tensor_builder import dynamic, LayoutTensorBuild
from memory import UnsafePointer

def tensor(dim0: Int, dim1: Int) -> LayoutTensor[DType.float32, Layout.row_major(-1, -1), MutableAnyOrigin]:
    var ptr = UnsafePointer[Scalar[DType.float32]].alloc(dim0 * dim1)
    return LayoutTensorBuild[DType.float32]().row_major(dynamic(dim0), dynamic(dim1)).view(ptr)

def main():
    var x = tensor(2,3).fill(1.0)
    print(x.runtime_layout) # ((2, 3):(3, 1))
    print(x) # correctly prints out a 2x3 matrix

    # However the following fails to run with the following Error: 'constraint failed: Requires fully static layout'
    var y = tensor(2,3).fill(1.0)
    print(x + y)

The initial creation of the dynamic LayoutTensors works, but working with them (e.g. doing arithmetic) still fails since it requires fully static shapes as shown in the comment. Why? Can we have sth like this working in some way?

After some further investigations I believe that doing anything with dynamic shapes, which are not known at compile time, is not supported yet. Am I right about this? I clearly still don’t understand MAX’ usage of a RuntimeLayout compared to a regular Layout. Would love to learn more about this all. :slight_smile:

Well I am on the same boat. The more I investigate the more confused I get with all the million things one can do and how none of them seem to really work as I want them to. I just saw their new release of kernels and they seem to mostly be using NDBuffer instead of any Tensor/Layout type…
(e.g. modular/mojo/kernels/src/nn/cumsum.mojo at main Ā· modular/modular Ā· GitHub)

I also asked the bot on Discord about the differences between Tensor, LayoutTensor and NDBuffer and got this:

The main distinction is that Tensor owns and manages its memory, while both LayoutTensor and NDBuffer are views over existing memory. LayoutTensor provides more sophisticated control over memory layout and access patterns compared to NDBuffer , making it particularly useful for performance-critical applications like GPU programming.

You can also convert between these types. For example, you can create a LayoutTensor view from a Tensor using the to_layout_tensor() method, as shown in GPU programming examples.

I’ll try to dig a bit and implement an example with NDBuffer and see how far I can get…

Dynamic shapes are supported, but at the moment it is (as you’ve found) a little counter-intuitive. A LayoutTensor has to have a compile-time layout parameter. Because it’s a parameter, that layout is basically static: it’s part of the tensor’s type, in the same way that a List[Int] is not interchangeable with List[String]. You can specify that one or more dimensions are unknown using UNKNOWN_VALUE in place of an actual dimension. (Or as @TilliFe showed, you can use LayoutTensorBuild, but I’ll stick with the manual process for the sake of explanation.)

To specify a (dynamic) runtime layout, you need to create a RuntimeLayout object, which lets you specify a dynamic shape. The runtime layout also has a static layout parameter. The dynamic layout can have different dimensions from the static layout, but it has to have the same number of dimensions.

It looks like some elementwise operations (like tensor_a + tensor_b) don’t work with dynamic tensor shapes at the moment. (I can’t shed any light on that, but certainly file an issue.) But you can do a lot of basic operations, like accessing individual elements. These dynamic layout tensors are used in some of the MAX AI kernels.

(For example: modular/max/kernels/src/nn/mla.mojo at b06993f719c128f0ef528cafdb2b11102186893e Ā· modular/modular Ā· GitHub)

Here’s a simple example of invoking a kernel with a dynamically-sized layout tensor.

from layout import Layout, LayoutTensor, UNKNOWN_VALUE, RuntimeLayout
from sys import has_accelerator
from gpu.host import DeviceContext
from gpu import thread_idx, block_idx, global_idx, grid_dim, block_dim, barrier
from utils import Index


fn dynamic_layout_example():
    alias dtype = DType.int32
    alias in_size = 128
    alias block_size = 16
    num_blocks = in_size // block_size
    alias input_layout = Layout.row_major(UNKNOWN_VALUE, UNKNOWN_VALUE)

    fn kernel(tensor: LayoutTensor[dtype, input_layout, MutableAnyOrigin]):
        var width: Int = tensor.runtime_layout.dim(0)
        var height = tensor.runtime_layout.dim(1)
        # extract a tile from the input tensor.
        var tile = tensor.tile[block_size, block_size](block_idx.x, block_idx.y)
        if (global_idx.x < width and global_idx.y < height):
            tile[thread_idx.x, thread_idx.y] = global_idx.y + global_idx.x * grid_dim.x * block_dim.x

    try:
        var ctx = DeviceContext()
        var width: Int = 128
        var height: Int= 128
        var host_buf = ctx.enqueue_create_host_buffer[dtype](width * height)
        var dev_buf = ctx.enqueue_create_buffer[dtype](width * height)
        ctx.synchronize()
        for i in range(width * height):
            host_buf[i] = 1
        ctx.enqueue_copy(dev_buf, host_buf)
        var runtime_layout = RuntimeLayout[input_layout].row_major(Index(width, height))
        var tensor = LayoutTensor[dtype, input_layout](dev_buf, runtime_layout)

        ctx.enqueue_function[ 
            kernel,
        ](
            tensor,
            grid_dim=(num_blocks, num_blocks), 
            block_dim=(block_size, block_size),
        )
        ctx.synchronize()
        ctx.enqueue_copy(host_buf, dev_buf)
        ctx.synchronize()
        print(host_buf)
    except error:
        print(error)


def main():
    if has_accelerator():
        dynamic_layout_example()
    else:
        print("No accelerator")

I think the compute graphs are much happier with tensors that have static dimensions, but if you need dynamic dimensions, there are things you can do. (You can also just pass a DeviceBuffer[dtype] to enqueue_function, which is translated to an UnsafePointer[Scalar[dtype]] in the kernel signature. But as @owenhilyard said, you may be sacrificing portability.)

2 Likes