- #11 VectorIndexProtocol for typed duck-typing of vector_index param - #10 Wrap all sync VectorIndex.query() calls in asyncio.to_thread - #9 Make vector_search async (consistent with to_thread wrapping) - #12 Replace module-level singleton with app.state.query_service - #13 InMemoryGraphDB.close(destroy=False) guard for test safety - #14 Parse dates with datetime.fromisoformat in list_cases sort Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
111 lines
4.4 KiB
Python
111 lines
4.4 KiB
Python
"""InMemoryGraphDB — NetworkX-backed property graph for testing."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
import networkx as nx
|
|
|
|
from aucourt_ingest.storage.graph_db import (
|
|
GraphDB, GraphNode, GraphRelationship,
|
|
_node_id, _rel_id,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InMemoryGraphDB:
|
|
"""In-memory graph using NetworkX. No external dependencies beyond networkx.
|
|
|
|
Suitable for testing and standalone use. Production should use Neo4jGraphDB.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._g = nx.MultiDiGraph()
|
|
self._nodes: dict[str, GraphNode] = {}
|
|
self._rels: dict[str, GraphRelationship] = {}
|
|
|
|
async def create_node(self, label: str, properties: dict) -> str:
|
|
node_id = _node_id(label, properties)
|
|
if node_id not in self._nodes:
|
|
self._nodes[node_id] = GraphNode(node_id, label, properties.copy())
|
|
self._g.add_node(node_id, label=label, **properties)
|
|
else:
|
|
# Update existing node properties
|
|
existing = self._nodes[node_id]
|
|
existing.properties.update(properties)
|
|
for k, v in properties.items():
|
|
self._g.nodes[node_id][k] = v
|
|
return node_id
|
|
|
|
async def create_relationship(self, from_id: str, to_id: str,
|
|
rel_type: str, properties: dict | None = None) -> str:
|
|
rel_id = _rel_id(from_id, to_id, rel_type)
|
|
props = properties or {}
|
|
rel = GraphRelationship(rel_id, from_id, to_id, rel_type, props)
|
|
self._rels[rel_id] = rel
|
|
self._g.add_edge(from_id, to_id, key=rel_type, **props)
|
|
return rel_id
|
|
|
|
async def get_node(self, node_id: str) -> GraphNode | None:
|
|
return self._nodes.get(node_id)
|
|
|
|
async def get_relationships(self, node_id: str,
|
|
rel_type: str | None = None,
|
|
direction: str = "outgoing") -> list[GraphRelationship]:
|
|
results = []
|
|
if direction in ("outgoing", "both"):
|
|
for u, v, k in self._g.out_edges(node_id, keys=True):
|
|
rel = self._rels.get(_rel_id(u, v, k))
|
|
if rel and (rel_type is None or rel.rel_type == rel_type):
|
|
results.append(rel)
|
|
if direction in ("incoming", "both"):
|
|
for u, v, k in self._g.in_edges(node_id, keys=True):
|
|
rel = self._rels.get(_rel_id(u, v, k))
|
|
if rel and (rel_type is None or rel.rel_type == rel_type):
|
|
if rel.rel_id not in {r.rel_id for r in results}:
|
|
results.append(rel)
|
|
return results
|
|
|
|
async def query_nodes(self, label: str = None,
|
|
properties: dict | None = None) -> list[GraphNode]:
|
|
results = []
|
|
for node_id, node in self._nodes.items():
|
|
if label and node.label != label:
|
|
continue
|
|
if properties:
|
|
if not all(node.properties.get(k) == v for k, v in properties.items()):
|
|
continue
|
|
results.append(node)
|
|
return results
|
|
|
|
async def node_count(self, label: str | None = None) -> int:
|
|
if label:
|
|
return sum(1 for n in self._nodes.values() if n.label == label)
|
|
return len(self._nodes)
|
|
|
|
async def relationship_count(self, rel_type: str | None = None) -> int:
|
|
if rel_type:
|
|
return sum(1 for r in self._rels.values() if r.rel_type == rel_type)
|
|
return len(self._rels)
|
|
|
|
async def neighbors(self, node_id: str, rel_type: str | None = None,
|
|
direction: str = "outgoing") -> list[str]:
|
|
"""Get adjacent node IDs."""
|
|
rels = await self.get_relationships(node_id, rel_type, direction)
|
|
return [r.to_id if r.from_id == node_id else r.from_id for r in rels]
|
|
|
|
async def shortest_path(self, from_id: str, to_id: str) -> list[str] | None:
|
|
"""Find shortest path between two nodes. Returns node ID list or None."""
|
|
try:
|
|
path = nx.shortest_path(self._g, from_id, to_id)
|
|
return list(path) if path else None
|
|
except (nx.NetworkXNoPath, nx.NodeNotFound):
|
|
return None
|
|
|
|
async def close(self, destroy: bool = True):
|
|
"""Clear graph data. Set destroy=False to keep data (e.g. for tests)."""
|
|
if destroy:
|
|
self._g.clear()
|
|
self._nodes.clear()
|
|
self._rels.clear()
|