I’m trying to test a GPU kernel, and the code implementation is exactly the same as the content in this link: https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/fused_attention.mojo. However, I encountered a compilation error, which seems to be a shape mismatch.
import torch
from pathlib import Path
from max.torch import CustomOpLibrary
mojo_kernels = Path(__file__).parent / "operations"
op_library = CustomOpLibrary(mojo_kernels)
mojo_fused_attention = op_library.fused_attention_custom[
{
"BN": 8,
"BD": 32,
}
]
def fused_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
m, _ = q.shape
_, n = v.shape
result = torch.zeros((m, n), dtype=q.dtype, device=q.device)
mojo_fused_attention(result, q, k, v)
return result
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
q = torch.randn((40, 128), dtype=torch.float32, device=device)
k = torch.randn((40, 128), dtype=torch.float32, device=device)
v = torch.randn((40, 128), dtype=torch.float32, device=device)
print(fused_attention(q, k, v))
if __name__ == "__main__":
main()
ValueError: Failed to run the MOToMGP pass manager:
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: failed to run the pass manager for offload functions
/content/modular-practice/custom_op/operations/attentions.mojo:314:4: error: call expansion failed
/content/modular-practice/custom_op/operations/attentions.mojo:314:4: note: function instantiation failed
/content/modular-practice/custom_op/operations/attentions.mojo:361:50: note: call expansion failed
note: call expansion failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:958:8: note: function instantiation failed
open-source/max/max/kernels/src/layout/layout_tensor.mojo:1027:18: note: call expansion failed
note: constraint failed: _elementwise_binary_with_broadcast requires shape to be the same for tensors of the same rank
open-source/max/max/kernels/src/Mogg/MOGGKernelAPI:1:1: error: Could not elaborate the provided code: failed to run the pass manager
error: The graph compiler tried to JIT compile the provided kernels but failed during elaboration