aucourt-ingest/aucourt_ingest/api/service.py
slothitude f799b41ed9 Resolve all deferred audit findings: type safety, async, isolation, dates
- #11 VectorIndexProtocol for typed duck-typing of vector_index param
- #10 Wrap all sync VectorIndex.query() calls in asyncio.to_thread
- #9 Make vector_search async (consistent with to_thread wrapping)
- #12 Replace module-level singleton with app.state.query_service
- #13 InMemoryGraphDB.close(destroy=False) guard for test safety
- #14 Parse dates with datetime.fromisoformat in list_cases sort

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-30 12:20:56 +10:00

225 lines
7.7 KiB
Python

"""QueryService — composes SubgraphQuery + VectorIndex + GraphDB for the read API."""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime
from typing import Protocol, runtime_checkable
from aucourt_ingest.jury.personas import PERSONAS, get_persona
from aucourt_ingest.jury.subgraph_query import SubgraphQuery
from aucourt_ingest.storage.graph_db import GraphDB
logger = logging.getLogger(__name__)
@runtime_checkable
class VectorIndexProtocol(Protocol):
"""Protocol for VectorIndex — used for type-safe duck-typing."""
def query(self, text: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None,
embedding: list[float] | None = None) -> list[dict]:
...
@property
def count(self) -> int:
...
class QueryService:
"""Read-only query service for juror RAG queries."""
def __init__(self, graph_db: GraphDB, vector_index: VectorIndexProtocol,
max_tokens: int = 4000):
self._graph_db = graph_db
self._vector_index = vector_index
self._subgraph = SubgraphQuery(graph_db)
self._max_tokens = max_tokens
def list_personas(self) -> list[dict]:
return [
{
"name": name,
"anchor_nodes": p.anchor_nodes,
"edge_types": p.edge_types,
"chunk_types": p.chunk_types,
}
for name, p in PERSONAS.items()
]
async def list_cases(self, court: str | None = None,
limit: int = 50, offset: int = 0) -> list[dict]:
nodes = await self._graph_db.query_nodes("Case")
cases = []
for node in nodes:
if court and node.properties.get("court") != court:
continue
cases.append({
"mnc": node.properties.get("mnc", node.node_id),
"court": node.properties.get("court", ""),
"date": node.properties.get("date", ""),
"jurisdiction": node.properties.get("jurisdiction", ""),
"matter_type": node.properties.get("matter_type", ""),
"verdict": node.properties.get("verdict", ""),
})
cases.sort(key=lambda c: _parse_date_key(c.get("date", "")), reverse=True)
return cases[offset:offset + limit]
async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
results = await asyncio.to_thread(self._vector_index.query, q, top_k=limit)
seen: set[str] = set()
cases = []
for r in results:
doc_id = r["metadata"].get("doc_id", "")
if doc_id and doc_id not in seen:
seen.add(doc_id)
cases.append(await self._get_case_by_mnc(doc_id))
return [c for c in cases if c]
async def get_juror_context(self, case_mnc: str, persona_name: str,
max_tokens: int | None = None) -> dict:
persona = get_persona(persona_name)
budget = max_tokens or self._max_tokens
ctx = await self._subgraph.get_juror_context(case_mnc, persona, budget)
return {
"persona": ctx.persona,
"case_mnc": ctx.case_mnc,
"context_text": ctx.context_text,
"source_chunk_ids": ctx.source_chunk_ids,
"total_tokens": ctx.total_tokens,
}
async def vector_search(self, query: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None) -> list[dict]:
results = await asyncio.to_thread(
self._vector_index.query, query,
top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
)
return results
async def hybrid_search(self, query: str, persona_name: str | None = None,
top_k: int = 10,
max_tokens: int | None = None) -> list[dict]:
vector_results = await asyncio.to_thread(
self._vector_index.query, query, top_k=top_k,
)
# Group by doc_id
by_doc: dict[str, list[dict]] = {}
for r in vector_results:
doc_id = r["metadata"].get("doc_id", "unknown")
by_doc.setdefault(doc_id, []).append(r)
results = []
for doc_mnc, chunks in by_doc.items():
entry: dict = {
"case_mnc": doc_mnc,
"chunks": chunks,
"juror_context": None,
}
if persona_name and persona_name in PERSONAS:
ctx = await self.get_juror_context(
doc_mnc, persona_name, max_tokens,
)
entry["juror_context"] = ctx
results.append(entry)
return results
async def get_case_graph(self, case_mnc: str) -> dict | None:
case_id = f"Case:{case_mnc}"
case_node = await self._graph_db.get_node(case_id)
if case_node is None:
return None
result = {
"case": _node_to_dict(case_node),
"judges": [],
"charges": [],
"rulings": [],
"witnesses": [],
"chunks": [],
"relationships": [],
}
rels = await self._graph_db.get_relationships(case_id, direction="both")
neighbor_ids: set[str] = set()
for rel in rels:
result["relationships"].append(_rel_to_dict(rel))
neighbor_id = rel.to_id if rel.from_id == case_id else rel.from_id
neighbor_ids.add(neighbor_id)
for nid in neighbor_ids:
node = await self._graph_db.get_node(nid)
if node is None:
continue
label_map = {
"Judge": "judges",
"Charge": "charges",
"Ruling": "rulings",
"Witness": "witnesses",
"Chunk": "chunks",
}
key = label_map.get(node.label)
if key and key in result:
result[key].append(_node_to_dict(node))
return result
async def health(self) -> dict:
total_nodes = await self._graph_db.node_count()
total_edges = await self._graph_db.relationship_count()
cases = await self._graph_db.node_count("Case")
chunks = await self._graph_db.node_count("Chunk")
return {
"status": "ok",
"nodes": total_nodes,
"edges": total_edges,
"cases": cases,
"chunks": chunks,
}
async def _get_case_by_mnc(self, mnc: str) -> dict | None:
case_id = f"Case:{mnc}"
node = await self._graph_db.get_node(case_id)
if node is None:
return None
return {
"mnc": node.properties.get("mnc", ""),
"court": node.properties.get("court", ""),
"date": node.properties.get("date", ""),
"jurisdiction": node.properties.get("jurisdiction", ""),
"matter_type": node.properties.get("matter_type", ""),
"verdict": node.properties.get("verdict", ""),
}
def _node_to_dict(node) -> dict:
return {
"node_id": node.node_id,
"label": node.label,
"properties": node.properties,
}
def _rel_to_dict(rel) -> dict:
return {
"rel_id": rel.rel_id,
"from_id": rel.from_id,
"to_id": rel.to_id,
"rel_type": rel.rel_type,
"properties": rel.properties,
}
def _parse_date_key(date_str: str) -> datetime | None:
"""Parse ISO 8601 date string for sorting. Returns None for empty/invalid."""
if not date_str:
return None
try:
return datetime.fromisoformat(date_str[:10])
except (ValueError, IndexError):
return None