Audit fixes: response_model validation, error handling, dead code, input sanitisation

- Add response_model to all 8 route endpoints for runtime validation and
  correct Swagger docs
- Remove global KeyError handler (routes catch it explicitly)
- Add catch-all Exception handler with logging for 500 responses
- Remove dead code in service.py get_case_graph (unused bucket variable)
- Explicit graph_backend validation in cmd_serve (memory|neo4j, else exit)
- Sanitise comma-separated query params (strip whitespace, filter empty)
- Move HTTPException to top-level import in routes.py
- Remove unused imports (Depends in dependencies.py, all_persona_names)
- Fix deprecated asyncio.get_event_loop() in test fixture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
slothitude 2026-05-30 12:17:55 +10:00
parent 792dbceab5
commit 24cde4cdec
6 changed files with 37 additions and 25 deletions

View file

@ -2,8 +2,6 @@
from __future__ import annotations from __future__ import annotations
from fastapi import Depends
from aucourt_ingest.api.service import QueryService from aucourt_ingest.api.service import QueryService
_query_service: QueryService | None = None _query_service: QueryService | None = None

View file

@ -2,15 +2,20 @@
from __future__ import annotations from __future__ import annotations
import logging
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
logger = logging.getLogger(__name__)
def register_error_handlers(app: FastAPI) -> None: def register_error_handlers(app: FastAPI) -> None:
@app.exception_handler(ValueError) @app.exception_handler(ValueError)
async def value_error_handler(request: Request, exc: ValueError): async def value_error_handler(request: Request, exc: ValueError):
return JSONResponse(status_code=400, content={"detail": str(exc)}) return JSONResponse(status_code=400, content={"detail": str(exc)})
@app.exception_handler(KeyError) @app.exception_handler(Exception)
async def key_error_handler(request: Request, exc: KeyError): async def generic_error_handler(request: Request, exc: Exception):
return JSONResponse(status_code=404, content={"detail": str(exc)}) logger.exception("Unhandled exception in query API")
return JSONResponse(status_code=500, content={"detail": "Internal server error"})

View file

