`pack_bits` inside of vectorize fails constraint check

import math

from algorithm import vectorize
from bit import pop_count
from memory import pack_bits
from sys import simdwidthof

alias U8_SIMD_WIDTH = simdwidthof[DType.uint8]()


fn count_nuc_content[
    simd_width: Int, nuc: UInt8
](sequence: Span[UInt8]) -> UInt:
    var count = 0
    var ptr = sequence.unsafe_ptr()

    # This works
    # alias nuc_vector = SIMD[DType.uint8, simd_width](nuc)
    # var aligned_end = math.align_down(len(sequence), simd_width)
    # for offset in range(0, aligned_end, simd_width):
    #     var vector = ptr.offset(offset).load[width=simd_width]()
    #     var mask = vector == nuc_vector
    #     var packed = pack_bits(mask)
    #     count += Int(pop_count(packed))

    # for offset in range(aligned_end, len(sequence)):
    #     count += 1 if sequence[offset] == nuc else 0

    # This does not
    @parameter
    fn count_nucs[width: Int](offset: Int):
        alias nuc_vector = SIMD[DType.uint8, width](nuc)
        var vector = ptr.offset(offset).load[width=width]()
        var mask = vector == nuc_vector
        var packed = pack_bits(mask)
        # ^ constraint failed: the width of the bool vector must be the same as the bitwidth of the target type
        count += Int(pop_count(packed))

    vectorize[count_nucs, simd_width](len(sequence))

    return count


def main():
    var seq = "ACTGACTGACGCCCCCCCCCCCCTTTTTTTTT".as_bytes()
    var count = count_nuc_content[U8_SIMD_WIDTH, ord("C")](seq)
    print("C Content:", count)

The commented out “manual” SIMD works. But the vectorize application does not. pack_bits says:

open-source/max/mojo/stdlib/stdlib/builtin/constrained.mojo:58:6: note: constraint failed: the width of the bool vector must be the same as the bitwidth of the target type

Gist is the slightly more fleshed out example: count_nucleotide.mojo · GitHub

I implemented pack_bits, but I’m a bit conflicted about whether we should support the Scalar[bool] case. Maybe we should just return a Scalar[bool] when given one, that would be slightly more consistent. That said, you’re still blocked by pop_count, which only accepts an integral dtype.

That said, since pack_bits is supposed to be a very thin wrapper around pop.bitcast, the fact that the following doesn’t compile seems to argue against that.

fn main():
  val = Scalar(True)
  res = __mlir_op.`pop.bitcast`[_type = Scalar[DType.bool]._mlir_type](val.value)
  # error: MLIR verification error: 
  #   'pop.bitcast' op operand type
  #     '!pop.scalar<bool>'
  #   and result type
  #     '!pop.scalar<bool>' 
  #   are cast incompatible
1 Like

Ah, it’s failing because it will get called with width=1. That at least makes sense now.

Would it be too egregious to have pack_bits return a UInt8 when given a Bool?

But I can at least fix it on my end now:

            @parameter
            if width == 1:
                count += 1 if mask else 0
            else:
                var packed = pack_bits(mask)
                count += Int(pop_count(packed))

Maybe the best fix is just an improved error message for the specific case where width == 1. Once you pointed that out things made a lot more sense, I was just missing that that was even a possibility.

Two small issues:

  • Should we go with scalar<bool> -> scalar<uint8> or scalar<bool> -> scalar<bool>?
  • Either way, we lose the nice invariant: width == bitwidthof[Scalar[new_type]]().

That said, I’m not opposed to either option if we can find a good justification.
We could even add a pack_bits function to the bit module that implements one of these more specific semantics (PR welcome).

P.S. You could also write count += Int(mask).

1 Like

Indeed we can probably do better with constrained (*extra: StaticString for error messages) these days. Are you willing to make a PR?

1 Like

Totally up for making that PR! Added to my Mojo stdlib todo list, might be a few days.

I agree that the width == bitwidthof invariant is nice to maintain… but I can see wanting everything to work with scalar too.

I’ll make a PR for the improved error message since I think that really covers the biggest need, and pushes the handling of width=1 to the caller at least as a starting point.

Good point on Int(mask) :+1:

We recently added the bitset type BitSet | Modular so you could write the code as len(Bitset(some_lhs == some_rhs))

You can see the implementation at modular/mojo/stdlib/src/collections/bitset.mojo at main · modular/modular · GitHub

2 Likes

Interesting, I’ll give that a try too and maybe bench the two against each other.

Maybe we should use pack_bits in the BitSet from SIMD[bool, _]constructor. Also, you might want to use _words.unsafe_get(i) throughout the codebase, since InlineArray.__getitem__ has bounds checking and assertions enabled by default.

Some benchmarks with BitSet. bench_bitset.mojo · GitHub

I was already working on some benchmarks for different methods so I included BitSet in the matrix. It’s just counting the GC content (in a simplistic way) in the whole genome. See comment on main for prepping the input data.


name met (ms) iters GB/s
Manual SIMD, width 32 171.47048266666664 6 18.71625981970759
Vectorized, width 32 169.0004497142857 7 18.989808076994233
BitSet, width 32 46201.248465 2 0.06946319010039766
Manual SIMD, width 64 159.27863557142857 7 20.148879939147864
Vectorized, width 64 158.84849485714287 7 20.203440441070633
BitSet, width 64 47842.169988 2 0.06708069692083299
Naive 5174.215648 2 0.6202459122940677