Code review for scientific code that already beats Fortran :)

Hi Modular folks — would love a performance-focused code review (Mojo)

Hey everyone! First off: thank you for building Mojo and for all the great discussions/examples on performance patterns. I’m learning a lot from the community, and I’d really appreciate a set of eyes on two Mojo kernels I’m hoping to polish further.

I’m posting two files for review:

  • scalar_spectrum.mojo: scalar spherical degree power spectrum on a Gaussian latitude grid

  • vector_spectrum.mojo: vector (wind) spectra → total / rotational / divergent kinetic energy by spherical degree

The main thing I’m looking for is performance review (algorithmic + Mojo-specific micro-optimizations), with a secondary focus on code quality / idiomatic Mojo.


Inputs, shapes, and data movement

Both kernels currently:

  • accept Python objects (NumPy arrays / xarray .values)

  • do a single copy into native Mojo List[Float64]

  • run all heavy loops on native buffers (no Python calls in hot loops)

Typical full case

  • Grid: nlat=800, nlon=1600 (Gaussian latitudes + uniform longitudes)

  • Scalar input: T[j,i] (Float64) and lat_deg[j]

  • Vector input: u[j,i], v[j,i] (Float64)

Truncation

  • Scalar: max_degree=800 (output degrees 1..L; mean removed so degree 0 not returned)

  • Vector: max_degree=799 (output degrees 0..L)


Math / conventions (high level)

Scalar spectrum

Gaussian nodes (\mu_j=\sin\phi_j), weights (w_j), longitudes (\lambda_i=i\Delta\lambda), (\Delta\lambda=2\pi/nlon).

  1. Zonal Fourier coefficients per latitude row (explicit real/imag sums): [ \hat T_m(\phi_j) \propto \sum_i T(\lambda_i,\phi_j),e^{-im\lambda_i} ]

  2. Meridional projection with Gauss–Legendre quadrature using 4π fully-normalized (\bar P_\ell^m(\mu)): [ C_{\ell m}, S_{\ell m} \approx \sum_j w_j,\bar P_\ell^m(\mu_j),\Re(\hat T_m),\ \Im(\hat T_m) ]

  3. Degree power (real-harmonic convention): [ E_\ell = \sum_{m=0}^{\ell}\left(C_{\ell m}^2+S_{\ell m}^2\right) ]

Notes:

  • kernel removes weighted global mean first → returns degrees 1..L

  • final scalar rescale enforces (\sum_\ell E_\ell) equals Gaussian-weighted grid variance of the de-meaned field

Vector KE split by degree

Build vorticity (\zeta) and divergence (\delta) (with care near poles): [ \zeta=\frac{\partial v/\partial\lambda-\partial(u\cos\phi)/\partial\phi}{a\cos\phi},\qquad \delta=\frac{\partial u/\partial\lambda+\partial(v\cos\phi)/\partial\phi}{a\cos\phi} ]

  • (\partial/\partial\lambda) done in Fourier space (multiply by (im))

  • (\partial/\partial\phi) uses small per-lat finite-difference stencils built on exact Gaussian nodes

Inverse Laplacian: [ \psi_{\ell m}=-\frac{a^2}{\ell(\ell+1)}\zeta_{\ell m},\qquad \chi_{\ell m}=-\frac{a^2}{\ell(\ell+1)}\delta_{\ell m} ]

Energy split: [ QE_\ell=\frac{\ell(\ell+1)}{2a^2}\sum_m|\psi_{\ell m}|^2,\quad DE_\ell=\frac{\ell(\ell+1)}{2a^2}\sum_m|\chi_{\ell m}|^2,\quad KE_\ell=QE_\ell+DE_\ell ] (plus a real-harmonic normalization factor to match the reference outputs.)


Stage timings (hot module cache, single process, full 800×1600)

Scalar (scalar_spectrum) kernel total ~2.16 s

  • Stage 0 (Python→native buffers): ~0.17 s (~8%)

  • Stage 1 (Gauss–Legendre weights via Newton iterations): ~0.007 s (<1%)

  • Stage 2 (weighted mean remove): ~0.002 s (<1%)

  • Stage 3 (weighted variance target): ~0.002 s (<1%)

  • Stage 4 (twiddle table + explicit DFTs): ~1.46 s (~68%)

  • Stage 5 (Legendre projection, streamed recurrence): ~0.51 s (~24%)

  • Stage 6 (variance rescale): ~0 s

Vector (vector_spectrum) kernel total ~2.24 s

  • Stage 0 (Python→native buffers): ~0.34 s (~15%)

  • Stage 1 (exact Gaussian nodes/weights + derivative stencils): ~0.008 s (<1%)

  • Stage 2 (Fourier precompute for U/V): ~1.07 s (~48%)

    • includes exact opposite-longitude pairing for even nlon (cuts this stage ~in half vs naive sums for retained modes)
  • Stage 3 (assemble (\zeta/\delta) in ((j,m)) space): ~0.028 s (~1%)

  • Stage 4 (Legendre projection + KE split): ~0.80 s (~36%)

So the biggest levers look like: scalar Stage 4, vector Stage 2, and Stage 0 copies.


What I’m asking for (performance-first review)

  • Algorithmic next step: keep explicit DFT sums, or switch to an FFT-based zonal transform in Mojo? Any “Mojo-idiomatic” FFT approach you’d recommend?

  • Interop / memory: best practices for reducing Stage 0 cost (zero-copy or lower-copy NumPy→Mojo) without footguns?

  • SIMD & layout: obvious opportunities to improve the inner longitude loops (SIMD/unrolling), and should I consider different buffer/layout choices than List[Float64]?

  • Parallelism: cleanest way to parallelize over latitudes and/or m blocks in Mojo for CPU scaling, while keeping determinism/stability?

  • General code quality: any red flags in staging, allocations, or recurrence precompute organization?

