From f799b41ed98e6198bbc74aea8406f16cc70c5686 Mon Sep 17 00:00:00 2001 From: slothitude Date: Sat, 30 May 2026 12:20:56 +1000 Subject: [PATCH] Resolve all deferred audit findings: type safety, async, isolation, dates - #11 VectorIndexProtocol for typed duck-typing of vector_index param - #10 Wrap all sync VectorIndex.query() calls in asyncio.to_thread - #9 Make vector_search async (consistent with to_thread wrapping) - #12 Replace module-level singleton with app.state.query_service - #13 InMemoryGraphDB.close(destroy=False) guard for test safety - #14 Parse dates with datetime.fromisoformat in list_cases sort Co-Authored-By: Claude Opus 4.6 --- aucourt_ingest/api/app.py | 4 +- aucourt_ingest/api/dependencies.py | 11 +++-- aucourt_ingest/api/routes.py | 2 +- aucourt_ingest/api/service.py | 51 ++++++++++++++++---- aucourt_ingest/storage/in_memory_graph_db.py | 10 ++-- 5 files changed, 57 insertions(+), 21 deletions(-) diff --git a/aucourt_ingest/api/app.py b/aucourt_ingest/api/app.py index 1663ba6..35a0a51 100644 --- a/aucourt_ingest/api/app.py +++ b/aucourt_ingest/api/app.py @@ -6,7 +6,6 @@ from contextlib import asynccontextmanager from fastapi import FastAPI -from aucourt_ingest.api import dependencies from aucourt_ingest.api.errors import register_error_handlers from aucourt_ingest.api.routes import router from aucourt_ingest.api.service import QueryService @@ -21,7 +20,7 @@ def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastA vector_index: VectorIndex instance (or duck-typed fake for tests) max_tokens: Default token budget for juror context assembly """ - dependencies._query_service = QueryService(graph_db, vector_index, max_tokens) + query_service = QueryService(graph_db, vector_index, max_tokens) @asynccontextmanager async def lifespan(app: FastAPI): @@ -35,6 +34,7 @@ def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastA lifespan=lifespan, ) + app.state.query_service = query_service app.include_router(router) register_error_handlers(app) diff --git a/aucourt_ingest/api/dependencies.py b/aucourt_ingest/api/dependencies.py index 6570b23..c53495d 100644 --- a/aucourt_ingest/api/dependencies.py +++ b/aucourt_ingest/api/dependencies.py @@ -2,12 +2,13 @@ from __future__ import annotations +from fastapi import Request + from aucourt_ingest.api.service import QueryService -_query_service: QueryService | None = None - -def get_query_service() -> QueryService: - if _query_service is None: +def get_query_service(request: Request) -> QueryService: + svc = request.app.state.query_service + if svc is None: raise RuntimeError("QueryService not initialised — call create_app() first") - return _query_service + return svc diff --git a/aucourt_ingest/api/routes.py b/aucourt_ingest/api/routes.py index ef09e16..c797d51 100644 --- a/aucourt_ingest/api/routes.py +++ b/aucourt_ingest/api/routes.py @@ -80,7 +80,7 @@ async def vector_search( ): types = [t.strip() for t in chunk_types.split(",") if t.strip()] if chunk_types else None docs = [d.strip() for d in doc_ids.split(",") if d.strip()] if doc_ids else None - return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs) + return await svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs) @router.get("/hybrid", response_model=list[HybridResult]) diff --git a/aucourt_ingest/api/service.py b/aucourt_ingest/api/service.py index 2018321..5839497 100644 --- a/aucourt_ingest/api/service.py +++ b/aucourt_ingest/api/service.py @@ -2,7 +2,10 @@ from __future__ import annotations +import asyncio import logging +from datetime import datetime +from typing import Protocol, runtime_checkable from aucourt_ingest.jury.personas import PERSONAS, get_persona from aucourt_ingest.jury.subgraph_query import SubgraphQuery @@ -11,10 +14,26 @@ from aucourt_ingest.storage.graph_db import GraphDB logger = logging.getLogger(__name__) +@runtime_checkable +class VectorIndexProtocol(Protocol): + """Protocol for VectorIndex — used for type-safe duck-typing.""" + + def query(self, text: str, top_k: int = 10, + chunk_types: list[str] | None = None, + doc_ids: list[str] | None = None, + embedding: list[float] | None = None) -> list[dict]: + ... + + @property + def count(self) -> int: + ... + + class QueryService: """Read-only query service for juror RAG queries.""" - def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000): + def __init__(self, graph_db: GraphDB, vector_index: VectorIndexProtocol, + max_tokens: int = 4000): self._graph_db = graph_db self._vector_index = vector_index self._subgraph = SubgraphQuery(graph_db) @@ -46,11 +65,11 @@ class QueryService: "matter_type": node.properties.get("matter_type", ""), "verdict": node.properties.get("verdict", ""), }) - cases.sort(key=lambda c: c.get("date", ""), reverse=True) + cases.sort(key=lambda c: _parse_date_key(c.get("date", "")), reverse=True) return cases[offset:offset + limit] async def search_cases(self, q: str, limit: int = 10) -> list[dict]: - results = self._vector_index.query(q, top_k=limit) + results = await asyncio.to_thread(self._vector_index.query, q, top_k=limit) seen: set[str] = set() cases = [] for r in results: @@ -73,17 +92,21 @@ class QueryService: "total_tokens": ctx.total_tokens, } - def vector_search(self, query: str, top_k: int = 10, - chunk_types: list[str] | None = None, - doc_ids: list[str] | None = None) -> list[dict]: - return self._vector_index.query( - query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids, + async def vector_search(self, query: str, top_k: int = 10, + chunk_types: list[str] | None = None, + doc_ids: list[str] | None = None) -> list[dict]: + results = await asyncio.to_thread( + self._vector_index.query, query, + top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids, ) + return results async def hybrid_search(self, query: str, persona_name: str | None = None, top_k: int = 10, max_tokens: int | None = None) -> list[dict]: - vector_results = self._vector_index.query(query, top_k=top_k) + vector_results = await asyncio.to_thread( + self._vector_index.query, query, top_k=top_k, + ) # Group by doc_id by_doc: dict[str, list[dict]] = {} for r in vector_results: @@ -190,3 +213,13 @@ def _rel_to_dict(rel) -> dict: "rel_type": rel.rel_type, "properties": rel.properties, } + + +def _parse_date_key(date_str: str) -> datetime | None: + """Parse ISO 8601 date string for sorting. Returns None for empty/invalid.""" + if not date_str: + return None + try: + return datetime.fromisoformat(date_str[:10]) + except (ValueError, IndexError): + return None diff --git a/aucourt_ingest/storage/in_memory_graph_db.py b/aucourt_ingest/storage/in_memory_graph_db.py index 0892dec..ef7382d 100644 --- a/aucourt_ingest/storage/in_memory_graph_db.py +++ b/aucourt_ingest/storage/in_memory_graph_db.py @@ -103,7 +103,9 @@ class InMemoryGraphDB: except (nx.NetworkXNoPath, nx.NodeNotFound): return None - async def close(self): - self._g.clear() - self._nodes.clear() - self._rels.clear() + async def close(self, destroy: bool = True): + """Clear graph data. Set destroy=False to keep data (e.g. for tests).""" + if destroy: + self._g.clear() + self._nodes.clear() + self._rels.clear()