Resolve all deferred audit findings: type safety, async, isolation, dates

- #11 VectorIndexProtocol for typed duck-typing of vector_index param
- #10 Wrap all sync VectorIndex.query() calls in asyncio.to_thread
- #9 Make vector_search async (consistent with to_thread wrapping)
- #12 Replace module-level singleton with app.state.query_service
- #13 InMemoryGraphDB.close(destroy=False) guard for test safety
- #14 Parse dates with datetime.fromisoformat in list_cases sort

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
slothitude 2026-05-30 12:20:56 +10:00
parent 24cde4cdec
commit f799b41ed9
5 changed files with 57 additions and 21 deletions

View file

@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from aucourt_ingest.api import dependencies
from aucourt_ingest.api.errors import register_error_handlers
from aucourt_ingest.api.routes import router
from aucourt_ingest.api.service import QueryService
@ -21,7 +20,7 @@ def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastA
vector_index: VectorIndex instance (or duck-typed fake for tests)
max_tokens: Default token budget for juror context assembly
"""
dependencies._query_service = QueryService(graph_db, vector_index, max_tokens)
query_service = QueryService(graph_db, vector_index, max_tokens)
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -35,6 +34,7 @@ def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastA
lifespan=lifespan,
)
app.state.query_service = query_service
app.include_router(router)
register_error_handlers(app)

View file

@ -2,12 +2,13 @@
from __future__ import annotations
from fastapi import Request
from aucourt_ingest.api.service import QueryService
_query_service: QueryService | None = None
def get_query_service() -> QueryService:
if _query_service is None:
def get_query_service(request: Request) -> QueryService:
svc = request.app.state.query_service
if svc is None:
raise RuntimeError("QueryService not initialised — call create_app() first")
return _query_service
return svc

View file

@ -80,7 +80,7 @@ async def vector_search(
):
types = [t.strip() for t in chunk_types.split(",") if t.strip()] if chunk_types else None
docs = [d.strip() for d in doc_ids.split(",") if d.strip()] if doc_ids else None
return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
return await svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
@router.get("/hybrid", response_model=list[HybridResult])

View file

@ -2,7 +2,10 @@
from __future__ import annotations
import asyncio
import logging
from datetime import datetime
from typing import Protocol, runtime_checkable
from aucourt_ingest.jury.personas import PERSONAS, get_persona
from aucourt_ingest.jury.subgraph_query import SubgraphQuery
@ -11,10 +14,26 @@ from aucourt_ingest.storage.graph_db import GraphDB
logger = logging.getLogger(__name__)
@runtime_checkable
class VectorIndexProtocol(Protocol):
"""Protocol for VectorIndex — used for type-safe duck-typing."""
def query(self, text: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None,
embedding: list[float] | None = None) -> list[dict]:
...
@property
def count(self) -> int:
...
class QueryService:
"""Read-only query service for juror RAG queries."""
def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000):
def __init__(self, graph_db: GraphDB, vector_index: VectorIndexProtocol,
max_tokens: int = 4000):
self._graph_db = graph_db
self._vector_index = vector_index
self._subgraph = SubgraphQuery(graph_db)
@ -46,11 +65,11 @@ class QueryService:
"matter_type": node.properties.get("matter_type", ""),
"verdict": node.properties.get("verdict", ""),
})
cases.sort(key=lambda c: c.get("date", ""), reverse=True)
cases.sort(key=lambda c: _parse_date_key(c.get("date", "")), reverse=True)
return cases[offset:offset + limit]
async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
results = self._vector_index.query(q, top_k=limit)
results = await asyncio.to_thread(self._vector_index.query, q, top_k=limit)
seen: set[str] = set()
cases = []
for r in results:
@ -73,17 +92,21 @@ class QueryService:
"total_tokens": ctx.total_tokens,
}
def vector_search(self, query: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None) -> list[dict]:
return self._vector_index.query(
query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
async def vector_search(self, query: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None) -> list[dict]:
results = await asyncio.to_thread(
self._vector_index.query, query,
top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
)
return results
async def hybrid_search(self, query: str, persona_name: str | None = None,
top_k: int = 10,
max_tokens: int | None = None) -> list[dict]:
vector_results = self._vector_index.query(query, top_k=top_k)
vector_results = await asyncio.to_thread(
self._vector_index.query, query, top_k=top_k,
)
# Group by doc_id
by_doc: dict[str, list[dict]] = {}
for r in vector_results:
@ -190,3 +213,13 @@ def _rel_to_dict(rel) -> dict:
"rel_type": rel.rel_type,
"properties": rel.properties,
}
def _parse_date_key(date_str: str) -> datetime | None:
"""Parse ISO 8601 date string for sorting. Returns None for empty/invalid."""
if not date_str:
return None
try:
return datetime.fromisoformat(date_str[:10])
except (ValueError, IndexError):
return None

View file

@ -103,7 +103,9 @@ class InMemoryGraphDB:
except (nx.NetworkXNoPath, nx.NodeNotFound):
return None
async def close(self):
self._g.clear()
self._nodes.clear()
self._rels.clear()
async def close(self, destroy: bool = True):
"""Clear graph data. Set destroy=False to keep data (e.g. for tests)."""
if destroy:
self._g.clear()
self._nodes.clear()
self._rels.clear()