If you’re willing to take a look, thank you so much—happy to iterate quickly and repost updated diffs + benchmark deltas based on your feedback.

2 Likes
  • Create fld as var fld = List[Float64](capacity=nlat*nlon) to avoid resizing the buffer after initial creation. You should try to do this initial capacity for all lists, including lat_rad, mu_arr, wj, cos_mi, sin_mi, etc.
  • The LayoutTensor and TileTensor types are technically better for this task than List, but they’re also a bit messy to use right now so we can ignore that.
  • A lot of the loops could be done using vectorize, which should give some substantial performance wins. Depending on your CPU, up to 8x performance gains might be on the table.
  • You should be able to use dlpack to zero-copy the data if you use LayoutTensor or TileTensor, which should remove a bunch of data copies.
  • Parallelizing some of the per-lane loops with parallelize should avoid stability issues and should be a performance win if the input is large enough.
  • For layout, I’d look at trying to make more loads and stores contagious since CPUs are still bad at scatter/gather.

@Caroline Could you possibly enable the discourse-math plugin so that we have some form of TeX? I didn’t realize that wasn’t enabled. See Discourse Math - Plugin - Discourse Meta

Done!

1 Like

I am really sorry but for whatever reason I am not able to implement your tips Mr. Owen.

I tried for hours and read the doc but for some weird reason

var fld_ptr = fld.unsafe_ptr()

def sub_mean_vec[width: Int](idx: Int) unified {mut}:
    var val = fld_ptr.load\[width=width\](idx)
    fld_ptr.store(idx, val - mean_val)
vectorize[simd_width](nlat * nlon, sub_mean_vec)

gets my program stuck (no errors, no output, just stuck) instead of accelerating it.

Would it be too much to ask to see how you would implement your notes to at least one of the two gists (preferably the vector one).

I’ll try to take a pass at it when I have some time.

1 Like

Thank you! I can’t believe my version is 2.5 times faster than Fortan optimized code and still has so much meat left on the bone.

Can you provide some python/mojo code that generates input data? It’s going to be a bit difficult to optimize this if I can’t benchmark it and I’m not familiar with this problem domain.

Here is a Python script that generates synthetic data and uses scalar_spectrum.mojo

#!/usr/bin/env python3
from __future__ import annotations

import argparse
import importlib
from pathlib import Path
import sys
import time

import numpy as np


def gaussian_latitudes_deg(nlat: int) -> np.ndarray:
    """Return Gaussian latitude centers in degrees (south->north)."""
    mu, _weights = np.polynomial.legendre.leggauss(nlat)
    return np.degrees(np.arcsin(mu))


