MAX kernel launch overhead

I had a question about MAX graphs - I enqueued the same function several times in a MAX graph (see code snippet below) and profiled the execution via nsys (see attached image). It seems like there was non-trivial CPU overhead for the launch of each function on the GPU. Does MAX have a way internally to create and execute CUDA graphs for execution of its own graphs?

+            gpu_ctx.enqueue_function[
+                conv1d_kernel[
+                    in_layout, out_layout, conv_layout, input_size, conv_size
+                ]
+            ](
+                output_tensor,
+                input_tensor,
+                kernel_tensor,
+                grid_dim=BLOCKS_PER_GRID,
+                block_dim=(TPB, 1),
+            )
+
+            gpu_ctx.enqueue_function[
+                conv1d_kernel[
+                    out_layout, out_layout, conv_layout, input_size, conv_size
+                ]
+            ](
+                output_tensor,
+                output_tensor,
+                kernel_tensor,
+                grid_dim=BLOCKS_PER_GRID,
+                block_dim=(TPB, 1),
+            )
+
+            gpu_ctx.enqueue_function[
+                conv1d_kernel[
+                    out_layout, out_layout, conv_layout, input_size, conv_size
+                ]
+            ](
+                output_tensor,
+                output_tensor,
+                kernel_tensor,
+                grid_dim=BLOCKS_PER_GRID,
+                block_dim=(TPB, 1),
+            )
+            gpu_ctx.enqueue_function[
+                conv1d_kernel[
+                    out_layout, out_layout, conv_layout, input_size, conv_size
+                ]
+            ](
+                output_tensor,
+                output_tensor,
+                kernel_tensor,
+                grid_dim=BLOCKS_PER_GRID,
+                block_dim=(TPB, 1),
+            )

The first time you run a particular kernel, it will JIT compile it if you’re running it via gpu_ctx. You should first use cpu_ctx.compile_function_checked and use it as in DeviceFunction | Modular. As far as I am aware, there isn’t a hidden cache for function-level compilation, so you’re recompiling the function every time you do that.

Once you compile them beforehand, you should be able to see the actual overhead. Last I checked it was in the neighborhood of a few dozen microseconds.