Note: this is a question that was submitted for the February 3rd community meeting – we weren’t able to answer this question live so we’re answering it here instead.
When working on a neural net, a significant portion of the time iterating is waiting for the model to load the weights and doing the warmup on the gpu. And compiling if using a compiled language like Rust. All ML frameworks are pretty slow at this. Any chance MAX can do better on this front?
I’ll give a preliminary answer while you wait for the team to get to you.
A decent portion of actually loading a model is in lifting the weights off of disk, and then moving them to the GPU. If you are already maxing out the bandwidth of the drive, there’s not much that Modular can do. If you aren’t then GPU Direct Storage becomes an option for getting the weights to the GPU without a round-trip through the CPU, although this has a bunch of compatibility requirements. If you have a pre-compiled MAX model, then GPU Direct Storage is probably the best path to get the model into the GPU(s) quickly.
Warmup is something that’s generally unnecessary unless you’re benchmarking, so that can be tossed out if you’re iterating on architecture.
MAX is, at its core, a graph compiler, similar to torch.compile, so there will likely be some compilation overhead. Modular could introduce an “interpreted” mode, which works more like normal step-by-step pytorch.
Mojo is a compiled language, but if you are using the JIT mode in the REPL or a notebook, as you would when iterating, it’s very fast since LLVM spends a lot of time dealing with all of the stuff needed to produce an executable like object file layout and linking.
Files that can be zero-copied, like .safetensors, are almost the best possible case for disk performance. If I were you, I would try using a zero-copy format, and making sure you have enough RAM to store the model. This effectively removes IO time and makes it a matter of how fast the GPU can pull data out of CPU memory, which is usually quite fast. I’ve found this works well for me unless I’m making major modifications to the model every time, at which point I do run into limitations MAX has around the lack of a “don’t try that hard” mode.
Compilation and startup time are a constant area of emphasis for us on the performance front. Since the MAX 24.6 release, we’ve introduced a series of compilation improvements that reduced compile time by up to 1.5-4X in certain graphs that target GPUs. In the MAX nightlies, we’ve also shifted from just-in-time compilation of PTX for NVIDIA GPUs to an ahead-of-time compilation model. This all but eliminates the need for extensive warmup runs of AI models on GPU, giving you the best performance on first run.
As a point of comparison, for some of our most complex model graphs we currently take on the order of 1-2 minutes to compile for GPU, where we’ve heard other graph compilers may take tens of minutes to over an hour for the same graph compilation.
When it comes to weight loading, we currently do not embed the weights in the compiled graph for an AI model. We found that this yielded performance benefits vs. inlining the weights in the compiled graph. For our models, we can load from SafeTensor or GGUF weight formats, and the compiled graph maps to the weights in those files which are loaded at run time. Already-compiled graphs load very quickly, and we’re continuing to work on improving both graph and weight loading speeds.
In the MAX nightlies, we’ve also shifted from just-in-time compilation of PTX for NVIDIA GPUs to an ahead-of-time compilation model.
Could you expand on that? I may be misreading it, but that sounds like you are compiling directly for the GPU ISA instead of going through the PTX compiler in the driver, which I didn’t think was possible. AOT graph compilation is also something I’m very much interested in, especially if single source comes with it or follows.
Instead of compiling to PTX, we now go all the way to cubin on initial compilation. This isn’t changing the time at which the graph itself is compiled, but it makes the performance of the graph when run a lot more reliable by avoiding hard-to-predict JIT compilation times. These often led to the need for warmup runs to stabilize performance.