aucourt-ingest/aucourt_ingest/storage/in_memory_graph_db.py

112 lines
4.4 KiB
Python
Raw Normal View History

"""InMemoryGraphDB — NetworkX-backed property graph for testing."""
from __future__ import annotations
import logging
import networkx as nx
from aucourt_ingest.storage.graph_db import (
GraphDB, GraphNode, GraphRelationship,
_node_id, _rel_id,
)
logger = logging.getLogger(__name__)
class InMemoryGraphDB:
"""In-memory graph using NetworkX. No external dependencies beyond networkx.
Suitable for testing and standalone use. Production should use Neo4jGraphDB.
"""
def __init__(self):
self._g = nx.MultiDiGraph()
self._nodes: dict[str, GraphNode] = {}
self._rels: dict[str, GraphRelationship] = {}
async def create_node(self, label: str, properties: dict) -> str:
node_id = _node_id(label, properties)
if node_id not in self._nodes:
self._nodes[node_id] = GraphNode(node_id, label, properties.copy())
self._g.add_node(node_id, label=label, **properties)
else:
# Update existing node properties
existing = self._nodes[node_id]
existing.properties.update(properties)
for k, v in properties.items():
self._g.nodes[node_id][k] = v
return node_id
async def create_relationship(self, from_id: str, to_id: str,
rel_type: str, properties: dict | None = None) -> str:
rel_id = _rel_id(from_id, to_id, rel_type)
props = properties or {}
rel = GraphRelationship(rel_id, from_id, to_id, rel_type, props)
self._rels[rel_id] = rel
self._g.add_edge(from_id, to_id, key=rel_type, **props)
return rel_id
async def get_node(self, node_id: str) -> GraphNode | None:
return self._nodes.get(node_id)
async def get_relationships(self, node_id: str,
rel_type: str | None = None,
direction: str = "outgoing") -> list[GraphRelationship]:
results = []
if direction in ("outgoing", "both"):
for u, v, k in self._g.out_edges(node_id, keys=True):
rel = self._rels.get(_rel_id(u, v, k))
if rel and (rel_type is None or rel.rel_type == rel_type):
results.append(rel)
if direction in ("incoming", "both"):
for u, v, k in self._g.in_edges(node_id, keys=True):
rel = self._rels.get(_rel_id(u, v, k))
if rel and (rel_type is None or rel.rel_type == rel_type):
if rel.rel_id not in {r.rel_id for r in results}:
results.append(rel)
return results
async def query_nodes(self, label: str = None,
properties: dict | None = None) -> list[GraphNode]:
results = []
for node_id, node in self._nodes.items():
if label and node.label != label:
continue
if properties:
if not all(node.properties.get(k) == v for k, v in properties.items()):
continue
results.append(node)
return results
async def node_count(self, label: str | None = None) -> int:
if label:
return sum(1 for n in self._nodes.values() if n.label == label)
return len(self._nodes)
async def relationship_count(self, rel_type: str | None = None) -> int:
if rel_type:
return sum(1 for r in self._rels.values() if r.rel_type == rel_type)
return len(self._rels)
async def neighbors(self, node_id: str, rel_type: str | None = None,
direction: str = "outgoing") -> list[str]:
"""Get adjacent node IDs."""
rels = await self.get_relationships(node_id, rel_type, direction)
return [r.to_id if r.from_id == node_id else r.from_id for r in rels]
async def shortest_path(self, from_id: str, to_id: str) -> list[str] | None:
"""Find shortest path between two nodes. Returns node ID list or None."""
try:
path = nx.shortest_path(self._g, from_id, to_id)
return list(path) if path else None
except (nx.NetworkXNoPath, nx.NodeNotFound):
return None
async def close(self, destroy: bool = True):
"""Clear graph data. Set destroy=False to keep data (e.g. for tests)."""
if destroy:
self._g.clear()
self._nodes.clear()
self._rels.clear()