I had some Codex quota left this week, so I used it on a small experiment: can MAX be made to support PyTorch-like training behavior, and how fast is a compiled train step?
I built a prototype standalone package, max_training, with limited reverse-mode autograd, parameters, Linear, MSE loss, SGD-style updates, compiled train steps, MLIR inspection, and PyTorch comparison benchmarks.
RTX 5090 benchmark results:
For one larger MLP training-step benchmark, I saw:
torch_eager 0.865 ms
torch_compile[max-autotune] 0.808 ms
max_compile 0.581 ms
MAX compile time was much higher, around 100s, so this is only an early feasibility result, not a production claim.
Curious if this benchmark setup looks reasonable to people familiar with MAX internals. Are these numbers plausible, or am I accidentally benchmarking an easy/special case?
This seems roughly normal, MAX from pytorch tends to compile a lot slower than native MAX since MAX expects slightly higher level kernels as input. If you are willing to hand-write kernels like some past experiments have done, you may see fairly dramatic compile-time improvements. Some very old numbers on my 4090 had MAX beating JAX in compile times by nearly 100x (JAX is known for bad compile times) and in runtime performance by ~2x. This was multiple years of very heavy development ago, so that has likely changed.