How to uppercase and lowercase ASCII strings using SIMD in Mojo

from memory import Span
from sys.info import simdwidthof


fn apply[
    D: DType,
    O: MutableOrigin, //,
    func: fn[w: Int] (SIMD[D, w]) -> SIMD[D, w],
    *,
    where: fn[w: Int] (SIMD[D, w]) -> SIMD[DType.bool, w],
](span: Span[Scalar[D], O]):
    """Apply the function to the `Span` inplace.

    Parameters:
        D: The DType.
        O: The origin of the `Span`.
        func: The function to evaluate.
        where: The condition to apply the function.
    """

    alias widths = (256, 128, 64, 32, 16, 8, 4)
    var ptr = span.unsafe_ptr()
    var length = len(span)
    var processed = 0

    @parameter
    for i in range(len(widths)):
        alias w = widths.get[i, Int]()

        @parameter
        if simdwidthof[D]() >= w:
            for _ in range((length - processed) // w):
                var p_curr = ptr + processed
                var vec = p_curr.load[width=w]()                
                p_curr.store(where(vec).select(func(vec), vec))
                processed += w

    for i in range(length - processed):
        var vec = ptr[processed + i]
        if where(vec):
            (ptr + processed + i).init_pointee_move(func(vec))

fn is_lower_ascii[w: Int](value: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
    alias `a` = Byte(ord("a"))
    alias `z` = Byte(ord("z"))
    return (value >= `a`) & (value <= `z`)

fn is_upper_ascii[w: Int](value: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
    alias `A` = Byte(ord("A"))
    alias `Z` = Byte(ord("Z"))
    return (value >= `A`) & (value <= `Z`)

fn toggle_case[w: Int](value: SIMD[DType.uint8, w]) -> SIMD[DType.uint8, w]:
    alias `a` = Byte(ord("a"))
    alias `A` = Byte(ord("A"))
    return value ^ (`A` ^ `a`)

fn main():
    items = List[Byte](
        ord("a"), ord("a"), ord("a"), ord("A"), ord("A"), ord("A")
    )
    tmp_copy = items
    span = Span(tmp_copy)
    apply[func=toggle_case, where=is_lower_ascii](span)
    for i in range(len(items)):
        if i < 3:
            print(span[i] == ord("A"))  # True
        else:
            print(span[i] == ord("A"))  # True

    tmp_copy = items
    span = Span(tmp_copy)

    apply[func=toggle_case, where=is_upper_ascii](span)
    for i in range(len(items)):
        if i < 3:
            print(span[i] == ord("a"))  # True
        else:
            print(span[i] == ord("a"))  # True
1 Like

You might want to use the fact that you can or with 0x20 (SPACE) to upper case ascii text and and with the inverse to lowercase.

That’s a neat trick, now that I looked for the stdlib impl I see it used to do (before adding full unicode casing) char ^ (1 << 5) which would be ^ 0x20. Which operation is faster? I imagine each CPU manufacturer has different amount of bitwise-op circuits. I also assumed the compiler would optimize sequential ops with aliased values to some such clever bitwise op

1 Like

^ 0x20 inverts the case, meaning that upper case letters at the start of a sentence will become lower case. Bitwise operations are generally 1 op per cycle even for 64 byte wide vector instructions, and they’re very cheap gate wise (literally 1 gate per bit) so I think that almost any ALU will be able to service them. The compiler may be fixing it, but this is such a well-known trick that I don’t see many issues with using it directly.

2 Likes

Edited the post to add a bit less “magical-number-using” implementation

I think this may actually be a less recognizable version of the trick. You may want to write up the “why” and just use space directly.

This topic was automatically closed 180 days after the last reply. New replies are no longer allowed.