@ -2,24 +2,33 @@
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, Query, Depends from fastapi import APIRouter, HTTPException, Query, Depends
from aucourt_ingest.api.dependencies import get_query_service from aucourt_ingest.api.dependencies import get_query_service
from aucourt_ingest.api.schemas import (
CaseGraphSummary,
CaseSummary,
ChunkResult,
HealthResponse,
HybridResult,
JurorContextResponse,
PersonaInfo,
)
router = APIRouter(prefix="/api/v1") router = APIRouter(prefix="/api/v1")
@router.get("/health") @router.get("/health", response_model=HealthResponse)
async def health(svc=Depends(get_query_service)): async def health(svc=Depends(get_query_service)):
return await svc.health() return await svc.health()
@router.get("/personas") @router.get("/personas", response_model=list[PersonaInfo])
async def list_personas(svc=Depends(get_query_service)): async def list_personas(svc=Depends(get_query_service)):
return svc.list_personas() return svc.list_personas()
@router.get("/cases") @router.get("/cases", response_model=list[CaseSummary])
async def list_cases( async def list_cases(
court: str | None = Query(None, description="Filter by court code"), court: str | None = Query(None, description="Filter by court code"),
limit: int = Query(50, ge=1, le=200), limit: int = Query(50, ge=1, le=200),
@ -29,7 +38,7 @@ async def list_cases(
return await svc.list_cases(court=court, limit=limit, offset=offset) return await svc.list_cases(court=court, limit=limit, offset=offset)
@router.get("/cases/search") @router.get("/cases/search", response_model=list[CaseSummary])
async def search_cases( async def search_cases(
q: str = Query(..., description="Search query text"), q: str = Query(..., description="Search query text"),
limit: int = Query(10, ge=1, le=50), limit: int = Query(10, ge=1, le=50),
@ -38,7 +47,7 @@ async def search_cases(
return await svc.search_cases(q=q, limit=limit) return await svc.search_cases(q=q, limit=limit)
@router.get("/cases/{case_mnc}/juror/{persona}") @router.get("/cases/{case_mnc}/juror/{persona}", response_model=JurorContextResponse)
async def get_juror_context( async def get_juror_context(
case_mnc: str, case_mnc: str,
persona: str, persona: str,
@ -50,20 +59,18 @@ async def get_juror_context(
case_mnc, persona, max_tokens, case_mnc, persona, max_tokens,
) )
except KeyError as e: except KeyError as e:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@router.get("/cases/{case_mnc}") @router.get("/cases/{case_mnc}", response_model=CaseGraphSummary)
async def get_case_graph(case_mnc: str, svc=Depends(get_query_service)): async def get_case_graph(case_mnc: str, svc=Depends(get_query_service)):
result = await svc.get_case_graph(case_mnc) result = await svc.get_case_graph(case_mnc)
if result is None: if result is None:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail=f"Case not found: {case_mnc}") raise HTTPException(status_code=404, detail=f"Case not found: {case_mnc}")
return result return result
@router.get("/search") @router.get("/search", response_model=list[ChunkResult])
async def vector_search( async def vector_search(
q: str = Query(..., description="Query text"), q: str = Query(..., description="Query text"),
top_k: int = Query(10, ge=1, le=50), top_k: int = Query(10, ge=1, le=50),
@ -71,12 +78,12 @@ async def vector_search(
doc_ids: str | None = Query(None, description="Comma-separated doc IDs"), doc_ids: str | None = Query(None, description="Comma-separated doc IDs"),
svc=Depends(get_query_service), svc=Depends(get_query_service),
): ):
types = chunk_types.split(",") if chunk_types else None types = [t.strip() for t in chunk_types.split(",") if t.strip()] if chunk_types else None
docs = doc_ids.split(",") if doc_ids else None docs = [d.strip() for d in doc_ids.split(",") if d.strip()] if doc_ids else None
return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs) return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
@router.get("/hybrid") @router.get("/hybrid", response_model=list[HybridResult])
async def hybrid_search( async def hybrid_search(
q: str = Query(..., description="Query text"), q: str = Query(..., description="Query text"),
persona: str | None = Query(None, description="Juror persona name"), persona: str | None = Query(None, description="Juror persona name"),

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import logging import logging
from aucourt_ingest.jury.personas import PERSONAS, get_persona, all_persona_names from aucourt_ingest.jury.personas import PERSONAS, get_persona
from aucourt_ingest.jury.subgraph_query import SubgraphQuery from aucourt_ingest.jury.subgraph_query import SubgraphQuery
from aucourt_ingest.storage.graph_db import GraphDB from aucourt_ingest.storage.graph_db import GraphDB
@ -133,8 +133,6 @@ class QueryService:
node = await self._graph_db.get_node(nid) node = await self._graph_db.get_node(nid)
if node is None: if node is None:
continue continue
bucket = node.label.lower() + "s" if node.label != "Witness" else "witnesses"
# Map labels to plural keys
label_map = { label_map = {
"Judge": "judges", "Judge": "judges",
"Charge": "charges", "Charge": "charges",

View file

@ -310,7 +310,7 @@ def cmd_serve(args):
if backend == "memory": if backend == "memory":
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
graph_db = InMemoryGraphDB() graph_db = InMemoryGraphDB()
else: elif backend == "neo4j":
from aucourt_ingest.storage.graph_db import Neo4jGraphDB from aucourt_ingest.storage.graph_db import Neo4jGraphDB
graph_db = Neo4jGraphDB( graph_db = Neo4jGraphDB(
uri=config.storage.neo4j_uri, uri=config.storage.neo4j_uri,
@ -318,6 +318,9 @@ def cmd_serve(args):
password=config.storage.neo4j_password, password=config.storage.neo4j_password,
database=config.storage.neo4j_database, database=config.storage.neo4j_database,
) )
else:
print(f"Unknown graph backend: {backend} (valid: memory, neo4j)")
sys.exit(1)
from aucourt_ingest.storage.vector_index import VectorIndex from aucourt_ingest.storage.vector_index import VectorIndex
vector_index = VectorIndex(str(config.storage.chromadb_dir)) vector_index = VectorIndex(str(config.storage.chromadb_dir))

View file

@ -5,6 +5,8 @@ All tests use InMemoryGraphDB + FakeVectorIndex — zero external services.
from __future__ import annotations from __future__ import annotations
import asyncio
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -73,8 +75,7 @@ def seeded_graph():
] ]
for c in chunks1: for c in chunks1:
c.token_count = len(c.text) // 4 c.token_count = len(c.text) // 4
import asyncio asyncio.run(builder.build_full(meta1, chunks1))
asyncio.get_event_loop().run_until_complete(builder.build_full(meta1, chunks1))
# Case 2 # Case 2
meta2 = CaseMeta( meta2 = CaseMeta(
@ -96,7 +97,7 @@ def seeded_graph():
] ]
for c in chunks2: for c in chunks2:
c.token_count = len(c.text) // 4 c.token_count = len(c.text) // 4
asyncio.get_event_loop().run_until_complete(builder.build_full(meta2, chunks2)) asyncio.run(builder.build_full(meta2, chunks2))
return db return db