Hi, Modular team.
I’d like to report a slowdown in MAX when using a single reduction over all dimensions (mean(..., axis=None)), compared to a multi‑step reduction workaround.
For example, this code is slow.
mean_a = F.mean(a, axis=None)
As a workaround, I’m using “multi-step” reduction as below.
mean_rows = F.mean(a, axis=-1)
mean_a = F.mean(mean_rows, axis=None)
I put together a minimal reproducer that compares Max(both with Module_v2 and Module_v3) and PyTorch. The script verifies outputs match and prints timing for both “single reduction” and “multi‑step reduction” in each backend, and the results on H200 is as follows:
Module_V2 results (max.nn)
shape: (1, 4096, 6144)
device: Device(type=gpu,id=0)
single reduction mean: -0.00016507644613739103
multi-level mean: -0.0001650762278586626
avg single reduction: 73.886 ms
avg multi-level: 0.146 ms
ratio (single/multi): 507.807
Module_V3 results (experimental)
shape: (1, 4096, 6144)
device: Device(type=gpu,id=0)
single reduction mean: -0.00016507644613739103
multi-level mean: -0.0001650762278586626
avg single reduction: 73.821 ms
avg multi-level: 0.141 ms
ratio (single/multi): 525.031
PyTorch results
shape: (1, 4096, 6144)
device: cuda
single reduction mean: -0.0001650762278586626
multi-level mean: -0.00016507624241057783
avg single reduction: 0.096 ms
avg multi-level: 0.090 ms
ratio (single/multi): 1.063