def make_field(
    lat_deg: np.ndarray,
    lon_deg: np.ndarray,
    *,
    seed: int,
    noise_std: float,
) -> np.ndarray:
    lat_rad = np.radians(lat_deg)[:, None]
    lon_rad = np.radians(lon_deg)[None, :]

    # Structured signal with several zonal wavenumbers.
    mode_2 = 1.30 * np.cos(lat_rad) ** 2 * np.cos(2.0 * lon_rad)
    mode_5 = 0.75 * np.sin(2.0 * lat_rad) * np.sin(5.0 * lon_rad + 0.7)
    mode_9 = 0.40 * np.cos(3.0 * lat_rad) * np.cos(9.0 * lon_rad - 0.3)

    # Broad large-scale meridional background.
    background = 0.25 * (1.5 * np.sin(lat_rad) ** 2 - 0.5)

    # Localized anomaly.
    lat0 = np.radians(35.0)
    lon0 = np.radians(120.0)
    dist2 = (lat_rad - lat0) ** 2 + (np.angle(np.exp(1j * (lon_rad - lon0)))) ** 2
    blob = 0.55 * np.exp(-dist2 / (2.0 * 0.28**2))

    rng = np.random.default_rng(seed)
    noise = rng.normal(loc=0.0, scale=noise_std, size=(lat_deg.size, lon_deg.size))

    return (background + mode_2 + mode_5 + mode_9 + blob + noise).astype(np.float64)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Generate synthetic scalar Gaussian-grid data and run the Mojo "
            "scalar_spectrum kernel."
        )
    )
    parser.add_argument(
        "--nlat",
        type=int,
        default=800,
        help="Number of Gaussian latitudes (default: 800, stress-test size).",
    )
    parser.add_argument(
        "--nlon",
        type=int,
        default=1600,
        help="Number of longitudes (default: 1600, stress-test size).",
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducible noise.")
    parser.add_argument("--noise-std", type=float, default=0.05, help="Std-dev of additive white noise.")
    parser.add_argument(
        "--max-degree",
        type=int,
        default=0,
        help="Maximum harmonic degree (0 means full nlon/2 truncation used by the Mojo kernel).",
    )
    parser.add_argument(
        "--mojo-dir",
        type=Path,
        default=None,
        help="Directory containing scalar_spectrum.mojo (default: auto-detect).",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=Path("tests/artifacts/synthetic_scalar_input.npz"),
        help="Output .npz path for the generated synthetic field.",
    )
    return parser.parse_args()


def _resolve_mojo_dir(user_mojo_dir: Path | None) -> Path:
    if user_mojo_dir is not None:
        mojo_dir = user_mojo_dir.expanduser().resolve()
        if not (mojo_dir / "scalar_spectrum.mojo").exists():
            raise FileNotFoundError(
                f"scalar_spectrum.mojo not found in --mojo-dir: {mojo_dir}"
            )
        return mojo_dir

    script_dir = Path(__file__).resolve().parent
    candidates = [
        script_dir / "mojo",
        script_dir.parent / "mojo",
        Path.cwd() / "mojo",
        Path.cwd(),
    ]
    for cand in candidates:
        if (cand / "scalar_spectrum.mojo").exists():
            return cand.resolve()
    raise FileNotFoundError(
        "Could not locate scalar_spectrum.mojo automatically. "
        "Pass --mojo-dir <path>."
    )


def _load_scalar_kernel(mojo_dir: Path):
    if str(mojo_dir) not in sys.path:
        sys.path.insert(0, str(mojo_dir))
    import mojo.importer  # noqa: F401  # pyright: ignore[reportMissingImports]

    return importlib.import_module("scalar_spectrum")


def main() -> None:
    args = parse_args()
    if args.nlat < 4:
        raise ValueError("nlat must be >= 4.")
    if args.nlon < 8:
        raise ValueError("nlon must be >= 8.")
    if args.nlon % 2 != 0:
        raise ValueError("nlon must be even for typical spectral truncation use.")
    if args.noise_std < 0.0:
        raise ValueError("noise-std must be >= 0.")
    if args.max_degree < 0:
        raise ValueError("max-degree must be >= 0.")

    lat_deg = gaussian_latitudes_deg(args.nlat)
    lon_deg = np.linspace(0.0, 360.0, args.nlon, endpoint=False, dtype=np.float64)
    field = make_field(lat_deg, lon_deg, seed=args.seed, noise_std=args.noise_std)
    # Mirror the Mojo kernel behavior: if max_degree <= 0, use nlat then clamp to nlon/2.
    # For the stress-test default (800x1600), this gives max_degree=800.
    kernel_limit = args.nlon // 2
    full_kernel_degree = min(args.nlat, kernel_limit)
    suggested_max_degree = min(args.nlat - 1, kernel_limit)
    effective_max_degree = full_kernel_degree if args.max_degree == 0 else min(args.max_degree, full_kernel_degree)

    args.output.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(
        args.output,
        field=field,
        lat_deg=lat_deg,
        lon_deg=lon_deg,
        nlat=np.int64(args.nlat),
        nlon=np.int64(args.nlon),
        seed=np.int64(args.seed),
        noise_std=np.float64(args.noise_std),
        suggested_max_degree=np.int64(suggested_max_degree),
        full_kernel_degree=np.int64(full_kernel_degree),
        effective_max_degree=np.int64(effective_max_degree),
    )

    mojo_dir = _resolve_mojo_dir(args.mojo_dir)

    print(f"Wrote synthetic payload: {args.output}")
    print(f"field shape: {field.shape}, dtype={field.dtype}")
    print(f"suggested_max_degree (API-safe): {suggested_max_degree}")
    print(f"full_kernel_degree (Mojo max): {full_kernel_degree}")
    print(f"loading scalar_spectrum.mojo from: {mojo_dir}")
    print(f"running scalar_spectrum with max_degree={effective_max_degree} ...")

    scalar_mod = _load_scalar_kernel(mojo_dir)
    t0 = time.perf_counter()
    power = np.asarray(
        scalar_mod.scalar_spectrum(
            np.ascontiguousarray(field, dtype=np.float64),
            np.ascontiguousarray(lat_deg, dtype=np.float64),
            int(effective_max_degree),
        ),
        dtype=np.float64,
    )
    dt = time.perf_counter() - t0

    print(f"kernel runtime: {dt:.3f} s")
    print(f"output spectrum length: {power.size}")
    print(f"power sum: {power.sum():.6e}")
    print(f"power[0:5]: {power[:5]}")


if __name__ == "__main__":
    main()

[✓][2026-03-26 15:09:46][USER@HOST][mojo-spectrum]-[36610748.HOST-01-ib] {module-build}  - 22s 
ρ pixi run python "scripts/generate_synthetic_scalar_input.py" --nlat 800 --nlon 1600 --max-degree 800
⠁                                                                                                                 
⠁ activating environment                                                                                          
⠁ activating environment                                                                                          
Wrote synthetic payload: tests/artifacts/synthetic_scalar_input.npz
field shape: (800, 1600), dtype=float64
suggested_max_degree (API-safe): 799
full_kernel_degree (Mojo max): 800
loading scalar_spectrum.mojo from: /fs/site8/eccc/cmd/cmds/yor000/gitlab.science.gc.ca/yor000/mojo-experiments/tke-mojo/mojo
running scalar_spectrum with max_degree=800 ...
[mojo] grid 800 x 1600  max_degree= 800
[mojo] Stage 0: extracting field to native buffers ...
[mojo] Stage 0 time: 0.163210318016354 s
[mojo] Stage 1: computing Gauss-Legendre weights ...
[mojo] Stage 1 time: 0.006960558996070176 s
[mojo] Stage 2: area-weighted mean ...
[mojo] Stage 2 time: 0.0018916879780590534 s
[mojo] Stage 3: weighted variance target ...
[mojo] Stage 3 time: 0.0017297330196015537 s
[mojo] Stage 4: Fourier precompute (twiddle table + nlat x n_m DFTs) ...
[mojo] Stage 4 time: 1.4660470120143145 s
[mojo] Stage 5: Legendre projection (streaming recurrence) ...
[mojo] Stage 5 time: 0.5123696019873023 s
[mojo] Stage 6: variance rescaling ...
[mojo] Stage 6 time: 1.5160185284912586e-06 s
[mojo] total kernel time: 2.152271553990431 s
[mojo] done.
kernel runtime: 2.153 s
output spectrum length: 800
power sum: 6.405006e-01
power[0:5]: [0.00076523 0.45140083 0.00095973 0.00074946 0.000495  ]

EDIT : By the way I keep trying to make vectorize work and Opus 4.6, Gemini 3.1 Pro and Composer 2 all give the same cute message that starts with “Ahhh the joys of writing code with a bleeding edge language…” and then proceed to ignore the Nightly documentation I give it :slight_smile:

This runs ~3x faster for me, although it’s running into a few fun numeric precision issues that you’re better equipped to deal with as the domain person. I only vectorized stage 4 since that was the place to show the biggest impact, but it should be reasonable to apply that pattern elsewhere. I also fixed the list capacity, which was worth ~15% on its own. Switching to LayoutTensor and bringing in parallelize should help if you want even more.

from std.math import cos, sin, abs, pi, sqrt, iota
from std.python import Python, PythonObject
from std.python.bindings import PythonModuleBuilder
from std.os import abort
from std.collections import List
from std.algorithm import vectorize
from std.sys.info import simd_width_of


def _py_getattr(obj: PythonObject, name: String) raises -> PythonObject:
    return obj.__getattr__(name)


def _py_call0(f: PythonObject) raises -> PythonObject:
    return f.__call__()


def _py_call1(f: PythonObject, a: PythonObject) raises -> PythonObject:
    return f.__call__(a)


def _now_seconds() raises -> Float64:
    var time_mod = Python.import_module("time")
    return Float64(py=_py_call0(_py_getattr(time_mod, "perf_counter")))


def _norm_assoc_legendre(l: Int, m: Int, mu: Float64) -> Float64:
    """4-pi fully-normalized associated Legendre function.

    Includes the (2-delta_m0) geodesy convention so that:
        (1/4pi) integral [Pbar_l^m cos(m phi)]^2 dOmega = 1
    Values stay O(1) for all l, m, preventing the exponential
    overflow of the unnormalized recurrence.
    """
    if m > l:
        return Float64(0.0)

    var one_minus = Float64(1.0) - mu * mu
    if one_minus < Float64(0.0):
        one_minus = Float64(0.0)
    var somx2 = sqrt(one_minus)

    # Start the recurrence at P_m^m. The normalization keeps the basis
    # O(1) across degree, which is critical for high-l stability.
    var pmm = Float64(1.0)
    if m >= 1:
        pmm = sqrt(Float64(3.0)) * somx2
        for mm in range(2, m + 1):
            pmm = sqrt(Float64(2 * mm + 1) / Float64(2 * mm)) * somx2 * pmm

    if l == m:
        return pmm

    var pmmp1 = sqrt(Float64(2 * m + 3)) * mu * pmm
    if l == m + 1:
        return pmmp1

    # Three-term recurrence in l for fixed m.
    var plm2 = pmm
    var plm1 = pmmp1
    var pl = Float64(0.0)
    for ll in range(m + 2, l + 1):
        var ll_sq = Float64(ll * ll)
        var m_sq = Float64(m * m)
        var twoL1 = Float64(2 * ll + 1)
        var alpha = sqrt((Float64(4) * ll_sq - Float64(1)) / (ll_sq - m_sq))
        var lm1_sq = Float64((ll - 1) * (ll - 1))
        var beta = sqrt(
            twoL1 * (lm1_sq - m_sq) / ((Float64(2 * ll - 3)) * (ll_sq - m_sq))
        )
        pl = alpha * mu * plm1 - beta * plm2
        plm2 = plm1
        plm1 = pl
    return pl


def _gauss_legendre_weights(nlat: Int) -> List[Float64]:
    """Compute exact Gauss-Legendre quadrature weights via Newton-Raphson.

    Finds the N roots of P_N(mu) to machine precision, then applies
    w_j = 2(1-mu_j^2) / (N^2 * P_{N-1}(mu_j)^2).
    Weights are returned in ascending mu order (south-to-north).
    """
    var wj = List[Float64](capacity=nlat)
    for _ in range(nlat):
        wj.append(Float64(0.0))

    var n = nlat
    var n_f = Float64(n)
    var nhalf = (n + 1) / 2

    for i in range(nhalf):
        # Roots are symmetric about the equator, so solve only the southern
        # half and mirror the result.
        var theta = (
            Float64(pi) * (Float64(i) + Float64(0.75)) / (n_f + Float64(0.5))
        )
        var mu = cos(theta)

        var pnm1 = Float64(0.0)
        for _ in range(200):
            # Evaluate P_N(mu) and P_{N-1}(mu) together so Newton-Raphson can
            # refine the Gaussian node to near machine precision.
            var p0 = Float64(1.0)
            var p1 = mu
            for k in range(2, n + 1):
                var k_f = Float64(k)
                var p2 = (
                    (Float64(2) * k_f - Float64(1)) * mu * p1
                    - (k_f - Float64(1)) * p0
                ) / k_f
                p0 = p1
                p1 = p2
            pnm1 = p0
            var pn = p1
            var denom_d = mu * mu - Float64(1.0)
            if abs(denom_d) < Float64(1e-30):
                break
            var dpn = n_f * (mu * pn - pnm1) / denom_d
            if abs(dpn) < Float64(1e-30):
                break
            var delta = pn / dpn
            mu = mu - delta
            if abs(delta) < Float64(1e-15):
                break

        # Final P_{N-1} at converged root for the weight formula.
        var p0f = Float64(1.0)
        var p1f = mu
        for k in range(2, n + 1):
            var k_f = Float64(k)
            var p2f = (
                (Float64(2) * k_f - Float64(1)) * mu * p1f
                - (k_f - Float64(1)) * p0f
            ) / k_f
            p0f = p1f
            p1f = p2f
        pnm1 = p0f
        var one_m_mu2 = Float64(1.0) - mu * mu
        if one_m_mu2 < Float64(1e-30):
            one_m_mu2 = Float64(1e-30)
        var wd = n_f * n_f * pnm1 * pnm1
        if wd < Float64(1e-30):
            wd = Float64(1e-30)
        var wt = Float64(2.0) * one_m_mu2 / wd

        wj[i] = wt
        wj[n - 1 - i] = wt

    return wj^


comptime deg_to_rad = Float64(pi / 180.0)

@always_inline
def scalar_spectrum_stage_0_convert_fld(
    nlat: Int, nlon: Int, field: PythonObject
) raises -> List[Float64]:
    var fld = List[Float64](capacity=nlat * nlon)
    for j in range(nlat):
        for i in range(nlon):
            fld.append(Float64(py=field[j, i]))
    return fld^

@always_inline
def scalar_spectrum_stage_0_convert_lat_deg(
    nlat: Int, lat_deg: PythonObject
) raises -> List[Float64]:
    var lat_rad = List[Float64](capacity=nlat)
    for j in range(nlat):
        lat_rad.append(Float64(py=lat_deg[j]) * deg_to_rad)
    return lat_rad^

@always_inline
def scalar_spectrum_stage_4_scalar(
    nlat: Int,
    nlon: Int,
    n_m: Int,
    inv_nlon: Float64,
    fld: Span[mut=False, Float64, ...],
    cos_mi: Span[mut=False, Float64, ...],
    sin_mi: Span[mut=False, Float64, ...],
    re_jm: Span[mut=True, Float64, ...],
    im_jm: Span[mut=True, Float64, ...],
):
    for j in range(nlat):
        var fld_off = j * nlon
        var jm_base = j * n_m
        for m in range(n_m):
            var re = Float64(0.0)
            var im = Float64(0.0)
            var mi_base = m * nlon
            for i in range(nlon):
                var x = fld[fld_off + i]
                re = re + x * cos_mi[mi_base + i]
                im = im - x * sin_mi[mi_base + i]
            re_jm[jm_base + m] = re * inv_nlon
            im_jm[jm_base + m] = im * inv_nlon

@always_inline
def scalar_spectrum_stage_4_vector(
    nlat: Int,
    nlon: Int,
    n_m: Int,
    inv_nlon: Float64,
    fld: Span[mut=False, Float64, ...],
    cos_mi: Span[mut=False, Float64, ...],
    sin_mi: Span[mut=False, Float64, ...],
    re_jm: Span[mut=True, Float64, ...],
    im_jm: Span[mut=True, Float64, ...],
):
    for j in range(nlat):
        var fld_off = j * nlon
        var jm_base = j * n_m

        def kernel[width: Int](m: Int) unified {mut}:
            var re = SIMD[DType.float64, width](0.0)
            var im = SIMD[DType.float64, width](0.0)
            var mi_base = m * nlon

            for i in range(nlon):
                var x = fld.unsafe_ptr().load[width=width](fld_off + i)
                re = re + x * cos_mi.unsafe_ptr().load[width=width](mi_base + i)
                im = im - x * sin_mi.unsafe_ptr().load[width=width](mi_base + i)

            re_jm.unsafe_ptr().store[width=width](jm_base + m, re * inv_nlon)
            im_jm.unsafe_ptr().store[width=width](jm_base + m, im * inv_nlon)

        vectorize[simd_width_of[DType.float64](), unroll_factor=2](n_m, kernel)


def scalar_spectrum(
    field: PythonObject, lat_deg: PythonObject, max_degree_obj: PythonObject
) raises -> PythonObject:
    var t_total0 = _now_seconds()
    var nlat = Int(py=_py_getattr(field, "shape")[0])
    var nlon = Int(py=_py_getattr(field, "shape")[1])
    var max_degree = Int(py=max_degree_obj)
    if max_degree <= 0:
        max_degree = nlat
    if max_degree > nlon / 2:
        max_degree = nlon / 2

    comptime eps = Float64(1.0e-30)

    var dlon = Float64(2.0 * pi) / Float64(nlon)
    var inv_nlon = Float64(1.0) / Float64(nlon)

    print(
        "[mojo] grid",
        nlat,
        "x",
        nlon,
        " max_degree=",
        max_degree,
    )

    # ── Stage 0: Extract Python data into native Mojo buffers ──
    # Keep all Python interop here. Every later stage runs on native Mojo
    # buffers so the hot loops are pure numeric code.
    print("[mojo] Stage 0: extracting field to native buffers ...")
    var t0 = _now_seconds()
    var fld = scalar_spectrum_stage_0_convert_fld(nlat, nlon, field)
    var lat_rad = scalar_spectrum_stage_0_convert_lat_deg(nlat, lat_deg)
    print("[mojo] Stage 0 time:", _now_seconds() - t0, "s")

    # ── Stage 1: mu values and Gauss-Legendre quadrature weights ──
    # For a Gaussian grid, the latitude integral is naturally written in
    # mu = sin(phi), with exact Gauss-Legendre weights.
    print("[mojo] Stage 1: computing Gauss-Legendre weights ...")
    var t1 = _now_seconds()
    var mu_arr = List[Float64](capacity=nlat)
    for j in range(nlat):
        mu_arr.append(sin(lat_rad[j]))
    var wj_arr = _gauss_legendre_weights(nlat)
    print("[mojo] Stage 1 time:", _now_seconds() - t1, "s")

    # ── Stage 2: Area-weighted mean using Gaussian weights ──
    # The scalar reference spectrum is for anomaly power, so remove the
    # weighted global mean before harmonic analysis.
    print("[mojo] Stage 2: area-weighted mean ...")
    var t2 = _now_seconds()
    var mean_num = Float64(0.0)
    var mean_den = Float64(0.0)
    for j in range(nlat):
        var w = wj_arr[j]
        mean_den = mean_den + w
        var off = j * nlon
        var row_sum = Float64(0.0)
        for i in range(nlon):
            row_sum = row_sum + fld[off + i]
        mean_num = mean_num + w * row_sum * inv_nlon
    var mean_val = Float64(0.0)
    if mean_den > eps:
        mean_val = mean_num / mean_den

    for idx in range(nlat * nlon):
        fld[idx] = fld[idx] - mean_val
    print("[mojo] Stage 2 time:", _now_seconds() - t2, "s")

    # ── Stage 3: Weighted variance target using Gaussian weights ──
    # This target is used at the end to preserve the total variance seen in
    # grid space after mean removal.
    print("[mojo] Stage 3: weighted variance target ...")
    var t3 = _now_seconds()
    var sum_w = Float64(0.0)
    var var_target = Float64(0.0)
    for j in range(nlat):
        var w = wj_arr[j]
        sum_w = sum_w + w
        var row_var = Float64(0.0)
        var off = j * nlon
        for i in range(nlon):
            var x = fld[off + i]
            row_var = row_var + x * x
        row_var = row_var * inv_nlon
        var_target = var_target + w * row_var
    if sum_w > eps:
        var_target = var_target / sum_w
    print("[mojo] Stage 3 time:", _now_seconds() - t3, "s")

    # ── Stage 4: Precompute Fourier coefficients re(j,m) / im(j,m) ──
    print(
        "[mojo] Stage 4: Fourier precompute (twiddle table + nlat x n_m"
        " DFTs) ..."
    )
    var t4 = _now_seconds()
    # Factorises O(L^2 * nlat * nlon) into O(M * nlon) + O(M * nlat * nlon) + O(L^2 * nlat).
    var n_m = max_degree + 1
    var cos_mi = List[Float64](capacity=nlat * nlon)
    var sin_mi = List[Float64](capacity=nlat * nlon)
    for m in range(n_m):
        for i in range(nlon):
            # Twiddle table: compute cos/sin once and reuse for every latitude.
            var angle = Float64(m) * Float64(i) * dlon
            cos_mi.append(cos(angle))
            sin_mi.append(sin(angle))

    var re_jm = List[Float64](length=nlat * n_m, fill=Float64(0.0))
    var im_jm = List[Float64](length=nlat * n_m, fill=Float64(0.0))

    scalar_spectrum_stage_4_vector(
        nlat,
        nlon,
        n_m,
        inv_nlon,
        fld,
        cos_mi,
        sin_mi,
        re_jm,
        im_jm,
    )

    print("[mojo] Stage 4 time:", _now_seconds() - t4, "s")

    # ── Stage 5: Legendre projection → spectral power per degree ──
    print("[mojo] Stage 5: Legendre projection (streaming recurrence) ...")
    var t5 = _now_seconds()
    # 4-pi normalization: E_l = sum_{m=0}^l (C_lm^2 + S_lm^2)
    var power = List[Float64](length=max_degree, fill=Float64(0.0))

    # Stage 5 setup: precompute the m-only and (l,m)-only recurrence factors.
    # This moves sqrt-heavy setup out of the innermost projection loop.

    var pmm_const = List[Float64](capacity=n_m)
    var pmmp1_scale = List[Float64](capacity=n_m)

    for m in range(n_m):
        var coeff = Float64(1.0)
        if m >= 1:
            coeff = sqrt(Float64(3.0))
            for mm in range(2, m + 1):
                coeff = coeff * sqrt(Float64(2 * mm + 1) / Float64(2 * mm))
        pmm_const.append(coeff)
        pmmp1_scale.append(sqrt(Float64(2 * m + 3)))

    var somx2_pow_jm = List[Float64](capacity=nlat * (n_m - 1))
    for j in range(nlat):
        var one_minus = Float64(1.0) - mu_arr[j] * mu_arr[j]
        if one_minus < Float64(0.0):
            one_minus = Float64(0.0)
        var somx2 = sqrt(one_minus)
        var som_pow = Float64(1.0)
        somx2_pow_jm.append(som_pow)
        for _ in range(1, n_m):
            som_pow = som_pow * somx2
            somx2_pow_jm.append(som_pow)

    var alpha_lm = List[Float64](length=n_m * n_m, fill=Float64(0.0))
    var beta_lm = List[Float64](length=n_m * n_m, fill=Float64(0.0))

    for m in range(0, max_degree + 1):
        var m_sq = Float64(m * m)
        for l in range(m + 2, max_degree + 1):
            var ll_sq = Float64(l * l)
            var twoL1 = Float64(2 * l + 1)
            var lm1_sq = Float64((l - 1) * (l - 1))
            var lm_idx = l * n_m + m
            alpha_lm[lm_idx] = sqrt(
                (Float64(4) * ll_sq - Float64(1)) / (ll_sq - m_sq)
            )
            beta_lm[lm_idx] = sqrt(
                twoL1
                * (lm1_sq - m_sq)
                / ((Float64(2 * l - 3)) * (ll_sq - m_sq))
            )

    for m in range(0, max_degree + 1):
        # For fixed m, accumulate every degree l >= m in one latitude pass.
        # This is the main scalar analogue of the classic Fourier-Legendre
        # transform structure used by Gaussian-grid spectral models.
        var c_by_l = List[Float64](length=max_degree + 1, fill=Float64(0.0))
        var s_by_l = List[Float64](length=max_degree + 1, fill=Float64(0.0))

        for j in range(nlat):
            var idx = j * n_m + m
            var wre = wj_arr[j] * re_jm[idx]
            var wim = wj_arr[j] * im_jm[idx]
            var mu = mu_arr[j]
            # P_m^m(mu_j): factored into an m-only constant and a j,m power.
            var pmm = pmm_const[m] * somx2_pow_jm[idx]

            if m >= 1 and m <= max_degree:
                c_by_l[m] = c_by_l[m] + pmm * wre
                s_by_l[m] = s_by_l[m] + pmm * wim

            if m < max_degree:
                var pmmp1 = pmmp1_scale[m] * mu * pmm
                var lp1 = m + 1
                if lp1 >= 1:
                    c_by_l[lp1] = c_by_l[lp1] + pmmp1 * wre
                    s_by_l[lp1] = s_by_l[lp1] + pmmp1 * wim

                var plm2 = pmm
                var plm1 = pmmp1
                for l in range(m + 2, max_degree + 1):
                    var lm_idx = l * n_m + m
                    var alpha = alpha_lm[lm_idx]
                    var beta = beta_lm[lm_idx]
                    var pl = alpha * mu * plm1 - beta * plm2
                    c_by_l[l] = c_by_l[l] + pl * wre
                    s_by_l[l] = s_by_l[l] + pl * wim
                    plm2 = plm1
                    plm1 = pl

        var l_start = m
        if l_start < 1:
            l_start = 1
        for l in range(l_start, max_degree + 1):
            power[l - 1] = (
                power[l - 1] + c_by_l[l] * c_by_l[l] + s_by_l[l] * s_by_l[l]
            )
    print("[mojo] Stage 5 time:", _now_seconds() - t5, "s")

    # ── Stage 6: Variance-preserving rescaling ──
    # After the harmonic projection, apply one scalar so the summed spectral
    # power matches the weighted grid-space variance from Stage 3.
    print("[mojo] Stage 6: variance rescaling ...")
    var t6 = _now_seconds()
    var total = Float64(0.0)

    for l in range(max_degree):
        total = total + power[l]
    if total > eps and var_target > eps:
        var scale = var_target / total
        for l in range(max_degree):
            power[l] = power[l] * scale
    print("[mojo] Stage 6 time:", _now_seconds() - t6, "s")

    print("[mojo] total kernel time:", _now_seconds() - t_total0, "s")
    print("[mojo] done.")
    # ── Return as Python list ──
    var builtins = Python.import_module("builtins")
    var out = _py_call0(_py_getattr(builtins, "list"))
    for l in range(max_degree):
        _ = _py_call1(_py_getattr(out, "append"), PythonObject(power[l]))
    return out


@export
def PyInit_scalar_spectrum() -> PythonObject:
    try:
        var m = PythonModuleBuilder("scalar_spectrum")
        m.def_function[scalar_spectrum]("scalar_spectrum")
        return m.finalize()
    except e:
        abort(String("error creating scalar_spectrum module:", e))
[✗]-[2026-03-26 18:16:30][USER@HOST][tke-mojo]-[36722983.HOST-01-ib] {module-build}  - 14s 
ω pixi run python scripts/generate_synthetic_scalar_input.py --kernel-module scalar_spectrum_owen --mojo-dir mojo
Wrote synthetic payload: tests/artifacts/synthetic_scalar_input.npz
field shape: (800, 1600), dtype=float64
suggested_max_degree (API-safe): 799
full_kernel_degree (Mojo max): 800
loading scalar_spectrum_owen.mojo from: tke-mojo/mojo
running kernel with max_degree=800 ...
using exported function: scalar_spectrum_owen
[mojo] grid 800 x 1600  max_degree= 800
[mojo] Stage 0: extracting field to native buffers ...
[mojo] Stage 0 time: 0.1649138819775544 s
[mojo] Stage 1: computing Gauss-Legendre weights ...
[mojo] Stage 1 time: 0.006960193975828588 s
[mojo] Stage 2: area-weighted mean ...
[mojo] Stage 2 time: 0.0017336169839836657 s
[mojo] Stage 3: weighted variance target ...
[mojo] Stage 3 time: 0.0017265569767914712 s
[mojo] Stage 4: Fourier precompute (twiddle table + nlat x n_m DFTs) ...
[mojo] Stage 4 time: 0.28961051296209916 s
[mojo] Stage 5: Legendre projection (streaming recurrence) ...
[mojo] Stage 5 time: 0.42164504603715613 s
[mojo] Stage 6: variance rescaling ...
[mojo] Stage 6 time: 1.468986738473177e-06 s
[mojo] total kernel time: 0.8866714269970544 s
[mojo] done.
kernel runtime: 0.887 s
output spectrum length: 800
power sum: 6.405006e-01
power[0:5]: [0.00652359 0.05660658 0.09016741 0.0738675  0.06458382]

[✗]-[2026-03-26 18:16:30][USER@HOST][tke-mojo]-[36722983.HOST-01-ib] {module-build}  - 10s 
𝔖 pixi run python scripts/generate_synthetic_scalar_input.py --kernel-module scalar_spectrum --mojo-dir mojo
Wrote synthetic payload: tests/artifacts/synthetic_scalar_input.npz
field shape: (800, 1600), dtype=float64
suggested_max_degree (API-safe): 799
full_kernel_degree (Mojo max): 800
loading scalar_spectrum.mojo from: /fs/site8/eccc/cmd/cmds/yor000/gitlab.science.gc.ca/yor000/mojo-experiments/tke-mojo/mojo
running kernel with max_degree=800 ...
using exported function: scalar_spectrum
[mojo] grid 800 x 1600  max_degree= 800
[mojo] Stage 0: extracting field to native buffers ...
[mojo] Stage 0 time: 0.16240417695371434 s
[mojo] Stage 1: computing Gauss-Legendre weights ...
[mojo] Stage 1 time: 0.006977727985940874 s
[mojo] Stage 2: area-weighted mean ...
[mojo] Stage 2 time: 0.0018948979559354484 s
[mojo] Stage 3: weighted variance target ...
[mojo] Stage 3 time: 0.0017366730025969446 s
[mojo] Stage 4: Fourier precompute (twiddle table + nlat x n_m DFTs) ...
[mojo] Stage 4 time: 1.5165774180204608 s
[mojo] Stage 5: Legendre projection (streaming recurrence) ...
[mojo] Stage 5 time: 0.5460519670159556 s
[mojo] Stage 6: variance rescaling ...
[mojo] Stage 6 time: 1.4870311133563519e-06 s
[mojo] total kernel time: 2.235698721022345 s
[mojo] done.
kernel runtime: 2.236 s
output spectrum length: 800
power sum: 6.405006e-01
power[0:5]: [0.00076523 0.45140083 0.00095973 0.00074946 0.000495  ]

0.887 s vs 2.236 s => 2.5x speedup indeed!! Thank you for this Mr. Owen.

Do you think a future version of the language or if I ship this to the GPU would gain me another performance bump or is everything else a minor optimization, other than the parallelize and LayoutTensor you mentionned?

If someone ever figures out how to do autovec nicely, then maybe we can get that. There’s probably some more data layout stuff that would help too. Shipping to the GPU is going to depend on what kind of input size you have, since right now you’d waste a lot of time getting the GPU ready. It would still probably be a performance win, but if you increased the data volume by a lot it would help more. LayoutTensor gives you zero copy, which basically removes the cost of stage 0.

There’s always more stuff, but I think more vectorization and zero copy are the low hanging fruits here. Parallelization will help with it as well once you have more input data.

Is it me or should something like this be part of the [Numojo](NuMojo/docs/roadmap.md at main · Mojo-Numerics-and-Algorithms-group/NuMojo) project?

I know it’s a strange question because NumPy won naturally but what is the likelihood that Numojois THE module (mojo modules should be called mojules :winking_face_with_tongue: ) to use to calculate FFTs and such?

It seems like the blocker for NuMojois the lack of parametrized traits and GPU support and Advanced Memory Management, so is this considered and incentive for the Mojoteam to work on those parts to unlock the “bedrock” of Mojo ecosystem (as NumPy is for Python)?

GitHub - martinvuyk/hackathon-fft: A State Of The Art Fast Fourier Transform implementation written in Mojo · GitHub is the current fastest FFT implementation for Mojo.

A lot of the stuff you’re looking for lives in MAX, and I would recommend moving to TileTensor if you can. I know having python involved isn’t fantastic, but if python sets up the graph and then only does something once every 5 minutes to kick off a new MAX graph iteration, then it shouldn’t be that bad of a performance issue.

Ok I tried to LLM this last feedback point by Mr. @owenhilyard and I am still confused.

In my mind model of the Modular stack, Max was an inference framework with some awesome Mojo hooks for optimization, but the comment suggests I should be thinking about Max a bit like I think about Mojo now.

If I want a blazingly fast Pytbon module for atmospheric sciences backed by Modular, I should write a Python module that calls a Max Graph which registers operations that are highly specialized Mojo kerbels like the one’s I wrote for scalar and vector field spectra so that Max can distribute the compute and Pytbon has to only worry about “parametrizing” the Max Graph operations calls with the Xarray Dataarrays and metadata needed?

MAX is marketed for inference, but it’s a general purpose way to express compute provided you’re willing to write some kernels, in the same way the CUDA sees heavy use in AI but you can also use it to do weather simulations. I’ve made some use of MAX for HPC and it tends to blow away C++ with xtensor.

Generally, you’d want to move stuff into mojo/MAX datatypes since round-trips to python are expensive, but you should see some pretty nice performance wins.

Ok ok this is brilliant, do you know of a project that uses an architecture like this that I can look to?

I presented the mojo code above to some of my colleagues and I think I’ll rewrite a bunch of our internal utilities (interpolation and scientific calculations especially) in mojo, but I want to structure it in a way that will be amenable to the eventual 1.0 and 2.0 release of Mojo and Max platform.

We’ll open source them of course especially because Python has no interpolation solid alternative to GDAL.

Sadly almost all of that code is not mine to share. However, you may find the max custom ops examples helpful since they’re less ML flavored.

Oh so there is such a module but it is not open source yet?

Do you think the people who created would mind sharing their insights with me privately?

I’m not a Modular employee, I’m actually a PhD student at a university, and I took a crack at porting some plasma physics workloads from xtensor over to MAX as a trial to see if it yielded benefits. It’s a MAX port of a fairly proprietary library distributed in source form, so I can’t share it. All that I can share is that it works REALLY well and we saw simulation times drop by substantial amounts. A lot of the problem decomposes well into linear algebra (hence the need for a BLAS library), so the highly tuned linear algebra routines in MAX work very well. The graph compiler also helps a lot with cache locality thanks to loop fusion, with one smaller simulation actually seeing a 200x speedup after MAX fused most of the loops.

Right now they’re waiting for 1.0 before making investments, since I basically just showed up and said “hey, look at this”, but I’d advise anyone in HPC to take a very hard look at MAX given that it has thus far been able to get to “good enough” performance with very minimal tuning.

1 Like