This runs ~3x faster for me, although it’s running into a few fun numeric precision issues that you’re better equipped to deal with as the domain person. I only vectorized stage 4 since that was the place to show the biggest impact, but it should be reasonable to apply that pattern elsewhere. I also fixed the list capacity, which was worth ~15% on its own. Switching to LayoutTensor and bringing in parallelize should help if you want even more.
from std.math import cos, sin, abs, pi, sqrt, iota
from std.python import Python, PythonObject
from std.python.bindings import PythonModuleBuilder
from std.os import abort
from std.collections import List
from std.algorithm import vectorize
from std.sys.info import simd_width_of
def _py_getattr(obj: PythonObject, name: String) raises -> PythonObject:
return obj.__getattr__(name)
def _py_call0(f: PythonObject) raises -> PythonObject:
return f.__call__()
def _py_call1(f: PythonObject, a: PythonObject) raises -> PythonObject:
return f.__call__(a)
def _now_seconds() raises -> Float64:
var time_mod = Python.import_module("time")
return Float64(py=_py_call0(_py_getattr(time_mod, "perf_counter")))
def _norm_assoc_legendre(l: Int, m: Int, mu: Float64) -> Float64:
"""4-pi fully-normalized associated Legendre function.
Includes the (2-delta_m0) geodesy convention so that:
(1/4pi) integral [Pbar_l^m cos(m phi)]^2 dOmega = 1
Values stay O(1) for all l, m, preventing the exponential
overflow of the unnormalized recurrence.
"""
if m > l:
return Float64(0.0)
var one_minus = Float64(1.0) - mu * mu
if one_minus < Float64(0.0):
one_minus = Float64(0.0)
var somx2 = sqrt(one_minus)
# Start the recurrence at P_m^m. The normalization keeps the basis
# O(1) across degree, which is critical for high-l stability.
var pmm = Float64(1.0)
if m >= 1:
pmm = sqrt(Float64(3.0)) * somx2
for mm in range(2, m + 1):
pmm = sqrt(Float64(2 * mm + 1) / Float64(2 * mm)) * somx2 * pmm
if l == m:
return pmm
var pmmp1 = sqrt(Float64(2 * m + 3)) * mu * pmm
if l == m + 1:
return pmmp1
# Three-term recurrence in l for fixed m.
var plm2 = pmm
var plm1 = pmmp1
var pl = Float64(0.0)
for ll in range(m + 2, l + 1):
var ll_sq = Float64(ll * ll)
var m_sq = Float64(m * m)
var twoL1 = Float64(2 * ll + 1)
var alpha = sqrt((Float64(4) * ll_sq - Float64(1)) / (ll_sq - m_sq))
var lm1_sq = Float64((ll - 1) * (ll - 1))
var beta = sqrt(
twoL1 * (lm1_sq - m_sq) / ((Float64(2 * ll - 3)) * (ll_sq - m_sq))
)
pl = alpha * mu * plm1 - beta * plm2
plm2 = plm1
plm1 = pl
return pl
def _gauss_legendre_weights(nlat: Int) -> List[Float64]:
"""Compute exact Gauss-Legendre quadrature weights via Newton-Raphson.
Finds the N roots of P_N(mu) to machine precision, then applies
w_j = 2(1-mu_j^2) / (N^2 * P_{N-1}(mu_j)^2).
Weights are returned in ascending mu order (south-to-north).
"""
var wj = List[Float64](capacity=nlat)
for _ in range(nlat):
wj.append(Float64(0.0))
var n = nlat
var n_f = Float64(n)
var nhalf = (n + 1) / 2
for i in range(nhalf):
# Roots are symmetric about the equator, so solve only the southern
# half and mirror the result.
var theta = (
Float64(pi) * (Float64(i) + Float64(0.75)) / (n_f + Float64(0.5))
)
var mu = cos(theta)
var pnm1 = Float64(0.0)
for _ in range(200):
# Evaluate P_N(mu) and P_{N-1}(mu) together so Newton-Raphson can
# refine the Gaussian node to near machine precision.
var p0 = Float64(1.0)
var p1 = mu
for k in range(2, n + 1):
var k_f = Float64(k)
var p2 = (
(Float64(2) * k_f - Float64(1)) * mu * p1
- (k_f - Float64(1)) * p0
) / k_f
p0 = p1
p1 = p2
pnm1 = p0
var pn = p1
var denom_d = mu * mu - Float64(1.0)
if abs(denom_d) < Float64(1e-30):
break
var dpn = n_f * (mu * pn - pnm1) / denom_d
if abs(dpn) < Float64(1e-30):
break
var delta = pn / dpn
mu = mu - delta
if abs(delta) < Float64(1e-15):
break
# Final P_{N-1} at converged root for the weight formula.
var p0f = Float64(1.0)
var p1f = mu
for k in range(2, n + 1):
var k_f = Float64(k)
var p2f = (
(Float64(2) * k_f - Float64(1)) * mu * p1f
- (k_f - Float64(1)) * p0f
) / k_f
p0f = p1f
p1f = p2f
pnm1 = p0f
var one_m_mu2 = Float64(1.0) - mu * mu
if one_m_mu2 < Float64(1e-30):
one_m_mu2 = Float64(1e-30)
var wd = n_f * n_f * pnm1 * pnm1
if wd < Float64(1e-30):
wd = Float64(1e-30)
var wt = Float64(2.0) * one_m_mu2 / wd
wj[i] = wt
wj[n - 1 - i] = wt
return wj^
comptime deg_to_rad = Float64(pi / 180.0)
@always_inline
def scalar_spectrum_stage_0_convert_fld(
nlat: Int, nlon: Int, field: PythonObject
) raises -> List[Float64]:
var fld = List[Float64](capacity=nlat * nlon)
for j in range(nlat):
for i in range(nlon):
fld.append(Float64(py=field[j, i]))
return fld^
@always_inline
def scalar_spectrum_stage_0_convert_lat_deg(
nlat: Int, lat_deg: PythonObject
) raises -> List[Float64]:
var lat_rad = List[Float64](capacity=nlat)
for j in range(nlat):
lat_rad.append(Float64(py=lat_deg[j]) * deg_to_rad)
return lat_rad^
@always_inline
def scalar_spectrum_stage_4_scalar(
nlat: Int,
nlon: Int,
n_m: Int,
inv_nlon: Float64,
fld: Span[mut=False, Float64, ...],
cos_mi: Span[mut=False, Float64, ...],
sin_mi: Span[mut=False, Float64, ...],
re_jm: Span[mut=True, Float64, ...],
im_jm: Span[mut=True, Float64, ...],
):
for j in range(nlat):
var fld_off = j * nlon
var jm_base = j * n_m
for m in range(n_m):
var re = Float64(0.0)
var im = Float64(0.0)
var mi_base = m * nlon
for i in range(nlon):
var x = fld[fld_off + i]
re = re + x * cos_mi[mi_base + i]
im = im - x * sin_mi[mi_base + i]
re_jm[jm_base + m] = re * inv_nlon
im_jm[jm_base + m] = im * inv_nlon
@always_inline
def scalar_spectrum_stage_4_vector(
nlat: Int,
nlon: Int,
n_m: Int,
inv_nlon: Float64,
fld: Span[mut=False, Float64, ...],
cos_mi: Span[mut=False, Float64, ...],
sin_mi: Span[mut=False, Float64, ...],
re_jm: Span[mut=True, Float64, ...],
im_jm: Span[mut=True, Float64, ...],
):
for j in range(nlat):
var fld_off = j * nlon
var jm_base = j * n_m
def kernel[width: Int](m: Int) unified {mut}:
var re = SIMD[DType.float64, width](0.0)
var im = SIMD[DType.float64, width](0.0)
var mi_base = m * nlon
for i in range(nlon):
var x = fld.unsafe_ptr().load[width=width](fld_off + i)
re = re + x * cos_mi.unsafe_ptr().load[width=width](mi_base + i)
im = im - x * sin_mi.unsafe_ptr().load[width=width](mi_base + i)
re_jm.unsafe_ptr().store[width=width](jm_base + m, re * inv_nlon)
im_jm.unsafe_ptr().store[width=width](jm_base + m, im * inv_nlon)
vectorize[simd_width_of[DType.float64](), unroll_factor=2](n_m, kernel)
def scalar_spectrum(
field: PythonObject, lat_deg: PythonObject, max_degree_obj: PythonObject
) raises -> PythonObject:
var t_total0 = _now_seconds()
var nlat = Int(py=_py_getattr(field, "shape")[0])
var nlon = Int(py=_py_getattr(field, "shape")[1])
var max_degree = Int(py=max_degree_obj)
if max_degree <= 0:
max_degree = nlat
if max_degree > nlon / 2:
max_degree = nlon / 2
comptime eps = Float64(1.0e-30)
var dlon = Float64(2.0 * pi) / Float64(nlon)
var inv_nlon = Float64(1.0) / Float64(nlon)
print(
"[mojo] grid",
nlat,
"x",
nlon,
" max_degree=",
max_degree,
)
# ── Stage 0: Extract Python data into native Mojo buffers ──
# Keep all Python interop here. Every later stage runs on native Mojo
# buffers so the hot loops are pure numeric code.
print("[mojo] Stage 0: extracting field to native buffers ...")
var t0 = _now_seconds()
var fld = scalar_spectrum_stage_0_convert_fld(nlat, nlon, field)
var lat_rad = scalar_spectrum_stage_0_convert_lat_deg(nlat, lat_deg)
print("[mojo] Stage 0 time:", _now_seconds() - t0, "s")
# ── Stage 1: mu values and Gauss-Legendre quadrature weights ──
# For a Gaussian grid, the latitude integral is naturally written in
# mu = sin(phi), with exact Gauss-Legendre weights.
print("[mojo] Stage 1: computing Gauss-Legendre weights ...")
var t1 = _now_seconds()
var mu_arr = List[Float64](capacity=nlat)
for j in range(nlat):
mu_arr.append(sin(lat_rad[j]))
var wj_arr = _gauss_legendre_weights(nlat)
print("[mojo] Stage 1 time:", _now_seconds() - t1, "s")
# ── Stage 2: Area-weighted mean using Gaussian weights ──
# The scalar reference spectrum is for anomaly power, so remove the
# weighted global mean before harmonic analysis.
print("[mojo] Stage 2: area-weighted mean ...")
var t2 = _now_seconds()
var mean_num = Float64(0.0)
var mean_den = Float64(0.0)
for j in range(nlat):
var w = wj_arr[j]
mean_den = mean_den + w
var off = j * nlon
var row_sum = Float64(0.0)
for i in range(nlon):
row_sum = row_sum + fld[off + i]
mean_num = mean_num + w * row_sum * inv_nlon
var mean_val = Float64(0.0)
if mean_den > eps:
mean_val = mean_num / mean_den
for idx in range(nlat * nlon):
fld[idx] = fld[idx] - mean_val
print("[mojo] Stage 2 time:", _now_seconds() - t2, "s")
# ── Stage 3: Weighted variance target using Gaussian weights ──
# This target is used at the end to preserve the total variance seen in
# grid space after mean removal.
print("[mojo] Stage 3: weighted variance target ...")
var t3 = _now_seconds()
var sum_w = Float64(0.0)
var var_target = Float64(0.0)
for j in range(nlat):
var w = wj_arr[j]
sum_w = sum_w + w
var row_var = Float64(0.0)
var off = j * nlon
for i in range(nlon):
var x = fld[off + i]
row_var = row_var + x * x
row_var = row_var * inv_nlon
var_target = var_target + w * row_var
if sum_w > eps:
var_target = var_target / sum_w
print("[mojo] Stage 3 time:", _now_seconds() - t3, "s")
# ── Stage 4: Precompute Fourier coefficients re(j,m) / im(j,m) ──
print(
"[mojo] Stage 4: Fourier precompute (twiddle table + nlat x n_m"
" DFTs) ..."
)
var t4 = _now_seconds()
# Factorises O(L^2 * nlat * nlon) into O(M * nlon) + O(M * nlat * nlon) + O(L^2 * nlat).
var n_m = max_degree + 1
var cos_mi = List[Float64](capacity=nlat * nlon)
var sin_mi = List[Float64](capacity=nlat * nlon)
for m in range(n_m):
for i in range(nlon):
# Twiddle table: compute cos/sin once and reuse for every latitude.
var angle = Float64(m) * Float64(i) * dlon
cos_mi.append(cos(angle))
sin_mi.append(sin(angle))
var re_jm = List[Float64](length=nlat * n_m, fill=Float64(0.0))
var im_jm = List[Float64](length=nlat * n_m, fill=Float64(0.0))
scalar_spectrum_stage_4_vector(
nlat,
nlon,
n_m,
inv_nlon,
fld,
cos_mi,
sin_mi,
re_jm,
im_jm,
)
print("[mojo] Stage 4 time:", _now_seconds() - t4, "s")
# ── Stage 5: Legendre projection → spectral power per degree ──
print("[mojo] Stage 5: Legendre projection (streaming recurrence) ...")
var t5 = _now_seconds()
# 4-pi normalization: E_l = sum_{m=0}^l (C_lm^2 + S_lm^2)
var power = List[Float64](length=max_degree, fill=Float64(0.0))
# Stage 5 setup: precompute the m-only and (l,m)-only recurrence factors.
# This moves sqrt-heavy setup out of the innermost projection loop.
var pmm_const = List[Float64](capacity=n_m)
var pmmp1_scale = List[Float64](capacity=n_m)
for m in range(n_m):
var coeff = Float64(1.0)
if m >= 1:
coeff = sqrt(Float64(3.0))
for mm in range(2, m + 1):
coeff = coeff * sqrt(Float64(2 * mm + 1) / Float64(2 * mm))
pmm_const.append(coeff)
pmmp1_scale.append(sqrt(Float64(2 * m + 3)))
var somx2_pow_jm = List[Float64](capacity=nlat * (n_m - 1))
for j in range(nlat):
var one_minus = Float64(1.0) - mu_arr[j] * mu_arr[j]
if one_minus < Float64(0.0):
one_minus = Float64(0.0)
var somx2 = sqrt(one_minus)
var som_pow = Float64(1.0)
somx2_pow_jm.append(som_pow)
for _ in range(1, n_m):
som_pow = som_pow * somx2
somx2_pow_jm.append(som_pow)
var alpha_lm = List[Float64](length=n_m * n_m, fill=Float64(0.0))
var beta_lm = List[Float64](length=n_m * n_m, fill=Float64(0.0))
for m in range(0, max_degree + 1):
var m_sq = Float64(m * m)
for l in range(m + 2, max_degree + 1):
var ll_sq = Float64(l * l)
var twoL1 = Float64(2 * l + 1)
var lm1_sq = Float64((l - 1) * (l - 1))
var lm_idx = l * n_m + m
alpha_lm[lm_idx] = sqrt(
(Float64(4) * ll_sq - Float64(1)) / (ll_sq - m_sq)
)
beta_lm[lm_idx] = sqrt(
twoL1
* (lm1_sq - m_sq)
/ ((Float64(2 * l - 3)) * (ll_sq - m_sq))
)
for m in range(0, max_degree + 1):
# For fixed m, accumulate every degree l >= m in one latitude pass.
# This is the main scalar analogue of the classic Fourier-Legendre
# transform structure used by Gaussian-grid spectral models.
var c_by_l = List[Float64](length=max_degree + 1, fill=Float64(0.0))
var s_by_l = List[Float64](length=max_degree + 1, fill=Float64(0.0))
for j in range(nlat):
var idx = j * n_m + m
var wre = wj_arr[j] * re_jm[idx]
var wim = wj_arr[j] * im_jm[idx]
var mu = mu_arr[j]
# P_m^m(mu_j): factored into an m-only constant and a j,m power.
var pmm = pmm_const[m] * somx2_pow_jm[idx]
if m >= 1 and m <= max_degree:
c_by_l[m] = c_by_l[m] + pmm * wre
s_by_l[m] = s_by_l[m] + pmm * wim
if m < max_degree:
var pmmp1 = pmmp1_scale[m] * mu * pmm
var lp1 = m + 1
if lp1 >= 1:
c_by_l[lp1] = c_by_l[lp1] + pmmp1 * wre
s_by_l[lp1] = s_by_l[lp1] + pmmp1 * wim
var plm2 = pmm
var plm1 = pmmp1
for l in range(m + 2, max_degree + 1):
var lm_idx = l * n_m + m
var alpha = alpha_lm[lm_idx]
var beta = beta_lm[lm_idx]
var pl = alpha * mu * plm1 - beta * plm2
c_by_l[l] = c_by_l[l] + pl * wre
s_by_l[l] = s_by_l[l] + pl * wim
plm2 = plm1
plm1 = pl
var l_start = m
if l_start < 1:
l_start = 1
for l in range(l_start, max_degree + 1):
power[l - 1] = (
power[l - 1] + c_by_l[l] * c_by_l[l] + s_by_l[l] * s_by_l[l]
)
print("[mojo] Stage 5 time:", _now_seconds() - t5, "s")
# ── Stage 6: Variance-preserving rescaling ──
# After the harmonic projection, apply one scalar so the summed spectral
# power matches the weighted grid-space variance from Stage 3.
print("[mojo] Stage 6: variance rescaling ...")
var t6 = _now_seconds()
var total = Float64(0.0)
for l in range(max_degree):
total = total + power[l]
if total > eps and var_target > eps:
var scale = var_target / total
for l in range(max_degree):
power[l] = power[l] * scale
print("[mojo] Stage 6 time:", _now_seconds() - t6, "s")
print("[mojo] total kernel time:", _now_seconds() - t_total0, "s")
print("[mojo] done.")
# ── Return as Python list ──
var builtins = Python.import_module("builtins")
var out = _py_call0(_py_getattr(builtins, "list"))
for l in range(max_degree):
_ = _py_call1(_py_getattr(out, "append"), PythonObject(power[l]))
return out
@export
def PyInit_scalar_spectrum() -> PythonObject:
try:
var m = PythonModuleBuilder("scalar_spectrum")
m.def_function[scalar_spectrum]("scalar_spectrum")
return m.finalize()
except e:
abort(String("error creating scalar_spectrum module:", e))