How to get a mutable reference in a deeply nested tree

I have a nested tree (like a DOM). I want to get a mutable reference to a deeply nested node but can’t quite figure out the proper syntax for such a reference.

from testing import assert_equal

@fieldwise_init
struct Node(Stringable, Copyable, Movable, Representable):
    var name: String
    var children: List[Node]

    fn __str__(self: Self) -> String:
        return "{" + String(self.name) + self.children.__repr__() + "}"

    fn __repr__(self: Self) -> String:
        return self.__str__()

fn get(node: Node, path: List[Int]) -> Node:
    var current = node
    for index in path:
        current = current.children[index]
    return current

def main():
    root = Node("a", [
        Node("b", []),
        Node("c", []),
    ])

    node = get(root, [0])
    node.name = "d"
    assert_equal(root.children[0].name, "d")  # this errors because it is still "b"

How can I change the get fn to make this work ?

You’re returning by value, it should be returning ref Node, and ideally it should be a method on Node which takes ref self in order to be parametric over mutability.

Yes, I tried the following but get various errors that I don’t know how to fix

fn get[origin: MutableOrigin](ref [origin] node: Node, path: List[Int]) -> ref [origin] Node:
    var current = node
    for index in path:
        current = current.children[index]
    return current
/Users/romain/Projects/jsonpatch/main.mojo:21:12: error: cannot return reference with incompatible origin: 'current' vs 'origin._mlir_origin'
    return current
fn get[origin: MutableOrigin](ref [origin] node: Node, path: List[Int]) -> ref [origin] Node:
    for index in path:
        node = node.children[index]
    return node

and this one

Please submit a bug report to https://github.com/modular/modular/issues and include the crash backtrace along with all the relevant source codes.
Stack dump:
0.      Program arguments: mojo run -I . main.mojo
mojo crashed!
Please file a bug report.
[81957:1036473:20250620,164039.945776:WARNING in_range_cast.h:38] value -634136515 out of range
[81957:1036473:20250620,164040.113683:WARNING crash_report_exception_handler.cc:257] UniversalExceptionRaise: (os/kern) failure (5)
from testing import assert_equal


@fieldwise_init
struct Node(Stringable, Copyable, Movable, Representable):
    var name: String
    var children: List[Node]

    fn __str__(self: Self) -> String:
        return "{" + String(self.name) + self.children.__repr__() + "}"

    fn __repr__(self: Self) -> String:
        return self.__str__()

fn get(ref node: Node, path: List[Int]) -> ref[node]Node:
    var current = UnsafePointer(to=node)
    for index in path:
        current = UnsafePointer(to=current[].children[index])
    return current[]

def main():
    root = Node("a", [
        Node("b", []),
        Node("c", []),
    ])

    ref node = get(root, [0])
    node.name = "d"
    assert_equal(root.children[0].name, "d")  

Did you want something like this? If you want to use safepointer right now it requires rebind I think.


fn get(ref node: Node, path: List[Int]) -> ref[node]Node:
    var current = Pointer(to=node)
    for index in path:
        current = rebind[Pointer[Node,__origin_of(node)]](Pointer(to=current[].children[index]))
    return current[]

Or if you want other thing I don’t know.

Thank you !

For reference here is what the final code looks like

@fieldwise_init
struct Node(Stringable, Copyable, Movable, Representable):
    var name: String
    var children: List[Node]

    fn __str__(self) -> String:
        return "{" + String(self.name) + self.children.__repr__() + "}"

    fn __repr__(self) -> String:
        return self.__str__()

    fn get(ref self, path: List[Int]) -> ref [self] Node:
        var current = Pointer(to=self)
        for index in path:
            current_ = Pointer(to=current[].children[index])
            current = rebind[Pointer[Node, __origin_of(self)]](current_)
        return current[]

    fn set(mut self, path: List[Int], node: Node):
        if len(path) == 0:
            return None

        ref parent = self.get(path[:-1])

        index = path[-1]
        parent.children[index] = node

    fn insert(mut self, path: List[Int], node: Node):
        if len(path) == 0:
            return None

        ref parent = self.get(path[:-1])

        index = path[-1]
        if index >= len(parent.children):
            parent.children.append(node)
        else:
            parent.children.insert(index, node)


def main():
    root = Node("a", [
        Node("b", []),
        Node("c", []),
    ])

    print(root.__str__())

    print("get")
    b1 = root.get([0])
    b1.name = "d"
    print(root.__str__())
    assert_equal(root.children[0].name, "b")

    ref b2 = root.get([0])
    b2.name = "d"
    print(root.__str__())
    assert_equal(root.children[0].name, "d")

    print("set")
    root.set([1], Node("f", []))
    print(root.__str__())
    assert_equal(root.children[1].name, "f")

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.