How to uppercase and lowercase ASCII strings using SIMD in Mojo

from memory import Span
from sys.info import simdwidthof


fn apply[
    D: DType,
    O: MutableOrigin, //,
    func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w],
    *,
    where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w],
](span: Span[Scalar[D], O]):
    """Apply the function to the `Span` inplace.

    Parameters:
        D: The DType.
        O: The origin of the `Span`.
        func: The function to evaluate.
        where: The condition to apply the function.
    """

    alias widths = (256, 128, 64, 32, 16, 8, 4)
    var ptr = span.unsafe_ptr()
    var length = len(span)
    var processed = 0

    @parameter
    for i in range(len(widths)):
        alias w = widths.get[i, Int]()

        @parameter
        if simdwidthof[D]() >= w:
            for _ in range((length - processed) // w):
                var p_curr = ptr + processed
                var vec = p_curr.load[width=w]()                
                p_curr.store(where(vec).select(func(vec), vec))
                processed += w

    for i in range(length - processed):
        var vec = ptr[processed + i]
        if where(vec):
            (ptr + processed + i).init_pointee_move(func(vec))

fn is_lower_ascii[w: Int](value: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
    alias `a` = Byte(ord("a"))
    alias `z` = Byte(ord("z"))
    return (value >= `a`) & (value <= `z`)

fn is_upper_ascii[w: Int](value: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
    alias `A` = Byte(ord("A"))
    alias `Z` = Byte(ord("Z"))
    return (value >= `A`) & (value <= `Z`)

fn toggle_case[w: Int](value: SIMD[DType.uint8, w]) -> SIMD[DType.uint8, w]:
    alias `a` = Byte(ord("a"))
    alias `A` = Byte(ord("A"))
    return value ^ (`A` ^ `a`)

fn main():
    items = List[Byte](
        ord("a"), ord("a"), ord("a"), ord("A"), ord("A"), ord("A")
    )
    tmp_copy = items
    span = Span(tmp_copy)
    apply[func=toggle_case, where=is_lower_ascii](span)
    for i in range(len(items)):
        if i < 3:
            print(span[i] == ord("A"))  # True
        else:
            print(span[i] == ord("A"))  # True

    tmp_copy = items
    span = Span(tmp_copy)

    apply[func=toggle_case, where=is_upper_ascii](span)
    for i in range(len(items)):
        if i < 3:
            print(span[i] == ord("a"))  # True
        else:
            print(span[i] == ord("a"))  # True
1 Like

You might want to use the fact that you can or with 0x20 (SPACE) to upper case ascii text and and with the inverse to lowercase.

That’s a neat trick, now that I looked for the stdlib impl I see it used to do (before adding full unicode casing) char ^ (1 << 5) which would be ^ 0x20. Which operation is faster? I imagine each CPU manufacturer has different amount of bitwise-op circuits. I also assumed the compiler would optimize sequential ops with aliased values to some such clever bitwise op

1 Like

^ 0x20 inverts the case, meaning that upper case letters at the start of a sentence will become lower case. Bitwise operations are generally 1 op per cycle even for 64 byte wide vector instructions, and they’re very cheap gate wise (literally 1 gate per bit) so I think that almost any ALU will be able to service them. The compiler may be fixing it, but this is such a well-known trick that I don’t see many issues with using it directly.

2 Likes

Edited the post to add a bit less “magical-number-using” implementation

I think this may actually be a less recognizable version of the trick. You may want to write up the “why” and just use space directly.