GPU kernel compilation error

I’m trying to test a GPU kernel, and the code implementation is exactly the same as the content in this link: https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/fused_attention.mojo. However, I encountered a compilation error, which seems to be a shape mismatch.

import torch

from pathlib import Path
from max.torch import CustomOpLibrary

mojo_kernels = Path(__file__).parent / "operations"
op_library = CustomOpLibrary(mojo_kernels)
mojo_fused_attention = op_library.fused_attention_custom[
    {
        "BN": 8,
        "BD": 32,
    }
]


def fused_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    m, _ = q.shape
    _, n = v.shape
    result = torch.zeros((m, n), dtype=q.dtype, device=q.device)
    mojo_fused_attention(result, q, k, v)
    return result


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"device: {device}")
    q = torch.randn((40, 128), dtype=torch.float32, device=device)
    k = torch.randn((40, 128), dtype=torch.float32, device=device)
    v = torch.randn((40, 128), dtype=torch.float32, device=device)
    print(fused_attention(q, k, v))
    
if __name__ == "__main__":
    main()
ValueError: Failed to run the MOToMGP pass manager:
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: failed to run the pass manager for offload functions
/content/modular-practice/custom_op/operations/attentions.mojo:314:4: error: call expansion failed
/content/modular-practice/custom_op/operations/attentions.mojo:314:4: note: function instantiation failed
/content/modular-practice/custom_op/operations/attentions.mojo:361:50: note: call expansion failed
note: call expansion failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:958:8: note: function instantiation failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:1027:18: note: call expansion failed
note: constraint failed: _elementwise_binary_with_broadcast requires shape to be the same for tensors of the same rank
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: Could not elaborate the provided code: failed to run the pass manager
error: The graph compiler tried to JIT compile the provided kernels but failed during elaboration

The GPU matrix multiplication implementation in the linked code lacks robustness, and I am currently working on replacing the matmul implementation.

This one seems to be the Pytorch custom os issue and not the kernel itself. Could you please create an issue here?