"""InMemoryGraphDB — NetworkX-backed property graph for testing.""" from __future__ import annotations import logging import networkx as nx from aucourt_ingest.storage.graph_db import ( GraphDB, GraphNode, GraphRelationship, _node_id, _rel_id, ) logger = logging.getLogger(__name__) class InMemoryGraphDB: """In-memory graph using NetworkX. No external dependencies beyond networkx. Suitable for testing and standalone use. Production should use Neo4jGraphDB. """ def __init__(self): self._g = nx.MultiDiGraph() self._nodes: dict[str, GraphNode] = {} self._rels: dict[str, GraphRelationship] = {} async def create_node(self, label: str, properties: dict) -> str: node_id = _node_id(label, properties) if node_id not in self._nodes: self._nodes[node_id] = GraphNode(node_id, label, properties.copy()) self._g.add_node(node_id, label=label, **properties) else: # Update existing node properties existing = self._nodes[node_id] existing.properties.update(properties) for k, v in properties.items(): self._g.nodes[node_id][k] = v return node_id async def create_relationship(self, from_id: str, to_id: str, rel_type: str, properties: dict | None = None) -> str: rel_id = _rel_id(from_id, to_id, rel_type) props = properties or {} rel = GraphRelationship(rel_id, from_id, to_id, rel_type, props) self._rels[rel_id] = rel self._g.add_edge(from_id, to_id, key=rel_type, **props) return rel_id async def get_node(self, node_id: str) -> GraphNode | None: return self._nodes.get(node_id) async def get_relationships(self, node_id: str, rel_type: str | None = None, direction: str = "outgoing") -> list[GraphRelationship]: results = [] if direction in ("outgoing", "both"): for u, v, k in self._g.out_edges(node_id, keys=True): rel = self._rels.get(_rel_id(u, v, k)) if rel and (rel_type is None or rel.rel_type == rel_type): results.append(rel) if direction in ("incoming", "both"): for u, v, k in self._g.in_edges(node_id, keys=True): rel = self._rels.get(_rel_id(u, v, k)) if rel and (rel_type is None or rel.rel_type == rel_type): if rel.rel_id not in {r.rel_id for r in results}: results.append(rel) return results async def query_nodes(self, label: str = None, properties: dict | None = None) -> list[GraphNode]: results = [] for node_id, node in self._nodes.items(): if label and node.label != label: continue if properties: if not all(node.properties.get(k) == v for k, v in properties.items()): continue results.append(node) return results async def node_count(self, label: str | None = None) -> int: if label: return sum(1 for n in self._nodes.values() if n.label == label) return len(self._nodes) async def relationship_count(self, rel_type: str | None = None) -> int: if rel_type: return sum(1 for r in self._rels.values() if r.rel_type == rel_type) return len(self._rels) async def neighbors(self, node_id: str, rel_type: str | None = None, direction: str = "outgoing") -> list[str]: """Get adjacent node IDs.""" rels = await self.get_relationships(node_id, rel_type, direction) return [r.to_id if r.from_id == node_id else r.from_id for r in rels] async def shortest_path(self, from_id: str, to_id: str) -> list[str] | None: """Find shortest path between two nodes. Returns node ID list or None.""" try: path = nx.shortest_path(self._g, from_id, to_id) return list(path) if path else None except (nx.NetworkXNoPath, nx.NodeNotFound): return None async def close(self, destroy: bool = True): """Clear graph data. Set destroy=False to keep data (e.g. for tests).""" if destroy: self._g.clear() self._nodes.clear() self._rels.clear()