Looking for examples of mulit-gpu usage with Mojo

Thanks for the kind words about the development experience, myself and the rest of the team were delighted to hear that. We just rolled out our single-host, multi-GPU support officially in 25.2, so we still have to fill in the documentation and examples around multi-GPU programming.

Starting from the basics, MAX Driver API Devices can be created around any MAX-compatible accelerator in your system, like in this example from our API documentation:

from max import driver
device = driver.Accelerator()
# Or specify GPU id
device = driver.Accelerator(id=0)  # First GPU
device = driver.Accelerator(id=1)  # Second GPU
# Get device id
device_id = device.id

Driver Tensors can be moved to any Device, or between Devices, to orchestrate data transfer outside of graphs in Python or Mojo. MAX Graphs themselves can be placed on any accelerator for execution, or on multiple devices.

The new DistributedLlama3 MAX Graph architecture is a highly performant, tensor-parallel version of the LlamaForCausalLM model architecture. Inside of its Python source code (which is all available on our public MAX repository), you’ll find layers like DistributedMLP that in turn use an optimized allreduce operation for coordinating across multiple GPUs.

These GPU communication Mojo functions can be found in the gpu.comm standard library module, which we have not yet open-sourced, but the API docs for all relevant functions are available and the functions themselves are directly usable. Again, we’re planning to document and expose more of this, but as a preview, here’s a naive implementation of that allreduce function in Mojo, without peer-to-peer communication:

@always_inline
fn _allreduce_naive[
    type: DType,
    rank: Int,
    ngpus: Int,
    outputs_lambda: elementwise_epilogue_type,
](
    list_of_in_bufs: InlineArray[NDBuffer[type, rank, MutableAnyOrigin], ngpus],
    list_of_out_bufs: InlineArray[
        NDBuffer[type, rank, MutableAnyOrigin], ngpus
    ],
    max_num_blocks: Int,
    ctxs: List[DeviceContext],
) raises:
    """Performs allreduce across GPUs without using peer-to-peer access.

    Implementation Steps (per GPU):
    1. Create accumulation buffer initialized to zero
    2. For each other GPU:
       a. Allocate temporary buffer on current GPU
       b. Copy remote GPU's data to temporary buffer
    3. Reduce all buffers into accumulation buffer:
       - Local buffer
       - All temporary buffers
    4. Apply output lambda to write accumulation buffer to final output

    Data Flow (3 GPU example):

    GPU0 Input  GPU1 Input  GPU2 Input
          |         |         |
          |         |         |
          v         v         v
    +---------------------------------+
    | Temporary Buffers per GPU       |
    | GPU0: [Temp01][Temp02]          |
    | GPU1: [Temp10][Temp12]          |
    | GPU2: [Temp20][Temp21]          |
    +---------------------------------+
                   |
                   v
    +---------------------------------+
    | Accumulation Buffer per GPU     |
    | GPU0: sum(Input0 + Temp01 + Temp02) |
    | GPU1: sum(Input1 + Temp10 + Temp12) |
    | GPU2: sum(Input2 + Temp20 + Temp21) |
    +---------------------------------+
                   |
                   v
    +---------------------------------+
    | Output Lambda Application       |
    | (Writes to final output buffers)|
    +---------------------------------+

    Parameters:
        type: The data type of tensor elements.
        rank: Number of dimensions in input tensors.
        ngpus: Number of GPUs participating in allreduce.
        outputs_lambda: An elementwise output lambda function.

    Args:
        list_of_in_bufs: Input buffers from each GPU.
        list_of_out_bufs: Output buffers for each GPU.
        max_num_blocks: Maximum number of thread blocks to launch.
        ctxs: List of device contexts for participating GPUs.

    This implementation copies all data to each GPU and performs local reduction.
    Used as fallback when P2P access is not available.
    """
    alias simd_width = simdwidthof[type, target = _get_gpu_target()]()
    var num_elements = list_of_in_bufs[0].num_elements()

    var device_buffers = List[DeviceBuffer[type]](capacity=ngpus)
    # Assemble input buffer structures from all devices
    for i in range(ngpus):
        device_buffers.append(
            DeviceBuffer(
                ctxs[i], list_of_in_bufs[i].data, num_elements, owning=False
            )
        )

    # Process each device
    @parameter
    for device_idx in range(ngpus):
        var curr_ctx = ctxs[device_idx]

        # Create temporary accumulation buffer.
        var accum_buffer = curr_ctx.enqueue_create_buffer[type](num_elements)
        curr_ctx.enqueue_memset(accum_buffer, 0)  # Initialize to zero

        # Create temporary buffers for remote data.
        var tmp_buffers = List[DeviceBuffer[type]]()
        for i in range(ngpus):
            if i != device_idx:
                var tmp = curr_ctx.enqueue_create_buffer[type](num_elements)
                curr_ctx.enqueue_copy(tmp, device_buffers[i])
                tmp_buffers.append(tmp)

        # Reduce all buffers into accumulation buffer.
        alias BLOCK_SIZE = 256
        var grid_size = min(max_num_blocks, ceildiv(num_elements, BLOCK_SIZE))

        # First reduce local buffer.
        curr_ctx.enqueue_function[_naive_reduce_kernel[type]](
            accum_buffer.unsafe_ptr(),
            device_buffers[device_idx].unsafe_ptr(),
            num_elements,
            grid_dim=grid_size,
            block_dim=BLOCK_SIZE,
        )

        # Reduce remote buffers.
        for tmp in tmp_buffers:
            curr_ctx.enqueue_function[_naive_reduce_kernel[type]](
                accum_buffer.unsafe_ptr(),
                tmp[].unsafe_ptr(),
                num_elements,
                grid_dim=grid_size,
                block_dim=BLOCK_SIZE,
            )

        # Apply output lambda to final accumulated buffer.
        curr_ctx.enqueue_function[
            _naive_reduce_kernel_with_lambda[
                type,
                rank,
                my_rank=device_idx,
                width=simd_width,
                alignment = alignof[SIMD[type, simd_width]](),
                outputs_lambda=outputs_lambda,
            ]
        ](
            list_of_out_bufs[device_idx],
            accum_buffer.unsafe_ptr(),
            num_elements,
            grid_dim=grid_size,
            block_dim=BLOCK_SIZE,
        )

That’s a sampling of different topics around multi-GPU programming, so if there’s a specific area you’d like more help with, just let us know. We’re constantly expanding the documentation and examples around GPU programming, and we really appreciate the feedback on where to focus. These interfaces are also constantly evolving, so we’re working on even better examples.

2 Likes