How to apply a vectorized function to a List inplace in Mojo

from memory import Span
from sys.info import simdwidthof


fn apply[
    D: DType,
    O: MutableOrigin, //,
    func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w],
](span: Span[Scalar[D], O]):
    """Apply the function to the `Span` inplace.

    Parameters:
        D: The DType.
        O: The origin of the `Span`.
        func: The function to evaluate.
    """

    alias widths = (256, 128, 64, 32, 16, 8, 4)
    var ptr = span.unsafe_ptr()
    var length = len(span)
    var processed = 0

    @parameter
    for i in range(len(widths)):
        alias w = widths.get[i, Int]()

        @parameter
        if simdwidthof[D]() >= w:
            for _ in range((length - processed) // w):
                var p_curr = ptr + processed
                p_curr.store(func(p_curr.load[width=w]()))
                processed += w

    for i in range(length - processed):
        (ptr + processed + i).init_pointee_move(func(ptr[processed + i]))


fn main():
    items = List[Byte](
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19
    )
    twice = items
    span = Span(twice)

    fn _twice[w: Int](x: SIMD[DType.uint8, w]) -> SIMD[DType.uint8, w]:
        return x * 2

    apply[func=_twice](span)
    for i in range(len(items)):
        print(span[i] == items[i] * 2)  # True
3 Likes

This is great! Thanks so much for posting. I’ve been struggling with wrapping my head around this issue, and this will help me a lot.

1 Like

Why are the widths between 4 and 256 explicitly listed? E.g. is something like the implementation below worse for any reason:

fn apply[
    D: DType,
    O: MutableOrigin, //,
    func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w],
](span: Span[Scalar[D], O]):

    var ptr = span.unsafe_ptr()
    var length = len(span)
    var processed = 0

    alias width = simdwidthof[D]()
    alias pow_of_2 = bit_width(width)-1

    @parameter
    for i in range(pow_of_2,-1,-1):
        alias w = 2**i
        for _ in range((length - processed) // w):
            var p_curr = ptr + processed
            p_curr.store(func(p_curr.load[width=w]()))
            processed += w
2 Likes

That does seem like a nice and straightforward solution :smile:. I was mostly thinking about 4 as the limit because sometimes it’s faster to process small data chunks with scalars since they can be pipelined. I’ll have your approach in mind while adding these new methods to Span in the near future :fire:.

PS: It might be faster at compile time to index into an existing list than calculating the values, so this will probably have to eventually go into a compile time calculated list at the struct alias level. This dovetails nicely with some plans I have for vectorizing List and Span automatically, so thanks for the idea :smile: