Overhead of Multi‑Dimensional Reduction

Hi, Modular team.

I’d like to report a slowdown in MAX when using a single reduction over all dimensions (mean(..., axis=None)), compared to a multi‑step reduction workaround.

For example, this code is slow.

mean_a = F.mean(a, axis=None)

As a workaround, I’m using “multi-step” reduction as below.

mean_rows = F.mean(a, axis=-1)
mean_a = F.mean(mean_rows, axis=None)

I put together a minimal reproducer that compares Max(both with Module_v2 and Module_v3) and PyTorch. The script verifies outputs match and prints timing for both “single reduction” and “multi‑step reduction” in each backend, and the results on H200 is as follows:

Module_V2 results (max.nn)
  shape: (1, 4096, 6144)
  device: Device(type=gpu,id=0)
  single reduction mean: -0.00016507644613739103
  multi-level mean:     -0.0001650762278586626
  avg single reduction: 73.886 ms
  avg multi-level:      0.146 ms
  ratio (single/multi): 507.807

Module_V3 results (experimental)
  shape: (1, 4096, 6144)
  device: Device(type=gpu,id=0)
  single reduction mean: -0.00016507644613739103
  multi-level mean:     -0.0001650762278586626
  avg single reduction: 73.821 ms
  avg multi-level:      0.141 ms
  ratio (single/multi): 525.031

PyTorch results
  shape: (1, 4096, 6144)
  device: cuda
  single reduction mean: -0.0001650762278586626
  multi-level mean:     -0.00016507624241057783
  avg single reduction: 0.096 ms
  avg multi-level:      0.090 ms
  ratio (single/multi): 1.063

Thanks for reporting! Will let the team know.