from memory import UnsafePointer, Pointer
@value
struct _SomeCollectionIter[
is_mutable: Bool, //,
T: CollectionElement,
collection_origin: Origin[is_mutable],
forward: Bool = True,
]:
"""Iterator for `SomeCollection`.
Parameters:
is_mutable: Whether the reference to the collection is mutable.
T: The type of the elements in the collection.
collection_origin: The origin of the collection
forward: The iteration direction. `False` is backwards.
"""
alias collection_type = SomeCollection[T]
var index: Int
var src: Pointer[Self.collection_type, collection_origin]
fn __iter__(self) -> Self:
return self
fn __next__(mut self) -> Pointer[T, collection_origin]:
@parameter
if forward:
self.index += 1
return Pointer.address_of(self.src[][self.index - 1])
else:
self.index -= 1
return Pointer.address_of(self.src[][self.index])
fn __has_next__(self) -> Bool:
return self.__len__() > 0
fn __len__(self) -> Int:
@parameter
if forward:
return self.src[].length - self.index
else:
return self.index
@value
struct SomeCollection[T: CollectionElement]:
"""The `SomeCollection` type.
Parameters:
T: The type of the elements.
"""
var data: UnsafePointer[T]
"""The underlying storage for the `SomeCollection`."""
var length: Int
"""The number of elements in the `SomeCollection`."""
fn __init__(out self, value_1: T, value_2: T):
self.data = UnsafePointer[T].alloc(2)
self.data[0], self.data[1] = value_1, value_2
self.length = 2
fn __getitem__(ref self, idx: Int) -> ref [self] T:
var normalized_idx = idx
debug_assert(
-self.length <= normalized_idx < self.length,
"index: ",
normalized_idx,
" is out of bounds for `SomeCollection` of length: ",
self.length,
)
if normalized_idx < 0:
normalized_idx += self.length
return (self.data + normalized_idx)[]
fn __iter__(ref self) -> _SomeCollectionIter[T, __origin_of(self)]:
return _SomeCollectionIter(0, Pointer.address_of(self))
fn __del__(owned self):
"""Destroy all elements in the `SomeCollection` and free its memory."""
for i in range(self.length):
(self.data + i).destroy_pointee()
self.data.free()
fn main():
for item in SomeCollection(1, 2):
print(item[])
for item in SomeCollection("hello", "world"):
print(item[])
4 Likes