aucourt-ingest/aucourt_ingest/api/service.py
slothitude 24cde4cdec Audit fixes: response_model validation, error handling, dead code, input sanitisation
- Add response_model to all 8 route endpoints for runtime validation and
  correct Swagger docs
- Remove global KeyError handler (routes catch it explicitly)
- Add catch-all Exception handler with logging for 500 responses
- Remove dead code in service.py get_case_graph (unused bucket variable)
- Explicit graph_backend validation in cmd_serve (memory|neo4j, else exit)
- Sanitise comma-separated query params (strip whitespace, filter empty)
- Move HTTPException to top-level import in routes.py
- Remove unused imports (Depends in dependencies.py, all_persona_names)
- Fix deprecated asyncio.get_event_loop() in test fixture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-30 12:17:55 +10:00

192 lines
6.7 KiB
Python

"""QueryService — composes SubgraphQuery + VectorIndex + GraphDB for the read API."""
from __future__ import annotations
import logging
from aucourt_ingest.jury.personas import PERSONAS, get_persona
from aucourt_ingest.jury.subgraph_query import SubgraphQuery
from aucourt_ingest.storage.graph_db import GraphDB
logger = logging.getLogger(__name__)
class QueryService:
"""Read-only query service for juror RAG queries."""
def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000):
self._graph_db = graph_db
self._vector_index = vector_index
self._subgraph = SubgraphQuery(graph_db)
self._max_tokens = max_tokens
def list_personas(self) -> list[dict]:
return [
{
"name": name,
"anchor_nodes": p.anchor_nodes,
"edge_types": p.edge_types,
"chunk_types": p.chunk_types,
}
for name, p in PERSONAS.items()
]
async def list_cases(self, court: str | None = None,
limit: int = 50, offset: int = 0) -> list[dict]:
nodes = await self._graph_db.query_nodes("Case")
cases = []
for node in nodes:
if court and node.properties.get("court") != court:
continue
cases.append({
"mnc": node.properties.get("mnc", node.node_id),
"court": node.properties.get("court", ""),
"date": node.properties.get("date", ""),
"jurisdiction": node.properties.get("jurisdiction", ""),
"matter_type": node.properties.get("matter_type", ""),
"verdict": node.properties.get("verdict", ""),
})
cases.sort(key=lambda c: c.get("date", ""), reverse=True)
return cases[offset:offset + limit]
async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
results = self._vector_index.query(q, top_k=limit)
seen: set[str] = set()
cases = []
for r in results:
doc_id = r["metadata"].get("doc_id", "")
if doc_id and doc_id not in seen:
seen.add(doc_id)
cases.append(await self._get_case_by_mnc(doc_id))
return [c for c in cases if c]
async def get_juror_context(self, case_mnc: str, persona_name: str,
max_tokens: int | None = None) -> dict:
persona = get_persona(persona_name)
budget = max_tokens or self._max_tokens
ctx = await self._subgraph.get_juror_context(case_mnc, persona, budget)
return {
"persona": ctx.persona,
"case_mnc": ctx.case_mnc,
"context_text": ctx.context_text,
"source_chunk_ids": ctx.source_chunk_ids,
"total_tokens": ctx.total_tokens,
}
def vector_search(self, query: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None) -> list[dict]:
return self._vector_index.query(
query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
)
async def hybrid_search(self, query: str, persona_name: str | None = None,
top_k: int = 10,
max_tokens: int | None = None) -> list[dict]:
vector_results = self._vector_index.query(query, top_k=top_k)
# Group by doc_id
by_doc: dict[str, list[dict]] = {}
for r in vector_results:
doc_id = r["metadata"].get("doc_id", "unknown")
by_doc.setdefault(doc_id, []).append(r)
results = []
for doc_mnc, chunks in by_doc.items():
entry: dict = {
"case_mnc": doc_mnc,
"chunks": chunks,
"juror_context": None,
}
if persona_name and persona_name in PERSONAS:
ctx = await self.get_juror_context(
doc_mnc, persona_name, max_tokens,
)
entry["juror_context"] = ctx
results.append(entry)
return results
async def get_case_graph(self, case_mnc: str) -> dict | None:
case_id = f"Case:{case_mnc}"
case_node = await self._graph_db.get_node(case_id)
if case_node is None:
return None
result = {
"case": _node_to_dict(case_node),
"judges": [],
"charges": [],
"rulings": [],
"witnesses": [],
"chunks": [],
"relationships": [],
}
rels = await self._graph_db.get_relationships(case_id, direction="both")
neighbor_ids: set[str] = set()
for rel in rels:
result["relationships"].append(_rel_to_dict(rel))
neighbor_id = rel.to_id if rel.from_id == case_id else rel.from_id
neighbor_ids.add(neighbor_id)
for nid in neighbor_ids:
node = await self._graph_db.get_node(nid)
if node is None:
continue
label_map = {
"Judge": "judges",
"Charge": "charges",
"Ruling": "rulings",
"Witness": "witnesses",
"Chunk": "chunks",
}
key = label_map.get(node.label)
if key and key in result:
result[key].append(_node_to_dict(node))
return result
async def health(self) -> dict:
total_nodes = await self._graph_db.node_count()
total_edges = await self._graph_db.relationship_count()
cases = await self._graph_db.node_count("Case")
chunks = await self._graph_db.node_count("Chunk")
return {
"status": "ok",
"nodes": total_nodes,
"edges": total_edges,
"cases": cases,
"chunks": chunks,
}
async def _get_case_by_mnc(self, mnc: str) -> dict | None:
case_id = f"Case:{mnc}"
node = await self._graph_db.get_node(case_id)
if node is None:
return None
return {
"mnc": node.properties.get("mnc", ""),
"court": node.properties.get("court", ""),
"date": node.properties.get("date", ""),
"jurisdiction": node.properties.get("jurisdiction", ""),
"matter_type": node.properties.get("matter_type", ""),
"verdict": node.properties.get("verdict", ""),
}
def _node_to_dict(node) -> dict:
return {
"node_id": node.node_id,
"label": node.label,
"properties": node.properties,
}
def _rel_to_dict(rel) -> dict:
return {
"rel_id": rel.rel_id,
"from_id": rel.from_id,
"to_id": rel.to_id,
"rel_type": rel.rel_type,
"properties": rel.properties,
}