Looking for some pointers on this mxfp4 implementation. most of the code here is a translation from the llama.cpp implementation.
With the recent release of gpt-oss i wanted to implement it as a custom arch in max serve but then saw there was no current support for mxfp4 in mojo. Here is my first attempt at it using llama.cpp as a reference and the kernel examples of the custom_ops/kernels section of the modular/max github repo.
from math import floor, log2, ldexp
from compiler import register
from layout import Layout, LayoutTensor
from tensor_internal import InputTensor, OutputTensor
from runtime.asyncrt import DeviceContextPtr
from gpu.host import DeviceContext
from gpu.id import block_idx
# -----------------------------------------------------------------------------
# Constants / basic helpers
# -----------------------------------------------------------------------------
alias QK_MXFP4 = 32
@always_inline
fn clamp_f32(x: Float32, lo: Float32, hi: Float32) -> Float32:
if x < lo: return lo
if x > hi: return hi
return x
# E8M0 -> FP32 scale: 2^(e-127); special-case e==0
@always_inline
fn e8m0_to_fp32(e: UInt8) -> Float32:
if e == 0:
# ldexp expects Float32 base, so use explicit var one: Float32
return ldexp(Float32(1.0), -23)
return ldexp(Float32(1.0), Int(e) - 127)
# Choose E8M0 exponent from block max (heuristic aligned to 6.0 peak code value)
@always_inline
fn fp32_to_e8m0_from_block_max(max_val: Float32) -> UInt8:
if max_val <= 0.0:
return 0
var e_est: Float32 = floor(log2(max_val)) - 2.0 + 127.0
var e_u8 = UInt8(clamp_f32(e_est, 1.0, 254.0))
return e_u8
# Decode a 4-bit MXFP4 code (0..15) to its signed magnitude WITHOUT scale d.
# Magnitudes per code (low 3 bits): 0, 0.5, 1, 1.5, 2, 3, 4, 6; high bit is sign.
@always_inline
fn mxfp4_code_to_unit(i: Int) -> Float32:
var mag_code: Int = i & 0x7
var sgn: Int = (i >> 3) & 0x1
var base: Float32
if mag_code == 0: base = 0.0
elif mag_code == 1: base = 0.5
elif mag_code == 2: base = 1.0
elif mag_code == 3: base = 1.5
elif mag_code == 4: base = 2.0
elif mag_code == 5: base = 3.0
elif mag_code == 6: base = 4.0
else: base = 6.0
if sgn == 1: base = -base
return base
@always_inline
fn choose_code(val: Float32, d: Float32) -> UInt8:
if d == 0.0:
return 0
var target: Float32 = val / d
var best: Int = 0
var min_err: Float32 = 3.4028235e38 # big
for idx in range(16):
var qv: Float32 = mxfp4_code_to_unit(idx)
var err: Float32 = qv - target
var aerr: Float32 = (err if err >= 0.0 else -err)
if aerr < min_err:
min_err = aerr
best = idx
return UInt8(best)
# -----------------------------------------------------------------------------
# QUANTIZE: X[H,W] -> Q[H,W/2] (bytes of 2 nibbles) + E[H,W/32] (exponents)
# -----------------------------------------------------------------------------
@register("modular_ops::mxfp4_quantize_exq")
struct MXFP4QuantizeEXQ:
"""Quantize a 2D float tensor X[H,W] into separate MXFP4 buffers:
Q: UInt8[H, W/2] with two 4-bit codes per byte (low nibble first)
E: UInt8[H, W/32] with shared E8M0 exponent per 32 values.
W must be divisible by 32.
"""
@staticmethod
fn execute[
in_dtype: DType,
rank: Int,
BN: Int,
BD: Int,
target: StaticString,
](
out_q: OutputTensor[dtype=DType.uint8, rank=rank], # [H, W/2]
out_e: OutputTensor[dtype=DType.uint8, rank=rank], # [H, W/32]
x: InputTensor[dtype=in_dtype, rank=rank], # [H, W]
ctx: DeviceContextPtr,
) raises:
constrained[rank == 2, "rank must be 2"]()
var X = x.to_layout_tensor()
var Q = out_q.to_layout_tensor()
var E = out_e.to_layout_tensor()
alias H = X.shape[0]()
alias W = X.shape[1]()
constrained[W % QK_MXFP4 == 0, "W must be divisible by 32"]()
@parameter
if target == "cpu":
_mxfp4_quantize_cpu(X, Q, E)
else:
var dev = ctx.get_device_context()
_mxfp4_quantize_gpu[BN, BD](dev, X, Q, E)
fn _mxfp4_quantize_cpu(X: LayoutTensor, mut Q: LayoutTensor, mut E: LayoutTensor):
alias H = X.shape[0]()
alias W = X.shape[1]()
var blocks_per_row = W // QK_MXFP4
for r in range(H):
for b in range(blocks_per_row):
var c0 = b * QK_MXFP4
# block max
var m: Float32 = 0.0
for j in range(QK_MXFP4):
var f = rebind[Float32](X[r, c0 + j].cast[DType.float32]())
var af = (f if f >= 0.0 else -f)
if af > m: m = af
var e = fp32_to_e8m0_from_block_max(m)
E[r, b] = rebind[E.element_type](e)
var d = e8m0_to_fp32(e)
# pack two nibbles per byte for this block
var q_base = b * (QK_MXFP4 // 2)
for j in range(QK_MXFP4 // 2):
var v0 = rebind[Float32](X[r, c0 + j].cast[DType.float32]())
var v1 = rebind[Float32](X[r, c0 + j + QK_MXFP4//2].cast[DType.float32]())
var i0 = choose_code(v0, d)
var i1 = choose_code(v1, d)
var packed = (UInt8(i1) << 4) | (UInt8(i0) & 0x0F)
Q[r, q_base + j] = rebind[Q.element_type](packed)
# GPU kernel version (tiled)
fn _mxfp4_quantize_kernel[
x_dtype: DType, x_layout: Layout,
q_dtype: DType, q_layout: Layout,
e_dtype: DType, e_layout: Layout,
BN: Int, BD: Int,
](
X: LayoutTensor[x_dtype, x_layout, MutableAnyOrigin],
Q: LayoutTensor[q_dtype, q_layout, MutableAnyOrigin],
E: LayoutTensor[e_dtype, e_layout, MutableAnyOrigin],
):
var tile_x = X.tile[BN, BD](block_idx.y, block_idx.x)
var tile_q = Q.tile[BN, BD // 2](block_idx.y, block_idx.x) # Q has W/2 cols
var tile_e = E.tile[BN, BD // QK_MXFP4](block_idx.y, block_idx.x) # E has W/32 cols
for r in range(BN):
for cblock in range(BD // QK_MXFP4):
var c0 = cblock * QK_MXFP4
# block max
var m: Float32 = 0.0
for j in range(QK_MXFP4):
var f = rebind[Float32](tile_x[r, c0 + j].cast[DType.float32]())
var af = (f if f >= 0.0 else -f)
if af > m: m = af
var e = fp32_to_e8m0_from_block_max(m)
tile_e[r, cblock] = rebind[tile_e.element_type](e)
var d = e8m0_to_fp32(e)
# write 16 packed bytes
var q_base = cblock * (QK_MXFP4 // 2)
for j in range(QK_MXFP4 // 2):
var v0 = rebind[Float32](tile_x[r, c0 + j].cast[DType.float32]())
var v1 = rebind[Float32](tile_x[r, c0 + j + QK_MXFP4//2].cast[DType.float32]())
var i0 = choose_code(v0, d)
var i1 = choose_code(v1, d)
var packed = (UInt8(i1) << 4) | (UInt8(i0) & 0x0F)
tile_q[r, q_base + j] = rebind[tile_q.element_type](packed)
def _mxfp4_quantize_gpu[
BN: Int, BD: Int,
](
ctx: DeviceContext,
X: LayoutTensor,
mut Q: LayoutTensor,
mut E: LayoutTensor,
):
alias kernel = _mxfp4_quantize_kernel[
X.dtype, X.layout,
Q.dtype, Q.layout,
E.dtype, E.layout,
BN, BD,
]
ctx.enqueue_function[kernel](
X, Q, E,
grid_dim=(X.shape[1]() // BD, X.shape[0]() // BN),
block_dim=(32),
)
# -----------------------------------------------------------------------------
# DEQUANTIZE: (Q[H,W/2], E[H,W/32]) -> X[H,W]
# -----------------------------------------------------------------------------
@register("modular_ops::mxfp4_dequantize_exq")
struct MXFP4DequantizeEXQ:
"""Dequantize MXFP4 given E[H,W/32] and Q[H,W/2] back to float (X[H,W])."""
@staticmethod
fn execute[
out_dtype: DType,
rank: Int,
BN: Int,
BD: Int,
target: StaticString,
](
out_x: OutputTensor[dtype=out_dtype, rank=rank], # [H, W]
q: InputTensor[dtype=DType.uint8, rank=rank], # [H, W/2]
e: InputTensor[dtype=DType.uint8, rank=rank], # [H, W/32]
ctx: DeviceContextPtr,
) raises:
constrained[rank == 2, "rank must be 2"]()
var Q = q.to_layout_tensor()
var E = e.to_layout_tensor()
var X = out_x.to_layout_tensor()
alias H = X.shape[0]()
alias W = X.shape[1]()
constrained[W % QK_MXFP4 == 0, "W must be divisible by 32"]()
@parameter
if target == "cpu":
_mxfp4_dequantize_cpu(Q, E, X)
else:
var dev = ctx.get_device_context()
_mxfp4_dequantize_gpu[BN, BD](dev, Q, E, X)
fn _mxfp4_dequantize_cpu(Q: LayoutTensor, E: LayoutTensor, mut X: LayoutTensor):
alias H = X.shape[0]()
alias W = X.shape[1]()
var blocks_per_row = W // QK_MXFP4
for r in range(H):
for b in range(blocks_per_row):
var c0 = b * QK_MXFP4
var d = e8m0_to_fp32(rebind[UInt8](E[r, b].cast[DType.uint8]()))
var q_base = b * (QK_MXFP4 // 2)
for j in range(QK_MXFP4 // 2):
var byte_val = rebind[UInt8](Q[r, q_base + j].cast[DType.uint8]())
var i0 = Int(byte_val & 0x0F)
var i1 = Int(byte_val >> 4)
var v0 = mxfp4_code_to_unit(i0) * d
var v1 = mxfp4_code_to_unit(i1) * d
X[r, c0 + j] = rebind[X.element_type](v0)
X[r, c0 + j + QK_MXFP4//2] = rebind[X.element_type](v1)
fn _mxfp4_dequantize_kernel[
q_dtype: DType, q_layout: Layout,
e_dtype: DType, e_layout: Layout,
x_dtype: DType, x_layout: Layout,
BN: Int, BD: Int,
](
Q: LayoutTensor[q_dtype, q_layout, MutableAnyOrigin],
E: LayoutTensor[e_dtype, e_layout, MutableAnyOrigin],
X: LayoutTensor[x_dtype, x_layout, MutableAnyOrigin],
):
var tile_q = Q.tile[BN, BD // 2](block_idx.y, block_idx.x)
var tile_e = E.tile[BN, BD // QK_MXFP4](block_idx.y, block_idx.x)
var tile_x = X.tile[BN, BD](block_idx.y, block_idx.x)
for r in range(BN):
for cblock in range(BD // QK_MXFP4):
var d = e8m0_to_fp32(rebind[UInt8](tile_e[r, cblock].cast[DType.uint8]()))
var q_base = cblock * (QK_MXFP4 // 2)
var x_base = cblock * QK_MXFP4
for j in range(QK_MXFP4 // 2):
var byte_val = rebind[UInt8](tile_q[r, q_base + j].cast[DType.uint8]())
var i0 = Int(byte_val & 0x0F)
var i1 = Int(byte_val >> 4)
tile_x[r, x_base + j] = rebind[tile_x.element_type](mxfp4_code_to_unit(i0) * d)
tile_x[r, x_base + j + QK_MXFP4//2] = rebind[tile_x.element_type](mxfp4_code_to_unit(i1) * d)
def _mxfp4_dequantize_gpu[
BN: Int, BD: Int,
](
ctx: DeviceContext,
Q: LayoutTensor,
E: LayoutTensor,
mut X: LayoutTensor,
):
alias kernel = _mxfp4_dequantize_kernel[
Q.dtype, Q.layout,
E.dtype, E.layout,
X.dtype, X.layout,
BN, BD,
]
ctx.enqueue_function[kernel](
Q, E, X,
grid_dim=(X.shape[1]() // BD, X.shape[0]() // BN),
block_dim=(32),
)
# -----------------------------------------------------------------------------
# (Optional) CPU reference matvec using EXQ buffers (y = A @ x)
# -----------------------------------------------------------------------------
@register("modular_ops::mxfp4_matvec_f32_exq")
struct MXFP4MatVecF32EXQ:
"""Compute y[H] = (MXFP4(E,Q) @ x[W]) on CPU for correctness checks."""
@staticmethod
fn execute[
in_dtype: DType, # dtype of x (float32)
rank_w: Int,
rank_e: Int,
rank_q: Int,
target: StaticString,
](
out_y: OutputTensor[dtype=DType.float32, rank=1], # [H]
q: InputTensor[dtype=DType.uint8, rank=rank_q], # [H, W/2]
e: InputTensor[dtype=DType.uint8, rank=rank_e], # [H, W/32]
x: InputTensor[dtype=in_dtype, rank=rank_w], # [W]
ctx: DeviceContextPtr,
) raises:
var Q = q.to_layout_tensor()
var E = e.to_layout_tensor()
var Xv = x.to_layout_tensor()
var Y = out_y.to_layout_tensor()
alias H = Q.shape[0]()
alias W2 = Q.shape[1]()
var W = W2 * 2
var blocks = W // QK_MXFP4
for r in range(H):
var acc: Float32 = 0.0
for b in range(blocks):
var d = e8m0_to_fp32(rebind[UInt8](E[r, b].cast[DType.uint8]()))
var q_base = b * (QK_MXFP4 // 2)
var x_base = b * QK_MXFP4
for j in range(QK_MXFP4 // 2):
var byte_val = rebind[UInt8](Q[r, q_base + j].cast[DType.uint8]())
var i0 = Int(byte_val & 0x0F)
var i1 = Int(byte_val >> 4)
var v0 = mxfp4_code_to_unit(i0) * d
var v1 = mxfp4_code_to_unit(i1) * d
var x0 = rebind[Float32](Xv[0, x_base + j].cast[DType.float32]())
var x1 = rebind[Float32](Xv[0, x_base + j + QK_MXFP4//2].cast[DType.float32]())
acc += v0 * x0 + v1 * x1
Y[r] = rebind[Y.element_type](acc)