To date, the way that we’d had people interface with accelerated matrix multiplication was at the operation level in the graph compiler. The graph compiler could then target specific hardware, and dispatch to the right matmul implementation. We recently open-sourced all the definitions for each MAX graph operation, and those operations can be found in a giant file here.
The top-level matmul
function called in the mo.matmul
operation there has a lot of specializations, but let’s try to test just that. For that, I look to the linalg
unit tests for matmul. I’ve cribbed together a pretty crude port of that test to an isolated Mojo file here:
from buffer import NDBuffer
from buffer.dimlist import DimList
from linalg.matmul import matmul
from linalg.packing import pack_matmul_b_shape_func
from memory import UnsafePointer
from testing import assert_almost_equal, assert_equal
from utils.index import Index, IndexList
fn gemm_naive[](
a: NDBuffer,
b: NDBuffer,
c: NDBuffer[mut=True, *_],
m: Int,
n: Int,
k: Int,
):
for i in range(m):
for p in range(k):
for j in range(n):
var a_val = a[i, p].cast[c.type]()
var b_val = b[p, j].cast[c.type]()
c[i, j] += a_val * b_val
alias alignment = 64
def test_matmul[
a_type: DType,
a_shape: DimList,
b_type: DType,
b_shape: DimList,
c_type: DType,
c_shape: DimList,
transpose_b: Bool,
b_packed: Bool,
saturated: Bool,
](m: Int, n: Int, k: Int):
var a_ptr = UnsafePointer[Scalar[a_type], alignment=alignment].alloc(m * k)
var b_ptr = UnsafePointer[Scalar[b_type], alignment=alignment].alloc(k * n)
var b = NDBuffer[b_type, 2, _, b_shape](b_ptr, Index(k, n))
var padded_n_k = IndexList[2]()
padded_n_k = pack_matmul_b_shape_func[
a_type,
a_shape,
b_type,
b_shape,
c_type,
c_shape,
transpose_b,
True,
](b)
var padded_n = padded_n_k[1] if b_packed else n
var padded_k = padded_n_k[0] if b_packed else k
var bp_ptr = UnsafePointer[Scalar[b_type], alignment=alignment].alloc(
padded_k * padded_n
)
var c0_ptr = UnsafePointer[Scalar[c_type], alignment=alignment].alloc(m * n)
var c1_ptr = UnsafePointer[Scalar[c_type], alignment=alignment].alloc(m * n)
var a = NDBuffer[a_type, 2, _, a_shape](a_ptr, Index(m, k))
var bp = NDBuffer[b_type, 2, _, DimList.create_unknown[2]()](
bp_ptr, Index(padded_k, padded_n)
)
var c = NDBuffer[c_type, 2, _, c_shape](c0_ptr, Index(m, n))
var golden = NDBuffer[c_type, 2, _, c_shape](c1_ptr, Index(m, n))
# saturated VNNI only has a range [0,127] for the input a
var vnni_range: Int = 128 if saturated else 256
var cnt: Int = 0
for i in range(m):
for p in range(k):
# uint8 but limited to [0,127]
a[IndexList[2]((i, p))] = cnt % vnni_range
cnt += 1
cnt = 0
for p in range(k):
for j in range(n):
# int8 [-128, 127]
b[IndexList[2]((p, j))] = cnt % 256 - 128
bp[IndexList[2]((p, j))] = b[IndexList[2]((p, j))]
cnt += 1
for i in range(m):
for j in range(n):
c[IndexList[2]((i, j))] = 0
golden[IndexList[2]((i, j))] = c[IndexList[2]((i, j))]
matmul[
transpose_b=transpose_b,
b_packed=b_packed,
saturated_vnni=saturated,
](c, a, rebind[NDBuffer[b_type, 2, bp.origin, b_shape]](bp))
gemm_naive(a, b, golden, m, n, k)
for i in range(m):
for j in range(n):
var msg = String(
"values do not agree for ",
m,
"x",
n,
"x",
k,
" using the dtype=",
a_type,
",",
b_type,
",",
c_type,
)
@parameter
if c_type.is_floating_point():
assert_almost_equal(c[i, j], golden[i, j], msg)
else:
assert_equal(c[i, j], golden[i, j], msg)
a_ptr.free()
b_ptr.free()
bp_ptr.free()
c0_ptr.free()
c1_ptr.free()
def main():
alias a_shape = DimList.create_unknown[2]()
alias b_shape = DimList.create_unknown[2]()
alias c_shape = DimList.create_unknown[2]()
test_matmul[DType.float32,
a_shape,
DType.float32,
b_shape,
DType.float32,
c_shape, transpose_b=False, b_packed=False, saturated=False](256, 256, 256)
If you clone the modular
repository at a directory level parallel to this project, you should be able to run
mojo -I ../modular/max/kernels/src/ matmul_test.mojo
This ran locally on my machine, and the matmul passed the test vs. a naive matmul in that same file. There are a lot of specializations below this that you might be able to pull out and test in a similar manner, but this does seem to work against the modular
repository if you manually point the imports to the kernels directory inside there.