GPU kernel compilation error

I’m trying to test a GPU kernel, and the code implementation is exactly the same as the content in this link: https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/fused_attention.mojo. However, I encountered a compilation error, which seems to be a shape mismatch.

import torch

from pathlib import Path
from max.torch import CustomOpLibrary

mojo_kernels = Path(__file__).parent / "operations"
op_library = CustomOpLibrary(mojo_kernels)
mojo_fused_attention = op_library.fused_attention_custom[
    {
        "BN": 8,
        "BD": 32,
    }
]


def fused_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    m, _ = q.shape
    _, n = v.shape
    result = torch.zeros((m, n), dtype=q.dtype, device=q.device)
    mojo_fused_attention(result, q, k, v)
    return result


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"device: {device}")
    q = torch.randn((40, 128), dtype=torch.float32, device=device)
    k = torch.randn((40, 128), dtype=torch.float32, device=device)
    v = torch.randn((40, 128), dtype=torch.float32, device=device)
    print(fused_attention(q, k, v))
    
if __name__ == "__main__":
    main()
ValueError: Failed to run the MOToMGP pass manager:
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: failed to run the pass manager for offload functions
/content/modular-practice/custom_op/operations/attentions.mojo:314:4: error: call expansion failed
/content/modular-practice/custom_op/operations/attentions.mojo:314:4: note: function instantiation failed
/content/modular-practice/custom_op/operations/attentions.mojo:361:50: note: call expansion failed
note: call expansion failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:958:8: note: function instantiation failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:1027:18: note: call expansion failed
note: constraint failed: _elementwise_binary_with_broadcast requires shape to be the same for tensors of the same rank
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: Could not elaborate the provided code: failed to run the pass manager
error: The graph compiler tried to JIT compile the provided kernels but failed during elaboration

The GPU matrix multiplication implementation in the linked code lacks robustness, and I am currently working on replacing the matmul implementation.

This one seems to be the Pytorch custom os issue and not the kernel itself. Could you please create an issue here?

This topic was automatically closed 180 days after the last reply. New replies are no longer allowed.