Resolve all deferred audit findings: type safety, async, isolation, dates
- #11 VectorIndexProtocol for typed duck-typing of vector_index param - #10 Wrap all sync VectorIndex.query() calls in asyncio.to_thread - #9 Make vector_search async (consistent with to_thread wrapping) - #12 Replace module-level singleton with app.state.query_service - #13 InMemoryGraphDB.close(destroy=False) guard for test safety - #14 Parse dates with datetime.fromisoformat in list_cases sort Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
24cde4cdec
commit
f799b41ed9
5 changed files with 57 additions and 21 deletions
|
|
@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from aucourt_ingest.api import dependencies
|
|
||||||
from aucourt_ingest.api.errors import register_error_handlers
|
from aucourt_ingest.api.errors import register_error_handlers
|
||||||
from aucourt_ingest.api.routes import router
|
from aucourt_ingest.api.routes import router
|
||||||
from aucourt_ingest.api.service import QueryService
|
from aucourt_ingest.api.service import QueryService
|
||||||
|
|
@ -21,7 +20,7 @@ def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastA
|
||||||
vector_index: VectorIndex instance (or duck-typed fake for tests)
|
vector_index: VectorIndex instance (or duck-typed fake for tests)
|
||||||
max_tokens: Default token budget for juror context assembly
|
max_tokens: Default token budget for juror context assembly
|
||||||
"""
|
"""
|
||||||
dependencies._query_service = QueryService(graph_db, vector_index, max_tokens)
|
query_service = QueryService(graph_db, vector_index, max_tokens)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
|
@ -35,6 +34,7 @@ def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastA
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.state.query_service = query_service
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
register_error_handlers(app)
|
register_error_handlers(app)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,13 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
from aucourt_ingest.api.service import QueryService
|
from aucourt_ingest.api.service import QueryService
|
||||||
|
|
||||||
_query_service: QueryService | None = None
|
|
||||||
|
|
||||||
|
def get_query_service(request: Request) -> QueryService:
|
||||||
def get_query_service() -> QueryService:
|
svc = request.app.state.query_service
|
||||||
if _query_service is None:
|
if svc is None:
|
||||||
raise RuntimeError("QueryService not initialised — call create_app() first")
|
raise RuntimeError("QueryService not initialised — call create_app() first")
|
||||||
return _query_service
|
return svc
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ async def vector_search(
|
||||||
):
|
):
|
||||||
types = [t.strip() for t in chunk_types.split(",") if t.strip()] if chunk_types else None
|
types = [t.strip() for t in chunk_types.split(",") if t.strip()] if chunk_types else None
|
||||||
docs = [d.strip() for d in doc_ids.split(",") if d.strip()] if doc_ids else None
|
docs = [d.strip() for d in doc_ids.split(",") if d.strip()] if doc_ids else None
|
||||||
return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
|
return await svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/hybrid", response_model=list[HybridResult])
|
@router.get("/hybrid", response_model=list[HybridResult])
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,10 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from aucourt_ingest.jury.personas import PERSONAS, get_persona
|
from aucourt_ingest.jury.personas import PERSONAS, get_persona
|
||||||
from aucourt_ingest.jury.subgraph_query import SubgraphQuery
|
from aucourt_ingest.jury.subgraph_query import SubgraphQuery
|
||||||
|
|
@ -11,10 +14,26 @@ from aucourt_ingest.storage.graph_db import GraphDB
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class VectorIndexProtocol(Protocol):
|
||||||
|
"""Protocol for VectorIndex — used for type-safe duck-typing."""
|
||||||
|
|
||||||
|
def query(self, text: str, top_k: int = 10,
|
||||||
|
chunk_types: list[str] | None = None,
|
||||||
|
doc_ids: list[str] | None = None,
|
||||||
|
embedding: list[float] | None = None) -> list[dict]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class QueryService:
|
class QueryService:
|
||||||
"""Read-only query service for juror RAG queries."""
|
"""Read-only query service for juror RAG queries."""
|
||||||
|
|
||||||
def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000):
|
def __init__(self, graph_db: GraphDB, vector_index: VectorIndexProtocol,
|
||||||
|
max_tokens: int = 4000):
|
||||||
self._graph_db = graph_db
|
self._graph_db = graph_db
|
||||||
self._vector_index = vector_index
|
self._vector_index = vector_index
|
||||||
self._subgraph = SubgraphQuery(graph_db)
|
self._subgraph = SubgraphQuery(graph_db)
|
||||||
|
|
@ -46,11 +65,11 @@ class QueryService:
|
||||||
"matter_type": node.properties.get("matter_type", ""),
|
"matter_type": node.properties.get("matter_type", ""),
|
||||||
"verdict": node.properties.get("verdict", ""),
|
"verdict": node.properties.get("verdict", ""),
|
||||||
})
|
})
|
||||||
cases.sort(key=lambda c: c.get("date", ""), reverse=True)
|
cases.sort(key=lambda c: _parse_date_key(c.get("date", "")), reverse=True)
|
||||||
return cases[offset:offset + limit]
|
return cases[offset:offset + limit]
|
||||||
|
|
||||||
async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
|
async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
|
||||||
results = self._vector_index.query(q, top_k=limit)
|
results = await asyncio.to_thread(self._vector_index.query, q, top_k=limit)
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
cases = []
|
cases = []
|
||||||
for r in results:
|
for r in results:
|
||||||
|
|
@ -73,17 +92,21 @@ class QueryService:
|
||||||
"total_tokens": ctx.total_tokens,
|
"total_tokens": ctx.total_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
def vector_search(self, query: str, top_k: int = 10,
|
async def vector_search(self, query: str, top_k: int = 10,
|
||||||
chunk_types: list[str] | None = None,
|
chunk_types: list[str] | None = None,
|
||||||
doc_ids: list[str] | None = None) -> list[dict]:
|
doc_ids: list[str] | None = None) -> list[dict]:
|
||||||
return self._vector_index.query(
|
results = await asyncio.to_thread(
|
||||||
query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
|
self._vector_index.query, query,
|
||||||
|
top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
|
||||||
)
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
async def hybrid_search(self, query: str, persona_name: str | None = None,
|
async def hybrid_search(self, query: str, persona_name: str | None = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
max_tokens: int | None = None) -> list[dict]:
|
max_tokens: int | None = None) -> list[dict]:
|
||||||
vector_results = self._vector_index.query(query, top_k=top_k)
|
vector_results = await asyncio.to_thread(
|
||||||
|
self._vector_index.query, query, top_k=top_k,
|
||||||
|
)
|
||||||
# Group by doc_id
|
# Group by doc_id
|
||||||
by_doc: dict[str, list[dict]] = {}
|
by_doc: dict[str, list[dict]] = {}
|
||||||
for r in vector_results:
|
for r in vector_results:
|
||||||
|
|
@ -190,3 +213,13 @@ def _rel_to_dict(rel) -> dict:
|
||||||
"rel_type": rel.rel_type,
|
"rel_type": rel.rel_type,
|
||||||
"properties": rel.properties,
|
"properties": rel.properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date_key(date_str: str) -> datetime | None:
|
||||||
|
"""Parse ISO 8601 date string for sorting. Returns None for empty/invalid."""
|
||||||
|
if not date_str:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(date_str[:10])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,9 @@ class InMemoryGraphDB:
|
||||||
except (nx.NetworkXNoPath, nx.NodeNotFound):
|
except (nx.NetworkXNoPath, nx.NodeNotFound):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def close(self):
|
async def close(self, destroy: bool = True):
|
||||||
self._g.clear()
|
"""Clear graph data. Set destroy=False to keep data (e.g. for tests)."""
|
||||||
self._nodes.clear()
|
if destroy:
|
||||||
self._rels.clear()
|
self._g.clear()
|
||||||
|
self._nodes.clear()
|
||||||
|
self._rels.clear()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue