Mxfp4 implementation

Looking for some pointers on this mxfp4 implementation. most of the code here is a translation from the llama.cpp implementation.

With the recent release of gpt-oss i wanted to implement it as a custom arch in max serve but then saw there was no current support for mxfp4 in mojo. Here is my first attempt at it using llama.cpp as a reference and the kernel examples of the custom_ops/kernels section of the modular/max github repo.

from math import floor, log2, ldexp
from compiler import register
from layout import Layout, LayoutTensor
from tensor_internal import InputTensor, OutputTensor
from runtime.asyncrt import DeviceContextPtr
from gpu.host import DeviceContext
from gpu.id import block_idx

# -----------------------------------------------------------------------------
# Constants / basic helpers
# -----------------------------------------------------------------------------

alias QK_MXFP4 = 32

@always_inline
fn clamp_f32(x: Float32, lo: Float32, hi: Float32) -> Float32:
    if x < lo: return lo
    if x > hi: return hi
    return x

# E8M0 -> FP32 scale: 2^(e-127); special-case e==0
@always_inline
fn e8m0_to_fp32(e: UInt8) -> Float32:
    if e == 0:
    # ldexp expects Float32 base, so use explicit var one: Float32
       return ldexp(Float32(1.0), -23)
    return ldexp(Float32(1.0), Int(e) - 127)

# Choose E8M0 exponent from block max (heuristic aligned to 6.0 peak code value)
@always_inline
fn fp32_to_e8m0_from_block_max(max_val: Float32) -> UInt8:
    if max_val <= 0.0:
        return 0
    var e_est: Float32 = floor(log2(max_val)) - 2.0 + 127.0
    var e_u8 = UInt8(clamp_f32(e_est, 1.0, 254.0))
    return e_u8

# Decode a 4-bit MXFP4 code (0..15) to its signed magnitude WITHOUT scale d.
# Magnitudes per code (low 3 bits): 0, 0.5, 1, 1.5, 2, 3, 4, 6; high bit is sign.
@always_inline
fn mxfp4_code_to_unit(i: Int) -> Float32:
    var mag_code: Int = i & 0x7
    var sgn: Int = (i >> 3) & 0x1
    var base: Float32
    if   mag_code == 0: base = 0.0
    elif mag_code == 1: base = 0.5
    elif mag_code == 2: base = 1.0
    elif mag_code == 3: base = 1.5
    elif mag_code == 4: base = 2.0
    elif mag_code == 5: base = 3.0
    elif mag_code == 6: base = 4.0
    else:               base = 6.0
    if sgn == 1: base = -base
    return base

@always_inline
fn choose_code(val: Float32, d: Float32) -> UInt8:
    if d == 0.0:
        return 0
    var target: Float32 = val / d
    var best: Int = 0
    var min_err: Float32 = 3.4028235e38  # big
    for idx in range(16):
        var qv: Float32 = mxfp4_code_to_unit(idx)
        var err: Float32 = qv - target
        var aerr: Float32 = (err if err >= 0.0 else -err)
        if aerr < min_err:
            min_err = aerr
            best = idx
    return UInt8(best)

# -----------------------------------------------------------------------------
# QUANTIZE: X[H,W] -> Q[H,W/2] (bytes of 2 nibbles) + E[H,W/32] (exponents)
# -----------------------------------------------------------------------------

