How to build a generic collection with an Iterator in Mojo

from memory import UnsafePointer, Pointer

@value
struct _SomeCollectionIter[
    is_mutable: Bool, //,
    T: CollectionElement,
    collection_origin: Origin[is_mutable],
    forward: Bool = True,
]:
    """Iterator for `SomeCollection`.

    Parameters:
        is_mutable: Whether the reference to the collection is mutable.
        T: The type of the elements in the collection.
        collection_origin: The origin of the collection
        forward: The iteration direction. `False` is backwards.
    """

    alias collection_type = SomeCollection[T]

    var index: Int
    var src: Pointer[Self.collection_type, collection_origin]

    fn __iter__(self) -> Self:
        return self

    fn __next__(mut self) -> Pointer[T, collection_origin]:
        @parameter
        if forward:
            self.index += 1
            return Pointer.address_of(self.src[][self.index - 1])
        else:
            self.index -= 1
            return Pointer.address_of(self.src[][self.index])

    fn __has_next__(self) -> Bool:
        return self.__len__() > 0

    fn __len__(self) -> Int:
        @parameter
        if forward:
            return self.src[].length - self.index
        else:
            return self.index


@value
struct SomeCollection[T: CollectionElement]:
    """The `SomeCollection` type.

    Parameters:
        T: The type of the elements.
    """

    var data: UnsafePointer[T]
    """The underlying storage for the `SomeCollection`."""
    var length: Int
    """The number of elements in the `SomeCollection`."""

    fn __init__(out self, value_1: T, value_2: T):
        self.data = UnsafePointer[T].alloc(2)
        self.data[0], self.data[1] = value_1, value_2
        self.length = 2

    fn __getitem__(ref self, idx: Int) -> ref [self] T:

        var normalized_idx = idx

        debug_assert(
            -self.length <= normalized_idx < self.length,
            "index: ",
            normalized_idx,
            " is out of bounds for `SomeCollection` of length: ",
            self.length,
        )
        if normalized_idx < 0:
            normalized_idx += self.length

        return (self.data + normalized_idx)[]


    fn __iter__(ref self) -> _SomeCollectionIter[T, __origin_of(self)]:
        return _SomeCollectionIter(0, Pointer.address_of(self))

    fn __del__(owned self):
        """Destroy all elements in the `SomeCollection` and free its memory."""

        for i in range(self.length):
            (self.data + i).destroy_pointee()
        self.data.free()


fn main():
    for item in SomeCollection(1, 2):
        print(item[])

    for item in SomeCollection("hello", "world"):
        print(item[])
4 Likes