I am starting a notebook series on how to train transformer neural networks in Nabla.
Part 1 is a side-by-side (Nabla vs. JAX) toy implementation from scratch: JAX vs. Nabla: Training a Transformer (CPU)
Let’s build on this foundation…
I am starting a notebook series on how to train transformer neural networks in Nabla.
Part 1 is a side-by-side (Nabla vs. JAX) toy implementation from scratch: JAX vs. Nabla: Training a Transformer (CPU)
Let’s build on this foundation…
Thank you for sharing, this is a fantastic resource and an impressive outcome!
Why do you think nabla is faster? And do you think that result will hold true for larger models? What, if anything, is stopping us from using nabla for „production” training runs today?
I’m guessing the lower loss at the end of the training run is a coincidence due to the random seed?
You mention the M3 processor – does that mean this is all CPU only right now?
Hi, thanks for the nice feedback.
There are two orthogonal reasons why Nabla might be faster than other frameworks in JIT mode:
By the way, Nabla is not limited to CPU; I will provide more information on GPU executions soon. And would I use Nabla in production training already? No, not yet! But let’s push this idea forward…
I took a pass with the training example you provided and modified it to use my GPU. Nabla was quite a bit faster than JAX on the same hardware, I want to say >20%.
Wow that’s a great result!
This is so cool! Which GPU are you using? Would you mind sharing the code?
Or, if you feel like it , directly make a pull request. If your code is also a .ipynb file, then it would be super cool to have this in the root/tutorials directory next to the others as well as in the root/docs/tutorials directory (including the outputs of each cell). I believe everybody would love to see how you made this work!
4090m @ 165W, which is the same die as a desktop 4080. So, think of it as a 4080 with a power limit.
I did the modification in /tmp so it’s gone, but literally all I did was swap from CPU
to Accelerator
and it worked on the Nabla side. On the JAX side I swapped the dependency to use to cuda version.
Nice.