Resolve all deferred audit findings: type safety, async, isolation, dates

- #11 VectorIndexProtocol for typed duck-typing of vector_index param
- #10 Wrap all sync VectorIndex.query() calls in asyncio.to_thread
- #9 Make vector_search async (consistent with to_thread wrapping)
- #12 Replace module-level singleton with app.state.query_service
- #13 InMemoryGraphDB.close(destroy=False) guard for test safety
- #14 Parse dates with datetime.fromisoformat in list_cases sort

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
slothitude 2026-05-30 12:20:56 +10:00
parent 24cde4cdec
commit f799b41ed9
5 changed files with 57 additions and 21 deletions

View file

@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from aucourt_ingest.api import dependencies
from aucourt_ingest.api.errors import register_error_handlers from aucourt_ingest.api.errors import register_error_handlers
from aucourt_ingest.api.routes import router from aucourt_ingest.api.routes import router
from aucourt_ingest.api.service import QueryService from aucourt_ingest.api.service import QueryService
@ -21,7 +20,7 @@ def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastA
vector_index: VectorIndex instance (or duck-typed fake for tests) vector_index: VectorIndex instance (or duck-typed fake for tests)
max_tokens: Default token budget for juror context assembly max_tokens: Default token budget for juror context assembly
""" """
dependencies._query_service = QueryService(graph_db, vector_index, max_tokens) query_service = QueryService(graph_db, vector_index, max_tokens)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@ -35,6 +34,7 @@ def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastA
lifespan=lifespan, lifespan=lifespan,
) )
app.state.query_service = query_service
app.include_router(router) app.include_router(router)
register_error_handlers(app) register_error_handlers(app)

View file

@ -2,12 +2,13 @@
from __future__ import annotations from __future__ import annotations
from fastapi import Request
from aucourt_ingest.api.service import QueryService from aucourt_ingest.api.service import QueryService
_query_service: QueryService | None = None
def get_query_service(request: Request) -> QueryService:
def get_query_service() -> QueryService: svc = request.app.state.query_service
if _query_service is None: if svc is None:
raise RuntimeError("QueryService not initialised — call create_app() first") raise RuntimeError("QueryService not initialised — call create_app() first")
return _query_service return svc

View file

@ -80,7 +80,7 @@ async def vector_search(
): ):
types = [t.strip() for t in chunk_types.split(",") if t.strip()] if chunk_types else None types = [t.strip() for t in chunk_types.split(",") if t.strip()] if chunk_types else None
docs = [d.strip() for d in doc_ids.split(",") if d.strip()] if doc_ids else None docs = [d.strip() for d in doc_ids.split(",") if d.strip()] if doc_ids else None
return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs) return await svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
@router.get("/hybrid", response_model=list[HybridResult]) @router.get("/hybrid", response_model=list[HybridResult])

View file

@ -2,7 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from datetime import datetime
from typing import Protocol, runtime_checkable
from aucourt_ingest.jury.personas import PERSONAS, get_persona from aucourt_ingest.jury.personas import PERSONAS, get_persona
from aucourt_ingest.jury.subgraph_query import SubgraphQuery from aucourt_ingest.jury.subgraph_query import SubgraphQuery
@ -11,10 +14,26 @@ from aucourt_ingest.storage.graph_db import GraphDB
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@runtime_checkable
class VectorIndexProtocol(Protocol):
"""Protocol for VectorIndex — used for type-safe duck-typing."""
def query(self, text: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None,
embedding: list[float] | None = None) -> list[dict]:
...
@property
def count(self) -> int:
...
class QueryService: class QueryService:
"""Read-only query service for juror RAG queries.""" """Read-only query service for juror RAG queries."""
def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000): def __init__(self, graph_db: GraphDB, vector_index: VectorIndexProtocol,
max_tokens: int = 4000):
self._graph_db = graph_db self._graph_db = graph_db
self._vector_index = vector_index self._vector_index = vector_index
self._subgraph = SubgraphQuery(graph_db) self._subgraph = SubgraphQuery(graph_db)
@ -46,11 +65,11 @@ class QueryService:
"matter_type": node.properties.get("matter_type", ""), "matter_type": node.properties.get("matter_type", ""),
"verdict": node.properties.get("verdict", ""), "verdict": node.properties.get("verdict", ""),
}) })
cases.sort(key=lambda c: c.get("date", ""), reverse=True) cases.sort(key=lambda c: _parse_date_key(c.get("date", "")), reverse=True)
return cases[offset:offset + limit] return cases[offset:offset + limit]
async def search_cases(self, q: str, limit: int = 10) -> list[dict]: async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
results = self._vector_index.query(q, top_k=limit) results = await asyncio.to_thread(self._vector_index.query, q, top_k=limit)
seen: set[str] = set() seen: set[str] = set()
cases = [] cases = []
for r in results: for r in results:
@ -73,17 +92,21 @@ class QueryService:
"total_tokens": ctx.total_tokens, "total_tokens": ctx.total_tokens,
} }
def vector_search(self, query: str, top_k: int = 10, async def vector_search(self, query: str, top_k: int = 10,
chunk_types: list[str] | None = None, chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None) -> list[dict]: doc_ids: list[str] | None = None) -> list[dict]:
return self._vector_index.query( results = await asyncio.to_thread(
query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids, self._vector_index.query, query,
top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
) )
return results
async def hybrid_search(self, query: str, persona_name: str | None = None, async def hybrid_search(self, query: str, persona_name: str | None = None,
top_k: int = 10, top_k: int = 10,
max_tokens: int | None = None) -> list[dict]: max_tokens: int | None = None) -> list[dict]:
vector_results = self._vector_index.query(query, top_k=top_k) vector_results = await asyncio.to_thread(
self._vector_index.query, query, top_k=top_k,
)
# Group by doc_id # Group by doc_id
by_doc: dict[str, list[dict]] = {} by_doc: dict[str, list[dict]] = {}
for r in vector_results: for r in vector_results:
@ -190,3 +213,13 @@ def _rel_to_dict(rel) -> dict:
"rel_type": rel.rel_type, "rel_type": rel.rel_type,
"properties": rel.properties, "properties": rel.properties,
} }
def _parse_date_key(date_str: str) -> datetime | None:
"""Parse ISO 8601 date string for sorting. Returns None for empty/invalid."""
if not date_str:
return None
try:
return datetime.fromisoformat(date_str[:10])
except (ValueError, IndexError):
return None

View file

@ -103,7 +103,9 @@ class InMemoryGraphDB:
except (nx.NetworkXNoPath, nx.NodeNotFound): except (nx.NetworkXNoPath, nx.NodeNotFound):
return None return None
async def close(self): async def close(self, destroy: bool = True):
self._g.clear() """Clear graph data. Set destroy=False to keep data (e.g. for tests)."""
self._nodes.clear() if destroy:
self._rels.clear() self._g.clear()
self._nodes.clear()
self._rels.clear()