MAX' GPU ReduceOps for non-inner axes

Good day!

Lately I have been testing Nabla’s MAX backend for GPUs. Even though things look fairly promising and most things works out nicely, I came across a severe constraint when doing reduce operations (max.graph.ops.sum, max.graph.ops.mean, etc.) on a GPU device, sth. which I do not get if i target CPU only. Efficient Summation across multiple axes is super crucial for any AutoDiff framework, not having one is indeed a great bottleneck!
In particular, the respective MAX ops complain about not being able to reduce across non-inner axes on GPU, i.e. we can only reduce along the last axis of any given input tensor value. Surely there are workarounds with iteratively transposing → reducing → and transposing back like here, but this comes with a lot of memory copies internally, slowing everything down.

My question: May we get dedicated reduce operations in MAX that can act on non-inner axes on GPUs/take in multiple axes at the same time? Or are there potentially smarter, more efficient workarounds for now? :fire:

3 Likes

Hey Tillie!

  • This sounds like a hole in our GPU kernel implementations. The ops semantically support an index, and the GPU kernel implementation are just choosing to not do it yet :sweat_smile:

  • Transpose → reduce → transpose back should be efficient. Our LayoutTensor abstraction allows ops with compile-time layout transformations. In this case these are fused into a single reduce op. If there is a performance difference it should be from a different memory access pattern during the reduce-with-transposed-layout rather than a copy.

Could you please file a github issue for enabling the reductions with inner dimensions, and include the above context? I’ll look into whether it’s as simple as removing those lines :smiley:

2 Likes

This is very revealing. :hugs: Thank you! I will file an issue. The line in (3) is exactly the one I encountered.

1 Like

This should be fixed in the next nightly!

1 Like

Works perfectly! Thank you very much. (I will skip creating an issue on this then :))

By the way, I ran some benchmarks and it seems like this fix indeed makes things a bit faster. Where I got a ~10x speedup for training a simple MLP on GPU vs. CPU, I now seem to get more like a ~15x speedup. My workaround with the transpositions probably created some unnecessary non-contiguity along the way, slowing things down internally. (Not sure though… you mentioned the layout awareness, mmh :thinking:)

For anyone wondering why I am so keen on this particular type of operation:

Fast reductions are vital for Automatic Differentiation (AD) engines. In array-computing libraries like PyTorch, NumPy, MAX and Nabla, binary operations (such as addition, multiplication, etc.) are frequently performed between objects of different shapes. This is handled by an implicit “broadcasting” operation, which creates views of the underlying data to make the shapes compatible for computation. This happens extremely often in typical code.
When using reverse-mode AutoDiff (also known as Backpropagation), the system must propagate sensitivities (gradients) backward through every operation in the computation graph. This includes all the broadcasting operations that occurred during the forward pass.
So, what is the local backpropagation rule for a broadcasting operation? You guessed it! This rule, often called the vector-Jacobian product (vjp) of the broadcast, is precisely a summation (reduction) operation over the axes that were previously broadcasted. Therefore, this process is extremely common, and it’s crucial to have a quick and easy-to-use version of it. Cheers!

2 Likes

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.