Why does mojo build with matmul_gpu generate different kernels for different M at runtime?

I’m playing withbench_matmul(max/kernels/benchmarks/gpu/bench_matmul.mojo).

Trying to figure out thematmulGPU kernel it used when running the mojo program. To investigate, I modified the relevant matmul_gpu.mojo source so that everyenqueue_functioncall dumps the kernel assembly, as shown below

@@ -1177,7 +1180,7 @@ fn multistage_gemm[
             config=config,
             elementwise_lambda_fn=elementwise_lambda_fn,
         ]
-        ctx.enqueue_function[gemm_kernel_type](
+        ctx.enqueue_function[gemm_kernel_type, dump_asm=Path("out3.asm")](
             tensor_c,
             tensor_a,
             tensor_b,

Then, run benchmark as below

# In modular top folder
$ ./bazelw run //max/kernels/benchmarks/autotune:kbench bench_matmul.yaml -- --output-dir bench_out
# bench_matmul.yaml as below
name: bench_matmul
file: $KERNEL_BENCHMARKS_ROOT/gpu/bench_matmul.mojo

params:

- $M: [3500, 8192]
  N: 4096
  K: 4096

Such benchmark run will generate a compiled executable binary, bench_out/out_0/bench_matmul_N-4096_K-4096.

This executable can be reused to test different values of M by passing the --M option, for example

$ ./bench_out/out_0/bench_matmul_N-4096_K-4096 --M=3500
$ ./bench_out/out_0/bench_matmul_N-4096_K-4096 --M=8192

Both of these runs generate an assembly dump (e.g., out3.asm). Interestingly, I found that the two out3.asm (M=3500 and M=8192) files turned out to be different.

Here are my questions

  1. Why are different kernels generated in runtime?
    • Is Mojo build generating multiple specialized variants and embedded at executable binary?
  2. When does specialization happen?
    • Is the kernel specialized for some cases ofMat compile-time (via meta-programming)?

Hi @phybd44, Python to Mojo in the MAX engine and the kbench utility is JIT compiled, so you can change Python runtime values, which are then transformed into compile-time parameters on the Mojo end. When you change --M=3500 to --M=8192 it has to recompile the Mojo-generated binary with the parameter for that shape the first time it’s run, but then on subsequent runs it retrieves the compiled binary from the cache.

1 Like

@jack.clayton thanks for the explanation

Just to confirm: do you mean that if I run

./bench_out/out_0/bench_matmul_N-4096_K-4096 --M=1234

for the first time, it will also JIT compile the corresponding GPU kernel assembly?

My initial understanding was that the executable bench_out/out_0/bench_matmul_N-4096_K-4096 produced by mojo build is already compiled, so there wouldn’t be any further JIT compilation — I thought JIT only happens when running via mojo run.

Compilers are best when they have full information. By delaying kernel compilation mojo gets two big wins.

First, the kernels can make use of new features of GPUs without sacrificing portability or size. The reason that pytorch cublas is multiple gigabytes is because it ships multiple copies of compiled PTX for different generations of GPUs. Even then, pytorch doesn’t compile PTX for very old GPUs, so a given build of pytorch may simply not work on your GPU. If we were to multiply this out across AMD GPUs, Nvidia GPUs, Apple Silicon, and CPUs for Mojo, it would easily be tens of gigabytes.

Second, MAX knows exactly what GPU is being used, and can know the exact input shapes of all non-dynamic tensors. This means that it can take maximum advantage of every new feature in a way that would require pytorch to ship easily 5x the kernels just for Nvidia support. Imagine if Pytorch had to ship kernels for 2x2, 3x2, 4x2, …, 2x4, 3x4, …, 8196x8196 matrices, you’d be able to do inference by hand before the download would finish. By instead compiling the kernel in a more generic way Modular doesn’t need to write tons of kernels by hand, and you get the “perfect” kernels for your task. This also carries through to kernel fusion, since instead of dealing with assembly the compiler can look for for loops over a layout dimension, which is much, much faster.

If you have experience with kernel libraries, this sounds terrifying because of how long kernel libraries normally take to compile, but Mojo is actually a very fast to compile language because its type system is fairly straightforward to implement in a compiler, at least compared to C++ where you can go 20 functions deep and then “oops” your way back up because something didn’t quite work right, making the compiler try another option. Similarly, MAX doesn’t need to have SSE4, AVX, AVX2, AVX512, and AVX512 with some extensions versions of every single piece of CPU code.

Mojo does also use the JIT during Mojo run, since that saves a lot of linking time and helps with the iteration speed of working on Mojo. At some point, you’ll be able to AOT compile a graph for a particular target and ship that in your binary, but that’s less useful than supporting more GPUs, having more model pipelines, and expanding the Mojo language, so it’s lower priority since things work at full speed now once the JIT is done. It’s annoying for experimentation, but it’s not a big deal for production.

1 Like

Hi sorry I’ve been told it has nothing to do with JIT, here’s a response from someone who works on that part of the stack:

The dispatch logic for _matmul_gpu may contain several GPU kernels that have been auto-tuned (for nvidia) or instantiated at compile time (for AMD) for a given static value of N and K. In this case, N=4096 and K=4096 and given this is a common llama3 matmul shape, auto-tuning exists for most architectures. Selecting different values of M at runtime (in this case via the --M switch) causes these different GPU kernels to be selected. The same base kernel may be used but is parameterized with different parameters such as block shape, number of pipeline stages, etc.

1 Like