Training a Transformer with MAX Acceleration

I am starting a notebook series on how to train transformer neural networks in Nabla.

Part 1 is a side-by-side (Nabla vs. JAX) toy implementation from scratch: JAX vs. Nabla: Training a Transformer (CPU)

Let’s build on this foundation…

9 Likes

Thank you for sharing, this is a fantastic resource and an impressive outcome!

Why do you think nabla is faster? And do you think that result will hold true for larger models? What, if anything, is stopping us from using nabla for „production” training runs today?

I’m guessing the lower loss at the end of the training run is a coincidence due to the random seed?

You mention the M3 processor – does that mean this is all CPU only right now?

1 Like

Hi, thanks for the nice feedback.

There are two orthogonal reasons why Nabla might be faster than other frameworks in JIT mode:

  1. Nabla is the first AD framework to use the MAX Graph Compiler, which does certain optimizations on its own that other compiler engines (especially XLA for JAX) might not do. I don’t have a clear insight into MAX’ amazing optimization pipeline at the moment, so a member of the Modular team would be able to judge this part much better.
  2. On the other hand, Nabla aims’s to innovate on its function transformation engine (including trafos like vmap, jvp, vjp & grad), which by now is able to create genuinely lean intermediate representations (IR) for a given task. Nabla itself also performs some clever a priori graph optimizations to make this IR (which is then passed to MAX) as concise as possible.

By the way, Nabla is not limited to CPU; I will provide more information on GPU executions soon. And would I use Nabla in production training already? No, not yet! But let’s push this idea forward…

1 Like

I took a pass with the training example you provided and modified it to use my GPU. Nabla was quite a bit faster than JAX on the same hardware, I want to say >20%.

2 Likes

Wow that’s a great result!

1 Like

This is so cool! Which GPU are you using? Would you mind sharing the code?

Or, if you feel like it :hugs:, directly make a pull request. If your code is also a .ipynb file, then it would be super cool to have this in the root/tutorials directory next to the others as well as in the root/docs/tutorials directory (including the outputs of each cell). I believe everybody would love to see how you made this work!

4090m @ 165W, which is the same die as a desktop 4080. So, think of it as a 4080 with a power limit.

I did the modification in /tmp so it’s gone, but literally all I did was swap from CPU to Accelerator and it worked on the Nabla side. On the JAX side I swapped the dependency to use to cuda version.

1 Like

Nice.