@register("modular_ops::mxfp4_quantize_exq")
struct MXFP4QuantizeEXQ:
    """Quantize a 2D float tensor X[H,W] into separate MXFP4 buffers:
    Q: UInt8[H, W/2] with two 4-bit codes per byte (low nibble first)
    E: UInt8[H, W/32] with shared E8M0 exponent per 32 values.
    W must be divisible by 32.
    """

    @staticmethod
    fn execute[
        in_dtype: DType,
        rank: Int,
        BN: Int,
        BD: Int,                
        target: StaticString,
    ](
        out_q: OutputTensor[dtype=DType.uint8, rank=rank],  # [H, W/2]
        out_e: OutputTensor[dtype=DType.uint8, rank=rank],  # [H, W/32]
        x: InputTensor[dtype=in_dtype, rank=rank],          # [H, W]
        ctx: DeviceContextPtr,
    ) raises:
        constrained[rank == 2, "rank must be 2"]()
        var X = x.to_layout_tensor()
        var Q = out_q.to_layout_tensor()
        var E = out_e.to_layout_tensor()
        alias H = X.shape[0]()
        alias W = X.shape[1]()
        constrained[W % QK_MXFP4 == 0, "W must be divisible by 32"]()

        @parameter
        if target == "cpu":
            _mxfp4_quantize_cpu(X, Q, E)
        else:
            var dev = ctx.get_device_context()
            _mxfp4_quantize_gpu[BN, BD](dev, X, Q, E)


