diff --git a/cheuph/element_supply.py b/cheuph/element_supply.py index 5f170e3..885a969 100644 --- a/cheuph/element_supply.py +++ b/cheuph/element_supply.py @@ -1,5 +1,5 @@ import abc -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set from .element import Element, Id from .exceptions import TreeException @@ -170,7 +170,7 @@ class MemoryElementSupply(ElementSupply): def __init__(self) -> None: self._elements: Dict[Id, Element] = {} - self._children: Dict[Id, List[Id]] = {} + self._children: Dict[Optional[Id], Set[Id]] = {None: set()} def add(self, element: Element) -> None: """ @@ -181,10 +181,8 @@ class MemoryElementSupply(ElementSupply): self.remove(element.id) self._elements[element.id] = element - self._children[element.id] = [] - - if element.parent_id is not None: - self._children[element.parent_id].append(element.id) + self._children[element.id] = set() + self._children[element.parent_id].add(element.id) def remove(self, element_id: Id) -> None: """ @@ -197,9 +195,7 @@ class MemoryElementSupply(ElementSupply): self._elements.pop(element_id) self._children.pop(element_id) - - if element.parent_id is not None: - self._children[element.parent_id].remove(element.id) + self._children[element.parent_id].remove(element.id) def get(self, element_id: Id) -> Element: result = self._elements.get(element_id) @@ -210,13 +206,9 @@ class MemoryElementSupply(ElementSupply): return result def get_children_ids(self, element_id: Optional[Id]) -> List[Id]: - result: Optional[List[Id]] - if element_id is None: - result = list(element.id for element in self._elements.values()) - else: - result = self._children.get(element_id) + result = self._children.get(element_id) if result is None: raise TreeException(f"Element with id {element_id!r} could not be found") - return result + return list(sorted(result))