fn _mxfp4_quantize_cpu(X: LayoutTensor, mut Q: LayoutTensor, mut E: LayoutTensor):
    alias H = X.shape[0]()
    alias W = X.shape[1]()
    var blocks_per_row = W // QK_MXFP4

    for r in range(H):
        for b in range(blocks_per_row):
            var c0 = b * QK_MXFP4
            # block max
            var m: Float32 = 0.0
            for j in range(QK_MXFP4):
                var f = rebind[Float32](X[r, c0 + j].cast[DType.float32]())
                var af = (f if f >= 0.0 else -f)
                if af > m: m = af
            var e = fp32_to_e8m0_from_block_max(m)
            E[r, b] = rebind[E.element_type](e)
            var d = e8m0_to_fp32(e)

            # pack two nibbles per byte for this block
            var q_base = b * (QK_MXFP4 // 2)
            for j in range(QK_MXFP4 // 2):
                var v0 = rebind[Float32](X[r, c0 + j].cast[DType.float32]())
                var v1 = rebind[Float32](X[r, c0 + j + QK_MXFP4//2].cast[DType.float32]())
                var i0 = choose_code(v0, d)
                var i1 = choose_code(v1, d)
                var packed = (UInt8(i1) << 4) | (UInt8(i0) & 0x0F)
                Q[r, q_base + j] = rebind[Q.element_type](packed)


# GPU kernel version (tiled)
fn _mxfp4_quantize_kernel[
    x_dtype: DType, x_layout: Layout,
    q_dtype: DType, q_layout: Layout,
    e_dtype: DType, e_layout: Layout,
    BN: Int, BD: Int,
](
    X: LayoutTensor[x_dtype, x_layout, MutableAnyOrigin],
    Q: LayoutTensor[q_dtype, q_layout, MutableAnyOrigin],
    E: LayoutTensor[e_dtype, e_layout, MutableAnyOrigin],
):
    var tile_x = X.tile[BN, BD](block_idx.y, block_idx.x)
    var tile_q = Q.tile[BN, BD // 2](block_idx.y, block_idx.x)  # Q has W/2 cols
    var tile_e = E.tile[BN, BD // QK_MXFP4](block_idx.y, block_idx.x)  # E has W/32 cols

    for r in range(BN):
        for cblock in range(BD // QK_MXFP4):
            var c0 = cblock * QK_MXFP4
            # block max
            var m: Float32 = 0.0
            for j in range(QK_MXFP4):
                var f = rebind[Float32](tile_x[r, c0 + j].cast[DType.float32]())
                var af = (f if f >= 0.0 else -f)
                if af > m: m = af
            var e = fp32_to_e8m0_from_block_max(m)
            tile_e[r, cblock] = rebind[tile_e.element_type](e)
            var d = e8m0_to_fp32(e)

            # write 16 packed bytes
            var q_base = cblock * (QK_MXFP4 // 2)
            for j in range(QK_MXFP4 // 2):
                var v0 = rebind[Float32](tile_x[r, c0 + j].cast[DType.float32]())
                var v1 = rebind[Float32](tile_x[r, c0 + j + QK_MXFP4//2].cast[DType.float32]())
                var i0 = choose_code(v0, d)
                var i1 = choose_code(v1, d)
                var packed = (UInt8(i1) << 4) | (UInt8(i0) & 0x0F)
                tile_q[r, q_base + j] = rebind[tile_q.element_type](packed)

def _mxfp4_quantize_gpu[
    BN: Int, BD: Int,
](
    ctx: DeviceContext,
    X: LayoutTensor,
    mut Q: LayoutTensor,
    mut E: LayoutTensor,
):
    alias kernel = _mxfp4_quantize_kernel[
        X.dtype, X.layout,
        Q.dtype, Q.layout,
        E.dtype, E.layout,
        BN, BD,
    ]
    ctx.enqueue_function[kernel](
        X, Q, E,
        grid_dim=(X.shape[1]() // BD, X.shape[0]() // BN),
        block_dim=(32),
    )

# -----------------------------------------------------------------------------
# DEQUANTIZE: (Q[H,W/2], E[H,W/32]) -> X[H,W]
# -----------------------------------------------------------------------------

@register("modular_ops::mxfp4_dequantize_exq")
struct MXFP4DequantizeEXQ:
    """Dequantize MXFP4 given E[H,W/32] and Q[H,W/2] back to float (X[H,W])."""

    @staticmethod
    fn execute[
        out_dtype: DType,
        rank: Int,
        BN: Int,
        BD: Int,
        target: StaticString,
    ](
        out_x: OutputTensor[dtype=out_dtype, rank=rank],   # [H, W]
        q: InputTensor[dtype=DType.uint8, rank=rank],       # [H, W/2]
        e: InputTensor[dtype=DType.uint8, rank=rank],       # [H, W/32]
        ctx: DeviceContextPtr,
    ) raises:
        constrained[rank == 2, "rank must be 2"]()
        var Q = q.to_layout_tensor()
        var E = e.to_layout_tensor()
        var X = out_x.to_layout_tensor()
        alias H = X.shape[0]()
        alias W = X.shape[1]()
        constrained[W % QK_MXFP4 == 0, "W must be divisible by 32"]()

        @parameter
        if target == "cpu":
            _mxfp4_dequantize_cpu(Q, E, X)
        else:
            var dev = ctx.get_device_context()
            _mxfp4_dequantize_gpu[BN, BD](dev, Q, E, X)


fn _mxfp4_dequantize_cpu(Q: LayoutTensor, E: LayoutTensor, mut X: LayoutTensor):
    alias H = X.shape[0]()
    alias W = X.shape[1]()
    var blocks_per_row = W // QK_MXFP4

    for r in range(H):
        for b in range(blocks_per_row):
            var c0 = b * QK_MXFP4
            var d = e8m0_to_fp32(rebind[UInt8](E[r, b].cast[DType.uint8]()))
            var q_base = b * (QK_MXFP4 // 2)
            for j in range(QK_MXFP4 // 2):
                var byte_val = rebind[UInt8](Q[r, q_base + j].cast[DType.uint8]())
                var i0 = Int(byte_val & 0x0F)
                var i1 = Int(byte_val >> 4)
                var v0 = mxfp4_code_to_unit(i0) * d
                var v1 = mxfp4_code_to_unit(i1) * d
                X[r, c0 + j] = rebind[X.element_type](v0)
                X[r, c0 + j + QK_MXFP4//2] = rebind[X.element_type](v1)


fn _mxfp4_dequantize_kernel[
    q_dtype: DType, q_layout: Layout,
    e_dtype: DType, e_layout: Layout,
    x_dtype: DType, x_layout: Layout,
    BN: Int, BD: Int,
](
    Q: LayoutTensor[q_dtype, q_layout, MutableAnyOrigin],
    E: LayoutTensor[e_dtype, e_layout, MutableAnyOrigin],
    X: LayoutTensor[x_dtype, x_layout, MutableAnyOrigin],
):
    var tile_q = Q.tile[BN, BD // 2](block_idx.y, block_idx.x)
    var tile_e = E.tile[BN, BD // QK_MXFP4](block_idx.y, block_idx.x)
    var tile_x = X.tile[BN, BD](block_idx.y, block_idx.x)

    for r in range(BN):
        for cblock in range(BD // QK_MXFP4):
            var d = e8m0_to_fp32(rebind[UInt8](tile_e[r, cblock].cast[DType.uint8]()))
            var q_base = cblock * (QK_MXFP4 // 2)
            var x_base = cblock * QK_MXFP4
            for j in range(QK_MXFP4 // 2):
                var byte_val = rebind[UInt8](tile_q[r, q_base + j].cast[DType.uint8]())
                var i0 = Int(byte_val & 0x0F)
                var i1 = Int(byte_val >> 4)
                tile_x[r, x_base + j] = rebind[tile_x.element_type](mxfp4_code_to_unit(i0) * d)
                tile_x[r, x_base + j + QK_MXFP4//2] = rebind[tile_x.element_type](mxfp4_code_to_unit(i1) * d)

def _mxfp4_dequantize_gpu[
    BN: Int, BD: Int,
](
    ctx: DeviceContext,
    Q: LayoutTensor,
    E: LayoutTensor,
    mut X: LayoutTensor,
):
    alias kernel = _mxfp4_dequantize_kernel[
        Q.dtype, Q.layout,
        E.dtype, E.layout,
        X.dtype, X.layout,
        BN, BD,
    ]
    ctx.enqueue_function[kernel](
        Q, E, X,
        grid_dim=(X.shape[1]() // BD, X.shape[0]() // BN),
        block_dim=(32),
    )

# -----------------------------------------------------------------------------
# (Optional) CPU reference matvec using EXQ buffers (y = A @ x)
# -----------------------------------------------------------------------------

@register("modular_ops::mxfp4_matvec_f32_exq")
struct MXFP4MatVecF32EXQ:
    """Compute y[H] = (MXFP4(E,Q) @ x[W]) on CPU for correctness checks."""

    @staticmethod
    fn execute[
        in_dtype: DType,   # dtype of x (float32)
        rank_w: Int,
        rank_e: Int,
        rank_q: Int,
        target: StaticString,
    ](
        out_y: OutputTensor[dtype=DType.float32, rank=1],   # [H]
        q: InputTensor[dtype=DType.uint8, rank=rank_q],     # [H, W/2]
        e: InputTensor[dtype=DType.uint8, rank=rank_e],     # [H, W/32]
        x: InputTensor[dtype=in_dtype, rank=rank_w],        # [W]
        ctx: DeviceContextPtr,
    ) raises:
        var Q = q.to_layout_tensor()
        var E = e.to_layout_tensor()
        var Xv = x.to_layout_tensor()
        var Y = out_y.to_layout_tensor()
        alias H = Q.shape[0]()
        alias W2 = Q.shape[1]()
        var W = W2 * 2
        var blocks = W // QK_MXFP4

        for r in range(H):
            var acc: Float32 = 0.0
            for b in range(blocks):
                var d = e8m0_to_fp32(rebind[UInt8](E[r, b].cast[DType.uint8]()))
                var q_base = b * (QK_MXFP4 // 2)
                var x_base = b * QK_MXFP4
                for j in range(QK_MXFP4 // 2):
                    var byte_val = rebind[UInt8](Q[r, q_base + j].cast[DType.uint8]())
                    var i0 = Int(byte_val & 0x0F)
                    var i1 = Int(byte_val >> 4)
                    var v0 = mxfp4_code_to_unit(i0) * d
                    var v1 = mxfp4_code_to_unit(i1) * d
                    var x0 = rebind[Float32](Xv[0, x_base + j].cast[DType.float32]())
                    var x1 = rebind[Float32](Xv[0, x_base + j + QK_MXFP4//2].cast[DType.float32]())
                    acc += v0 * x0 + v1 * x1
            Y[r] = rebind[Y.element_type](acc)
2 Likes

This topic was automatically closed 180 days after the last reply. New replies are no longer allowed.