Stage 9: add read-only FastAPI query API for juror RAG queries
8 GET endpoints under /api/v1 for health, personas, cases, vector search, juror context, and hybrid search. Includes QueryService composing SubgraphQuery + VectorIndex + GraphDB, Pydantic response models, error handlers, and `serve` CLI mode via uvicorn. 20 new tests, 190 total, zero regressions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d77fe12cfc
commit
6374aea0a2
12 changed files with 809 additions and 0 deletions
0
aucourt_ingest/api/__init__.py
Normal file
0
aucourt_ingest/api/__init__.py
Normal file
41
aucourt_ingest/api/app.py
Normal file
41
aucourt_ingest/api/app.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""FastAPI app factory for the query API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from aucourt_ingest.api import dependencies
|
||||||
|
from aucourt_ingest.api.errors import register_error_handlers
|
||||||
|
from aucourt_ingest.api.routes import router
|
||||||
|
from aucourt_ingest.api.service import QueryService
|
||||||
|
from aucourt_ingest.storage.graph_db import GraphDB
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastAPI:
|
||||||
|
"""Create and configure the FastAPI app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_db: GraphDB instance (InMemoryGraphDB or Neo4jGraphDB)
|
||||||
|
vector_index: VectorIndex instance (or duck-typed fake for tests)
|
||||||
|
max_tokens: Default token budget for juror context assembly
|
||||||
|
"""
|
||||||
|
dependencies._query_service = QueryService(graph_db, vector_index, max_tokens)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
await graph_db.close()
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="AuCourtIngest Query API",
|
||||||
|
description="Read-only juror RAG query API for Australian legal cases",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
register_error_handlers(app)
|
||||||
|
|
||||||
|
return app
|
||||||
15
aucourt_ingest/api/dependencies.py
Normal file
15
aucourt_ingest/api/dependencies.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""FastAPI dependency providers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
from aucourt_ingest.api.service import QueryService
|
||||||
|
|
||||||
|
_query_service: QueryService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_service() -> QueryService:
|
||||||
|
if _query_service is None:
|
||||||
|
raise RuntimeError("QueryService not initialised — call create_app() first")
|
||||||
|
return _query_service
|
||||||
16
aucourt_ingest/api/errors.py
Normal file
16
aucourt_ingest/api/errors.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Error handlers for the query API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
def register_error_handlers(app: FastAPI) -> None:
|
||||||
|
@app.exception_handler(ValueError)
|
||||||
|
async def value_error_handler(request: Request, exc: ValueError):
|
||||||
|
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||||
|
|
||||||
|
@app.exception_handler(KeyError)
|
||||||
|
async def key_error_handler(request: Request, exc: KeyError):
|
||||||
|
return JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||||
87
aucourt_ingest/api/routes.py
Normal file
87
aucourt_ingest/api/routes.py
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
"""Route definitions for the read-only query API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, Depends
|
||||||
|
|
||||||
|
from aucourt_ingest.api.dependencies import get_query_service
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health(svc=Depends(get_query_service)):
|
||||||
|
return await svc.health()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/personas")
|
||||||
|
async def list_personas(svc=Depends(get_query_service)):
|
||||||
|
return svc.list_personas()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cases")
|
||||||
|
async def list_cases(
|
||||||
|
court: str | None = Query(None, description="Filter by court code"),
|
||||||
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
offset: int = Query(0, ge=0),
|
||||||
|
svc=Depends(get_query_service),
|
||||||
|
):
|
||||||
|
return await svc.list_cases(court=court, limit=limit, offset=offset)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cases/search")
|
||||||
|
async def search_cases(
|
||||||
|
q: str = Query(..., description="Search query text"),
|
||||||
|
limit: int = Query(10, ge=1, le=50),
|
||||||
|
svc=Depends(get_query_service),
|
||||||
|
):
|
||||||
|
return await svc.search_cases(q=q, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cases/{case_mnc}/juror/{persona}")
|
||||||
|
async def get_juror_context(
|
||||||
|
case_mnc: str,
|
||||||
|
persona: str,
|
||||||
|
max_tokens: int | None = Query(None, ge=100, le=16000),
|
||||||
|
svc=Depends(get_query_service),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
return await svc.get_juror_context(
|
||||||
|
case_mnc, persona, max_tokens,
|
||||||
|
)
|
||||||
|
except KeyError as e:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cases/{case_mnc}")
|
||||||
|
async def get_case_graph(case_mnc: str, svc=Depends(get_query_service)):
|
||||||
|
result = await svc.get_case_graph(case_mnc)
|
||||||
|
if result is None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=404, detail=f"Case not found: {case_mnc}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search")
|
||||||
|
async def vector_search(
|
||||||
|
q: str = Query(..., description="Query text"),
|
||||||
|
top_k: int = Query(10, ge=1, le=50),
|
||||||
|
chunk_types: str | None = Query(None, description="Comma-separated chunk types"),
|
||||||
|
doc_ids: str | None = Query(None, description="Comma-separated doc IDs"),
|
||||||
|
svc=Depends(get_query_service),
|
||||||
|
):
|
||||||
|
types = chunk_types.split(",") if chunk_types else None
|
||||||
|
docs = doc_ids.split(",") if doc_ids else None
|
||||||
|
return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/hybrid")
|
||||||
|
async def hybrid_search(
|
||||||
|
q: str = Query(..., description="Query text"),
|
||||||
|
persona: str | None = Query(None, description="Juror persona name"),
|
||||||
|
top_k: int = Query(10, ge=1, le=50),
|
||||||
|
max_tokens: int | None = Query(None, ge=100, le=16000),
|
||||||
|
svc=Depends(get_query_service),
|
||||||
|
):
|
||||||
|
return await svc.hybrid_search(q, persona_name=persona, top_k=top_k, max_tokens=max_tokens)
|
||||||
74
aucourt_ingest/api/schemas.py
Normal file
74
aucourt_ingest/api/schemas.py
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
"""Pydantic v2 response models for the query API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
anchor_nodes: list[str]
|
||||||
|
edge_types: list[str]
|
||||||
|
chunk_types: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class CaseSummary(BaseModel):
|
||||||
|
mnc: str
|
||||||
|
court: str = ""
|
||||||
|
date: str = ""
|
||||||
|
jurisdiction: str = ""
|
||||||
|
matter_type: str = ""
|
||||||
|
verdict: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkResult(BaseModel):
|
||||||
|
chunk_id: str
|
||||||
|
text: str | None = None
|
||||||
|
metadata: dict = Field(default_factory=dict)
|
||||||
|
distance: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class JurorContextResponse(BaseModel):
|
||||||
|
persona: str
|
||||||
|
case_mnc: str
|
||||||
|
context_text: str
|
||||||
|
source_chunk_ids: list[str]
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class CaseGraphNode(BaseModel):
|
||||||
|
node_id: str
|
||||||
|
label: str
|
||||||
|
properties: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CaseGraphRelationship(BaseModel):
|
||||||
|
rel_id: str
|
||||||
|
from_id: str
|
||||||
|
to_id: str
|
||||||
|
rel_type: str
|
||||||
|
properties: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CaseGraphSummary(BaseModel):
|
||||||
|
case: CaseGraphNode
|
||||||
|
judges: list[CaseGraphNode] = Field(default_factory=list)
|
||||||
|
charges: list[CaseGraphNode] = Field(default_factory=list)
|
||||||
|
rulings: list[CaseGraphNode] = Field(default_factory=list)
|
||||||
|
witnesses: list[CaseGraphNode] = Field(default_factory=list)
|
||||||
|
chunks: list[CaseGraphNode] = Field(default_factory=list)
|
||||||
|
relationships: list[CaseGraphRelationship] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class HybridResult(BaseModel):
|
||||||
|
case_mnc: str
|
||||||
|
chunks: list[ChunkResult] = Field(default_factory=list)
|
||||||
|
juror_context: JurorContextResponse | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class HealthResponse(BaseModel):
|
||||||
|
status: str
|
||||||
|
nodes: int
|
||||||
|
edges: int
|
||||||
|
cases: int
|
||||||
|
chunks: int
|
||||||
194
aucourt_ingest/api/service.py
Normal file
194
aucourt_ingest/api/service.py
Normal file
|
|
@ -0,0 +1,194 @@
|
||||||
|
"""QueryService — composes SubgraphQuery + VectorIndex + GraphDB for the read API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aucourt_ingest.jury.personas import PERSONAS, get_persona, all_persona_names
|
||||||
|
from aucourt_ingest.jury.subgraph_query import SubgraphQuery
|
||||||
|
from aucourt_ingest.storage.graph_db import GraphDB
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryService:
|
||||||
|
"""Read-only query service for juror RAG queries."""
|
||||||
|
|
||||||
|
def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000):
|
||||||
|
self._graph_db = graph_db
|
||||||
|
self._vector_index = vector_index
|
||||||
|
self._subgraph = SubgraphQuery(graph_db)
|
||||||
|
self._max_tokens = max_tokens
|
||||||
|
|
||||||
|
def list_personas(self) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"anchor_nodes": p.anchor_nodes,
|
||||||
|
"edge_types": p.edge_types,
|
||||||
|
"chunk_types": p.chunk_types,
|
||||||
|
}
|
||||||
|
for name, p in PERSONAS.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
async def list_cases(self, court: str | None = None,
|
||||||
|
limit: int = 50, offset: int = 0) -> list[dict]:
|
||||||
|
nodes = await self._graph_db.query_nodes("Case")
|
||||||
|
cases = []
|
||||||
|
for node in nodes:
|
||||||
|
if court and node.properties.get("court") != court:
|
||||||
|
continue
|
||||||
|
cases.append({
|
||||||
|
"mnc": node.properties.get("mnc", node.node_id),
|
||||||
|
"court": node.properties.get("court", ""),
|
||||||
|
"date": node.properties.get("date", ""),
|
||||||
|
"jurisdiction": node.properties.get("jurisdiction", ""),
|
||||||
|
"matter_type": node.properties.get("matter_type", ""),
|
||||||
|
"verdict": node.properties.get("verdict", ""),
|
||||||
|
})
|
||||||
|
cases.sort(key=lambda c: c.get("date", ""), reverse=True)
|
||||||
|
return cases[offset:offset + limit]
|
||||||
|
|
||||||
|
async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
|
||||||
|
results = self._vector_index.query(q, top_k=limit)
|
||||||
|
seen: set[str] = set()
|
||||||
|
cases = []
|
||||||
|
for r in results:
|
||||||
|
doc_id = r["metadata"].get("doc_id", "")
|
||||||
|
if doc_id and doc_id not in seen:
|
||||||
|
seen.add(doc_id)
|
||||||
|
cases.append(await self._get_case_by_mnc(doc_id))
|
||||||
|
return [c for c in cases if c]
|
||||||
|
|
||||||
|
async def get_juror_context(self, case_mnc: str, persona_name: str,
|
||||||
|
max_tokens: int | None = None) -> dict:
|
||||||
|
persona = get_persona(persona_name)
|
||||||
|
budget = max_tokens or self._max_tokens
|
||||||
|
ctx = await self._subgraph.get_juror_context(case_mnc, persona, budget)
|
||||||
|
return {
|
||||||
|
"persona": ctx.persona,
|
||||||
|
"case_mnc": ctx.case_mnc,
|
||||||
|
"context_text": ctx.context_text,
|
||||||
|
"source_chunk_ids": ctx.source_chunk_ids,
|
||||||
|
"total_tokens": ctx.total_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
def vector_search(self, query: str, top_k: int = 10,
|
||||||
|
chunk_types: list[str] | None = None,
|
||||||
|
doc_ids: list[str] | None = None) -> list[dict]:
|
||||||
|
return self._vector_index.query(
|
||||||
|
query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def hybrid_search(self, query: str, persona_name: str | None = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
max_tokens: int | None = None) -> list[dict]:
|
||||||
|
vector_results = self._vector_index.query(query, top_k=top_k)
|
||||||
|
# Group by doc_id
|
||||||
|
by_doc: dict[str, list[dict]] = {}
|
||||||
|
for r in vector_results:
|
||||||
|
doc_id = r["metadata"].get("doc_id", "unknown")
|
||||||
|
by_doc.setdefault(doc_id, []).append(r)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for doc_mnc, chunks in by_doc.items():
|
||||||
|
entry: dict = {
|
||||||
|
"case_mnc": doc_mnc,
|
||||||
|
"chunks": chunks,
|
||||||
|
"juror_context": None,
|
||||||
|
}
|
||||||
|
if persona_name and persona_name in PERSONAS:
|
||||||
|
ctx = await self.get_juror_context(
|
||||||
|
doc_mnc, persona_name, max_tokens,
|
||||||
|
)
|
||||||
|
entry["juror_context"] = ctx
|
||||||
|
results.append(entry)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def get_case_graph(self, case_mnc: str) -> dict | None:
|
||||||
|
case_id = f"Case:{case_mnc}"
|
||||||
|
case_node = await self._graph_db.get_node(case_id)
|
||||||
|
if case_node is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"case": _node_to_dict(case_node),
|
||||||
|
"judges": [],
|
||||||
|
"charges": [],
|
||||||
|
"rulings": [],
|
||||||
|
"witnesses": [],
|
||||||
|
"chunks": [],
|
||||||
|
"relationships": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
rels = await self._graph_db.get_relationships(case_id, direction="both")
|
||||||
|
neighbor_ids: set[str] = set()
|
||||||
|
|
||||||
|
for rel in rels:
|
||||||
|
result["relationships"].append(_rel_to_dict(rel))
|
||||||
|
neighbor_id = rel.to_id if rel.from_id == case_id else rel.from_id
|
||||||
|
neighbor_ids.add(neighbor_id)
|
||||||
|
|
||||||
|
for nid in neighbor_ids:
|
||||||
|
node = await self._graph_db.get_node(nid)
|
||||||
|
if node is None:
|
||||||
|
continue
|
||||||
|
bucket = node.label.lower() + "s" if node.label != "Witness" else "witnesses"
|
||||||
|
# Map labels to plural keys
|
||||||
|
label_map = {
|
||||||
|
"Judge": "judges",
|
||||||
|
"Charge": "charges",
|
||||||
|
"Ruling": "rulings",
|
||||||
|
"Witness": "witnesses",
|
||||||
|
"Chunk": "chunks",
|
||||||
|
}
|
||||||
|
key = label_map.get(node.label)
|
||||||
|
if key and key in result:
|
||||||
|
result[key].append(_node_to_dict(node))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def health(self) -> dict:
|
||||||
|
total_nodes = await self._graph_db.node_count()
|
||||||
|
total_edges = await self._graph_db.relationship_count()
|
||||||
|
cases = await self._graph_db.node_count("Case")
|
||||||
|
chunks = await self._graph_db.node_count("Chunk")
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"nodes": total_nodes,
|
||||||
|
"edges": total_edges,
|
||||||
|
"cases": cases,
|
||||||
|
"chunks": chunks,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _get_case_by_mnc(self, mnc: str) -> dict | None:
|
||||||
|
case_id = f"Case:{mnc}"
|
||||||
|
node = await self._graph_db.get_node(case_id)
|
||||||
|
if node is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"mnc": node.properties.get("mnc", ""),
|
||||||
|
"court": node.properties.get("court", ""),
|
||||||
|
"date": node.properties.get("date", ""),
|
||||||
|
"jurisdiction": node.properties.get("jurisdiction", ""),
|
||||||
|
"matter_type": node.properties.get("matter_type", ""),
|
||||||
|
"verdict": node.properties.get("verdict", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _node_to_dict(node) -> dict:
|
||||||
|
return {
|
||||||
|
"node_id": node.node_id,
|
||||||
|
"label": node.label,
|
||||||
|
"properties": node.properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _rel_to_dict(rel) -> dict:
|
||||||
|
return {
|
||||||
|
"rel_id": rel.rel_id,
|
||||||
|
"from_id": rel.from_id,
|
||||||
|
"to_id": rel.to_id,
|
||||||
|
"rel_type": rel.rel_type,
|
||||||
|
"properties": rel.properties,
|
||||||
|
}
|
||||||
|
|
@ -55,6 +55,14 @@ class TelegramConfig:
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ServerConfig:
|
||||||
|
host: str = "127.0.0.1"
|
||||||
|
port: int = 8000
|
||||||
|
graph_backend: str = "memory"
|
||||||
|
default_max_tokens: int = 4000
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AppConfig:
|
class AppConfig:
|
||||||
"""Root config — loaded from config.toml."""
|
"""Root config — loaded from config.toml."""
|
||||||
|
|
@ -64,6 +72,7 @@ class AppConfig:
|
||||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
telegram: TelegramConfig = field(default_factory=TelegramConfig)
|
telegram: TelegramConfig = field(default_factory=TelegramConfig)
|
||||||
|
server: ServerConfig = field(default_factory=ServerConfig)
|
||||||
user_agent: str = "AuCourtIngest/0.1 (legal research)"
|
user_agent: str = "AuCourtIngest/0.1 (legal research)"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -138,6 +147,15 @@ class AppConfig:
|
||||||
enabled=tg.get("enabled", False),
|
enabled=tg.get("enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Server
|
||||||
|
srv = raw.get("server", {})
|
||||||
|
config.server = ServerConfig(
|
||||||
|
host=srv.get("host", "127.0.0.1"),
|
||||||
|
port=srv.get("port", 8000),
|
||||||
|
graph_backend=srv.get("graph_backend", "memory"),
|
||||||
|
default_max_tokens=srv.get("default_max_tokens", 4000),
|
||||||
|
)
|
||||||
|
|
||||||
config.user_agent = raw.get("user_agent", "AuCourtIngest/0.1 (legal research)")
|
config.user_agent = raw.get("user_agent", "AuCourtIngest/0.1 (legal research)")
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,13 @@ def build_parser() -> argparse.ArgumentParser:
|
||||||
p_proc.add_argument("--db", default="data/meta.db", help="MetaDB path")
|
p_proc.add_argument("--db", default="data/meta.db", help="MetaDB path")
|
||||||
p_proc.add_argument("--raw-dir", default="data/raw", help="Raw document storage dir")
|
p_proc.add_argument("--raw-dir", default="data/raw", help="Raw document storage dir")
|
||||||
|
|
||||||
|
# serve
|
||||||
|
p_serve = sub.add_parser("serve", help="Start the read-only query API server")
|
||||||
|
p_serve.add_argument("--host", default=None, help="Bind host (default from config)")
|
||||||
|
p_serve.add_argument("--port", type=int, default=None, help="Bind port (default from config)")
|
||||||
|
p_serve.add_argument("--graph-backend", default=None,
|
||||||
|
help="Graph backend: memory|neo4j (default from config)")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -291,6 +298,44 @@ async def cmd_process(args):
|
||||||
print(f"Processing complete: {stats}")
|
print(f"Processing complete: {stats}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_serve(args):
|
||||||
|
"""Start the FastAPI query server."""
|
||||||
|
config = AppConfig.load(args.config)
|
||||||
|
|
||||||
|
host = args.host or config.server.host
|
||||||
|
port = args.port or config.server.port
|
||||||
|
backend = args.graph_backend or config.server.graph_backend
|
||||||
|
max_tokens = config.server.default_max_tokens
|
||||||
|
|
||||||
|
if backend == "memory":
|
||||||
|
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
|
||||||
|
graph_db = InMemoryGraphDB()
|
||||||
|
else:
|
||||||
|
from aucourt_ingest.storage.graph_db import Neo4jGraphDB
|
||||||
|
graph_db = Neo4jGraphDB(
|
||||||
|
uri=config.storage.neo4j_uri,
|
||||||
|
user=config.storage.neo4j_user,
|
||||||
|
password=config.storage.neo4j_password,
|
||||||
|
database=config.storage.neo4j_database,
|
||||||
|
)
|
||||||
|
|
||||||
|
# VectorIndex requires ChromaDB directory — skip for memory backend in tests
|
||||||
|
vector_index = None
|
||||||
|
if backend != "memory":
|
||||||
|
from aucourt_ingest.storage.vector_index import VectorIndex
|
||||||
|
vector_index = VectorIndex(str(config.storage.chromadb_dir))
|
||||||
|
else:
|
||||||
|
from aucourt_ingest.storage.vector_index import VectorIndex
|
||||||
|
vector_index = VectorIndex(str(config.storage.chromadb_dir))
|
||||||
|
|
||||||
|
from aucourt_ingest.api.app import create_app
|
||||||
|
app = create_app(graph_db, vector_index, max_tokens)
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
print(f"Starting AuCourtIngest API on {host}:{port} (graph={backend})")
|
||||||
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
async def async_main():
|
async def async_main():
|
||||||
parser = build_parser()
|
parser = build_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -309,6 +354,8 @@ async def async_main():
|
||||||
await cmd_audit(args)
|
await cmd_audit(args)
|
||||||
elif args.mode == "process":
|
elif args.mode == "process":
|
||||||
await cmd_process(args)
|
await cmd_process(args)
|
||||||
|
elif args.mode == "serve":
|
||||||
|
cmd_serve(args)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
||||||
|
|
@ -76,3 +76,9 @@ embedding_batch_size = 100
|
||||||
bot_token = ""
|
bot_token = ""
|
||||||
chat_id = ""
|
chat_id = ""
|
||||||
enabled = false
|
enabled = false
|
||||||
|
|
||||||
|
[server]
|
||||||
|
host = "127.0.0.1"
|
||||||
|
port = 8000
|
||||||
|
graph_backend = "memory"
|
||||||
|
default_max_tokens = 4000
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ dependencies = [
|
||||||
"aiosqlite>=0.19",
|
"aiosqlite>=0.19",
|
||||||
"anthropic>=0.25",
|
"anthropic>=0.25",
|
||||||
"apscheduler>=3.10",
|
"apscheduler>=3.10",
|
||||||
|
"fastapi>=0.110",
|
||||||
|
"uvicorn[standard]>=0.29",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
309
tests/test_api.py
Normal file
309
tests/test_api.py
Normal file
|
|
@ -0,0 +1,309 @@
|
||||||
|
"""Tests for the Query API (Stage 9).
|
||||||
|
|
||||||
|
All tests use InMemoryGraphDB + FakeVectorIndex — zero external services.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from aucourt_ingest.api.app import create_app
|
||||||
|
from aucourt_ingest.jury.personas import PERSONAS
|
||||||
|
from aucourt_ingest.models import CaseMeta, Chunk
|
||||||
|
from aucourt_ingest.processing.graph_builder import GraphBuilder
|
||||||
|
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
|
||||||
|
|
||||||
|
|
||||||
|
class FakeVectorIndex:
|
||||||
|
"""Duck-typed VectorIndex for testing. Pre-seeded with results."""
|
||||||
|
|
||||||
|
def __init__(self, results: list[dict] | None = None):
|
||||||
|
self._results = results or []
|
||||||
|
self._stored: list = []
|
||||||
|
|
||||||
|
def store_chunks(self, chunks: list) -> None:
|
||||||
|
self._stored.extend(chunks)
|
||||||
|
|
||||||
|
def query(self, text: str, top_k: int = 10,
|
||||||
|
chunk_types: list[str] | None = None,
|
||||||
|
doc_ids: list[str] | None = None,
|
||||||
|
embedding: list[float] | None = None) -> list[dict]:
|
||||||
|
results = self._results[:top_k]
|
||||||
|
if chunk_types:
|
||||||
|
results = [r for r in results
|
||||||
|
if r["metadata"].get("chunk_type") in chunk_types]
|
||||||
|
if doc_ids:
|
||||||
|
results = [r for r in results
|
||||||
|
if r["metadata"].get("doc_id") in doc_ids]
|
||||||
|
return results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
return len(self._stored)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def seeded_graph():
|
||||||
|
"""InMemoryGraphDB seeded with two cases, chunks, and relationships."""
|
||||||
|
db = InMemoryGraphDB()
|
||||||
|
builder = GraphBuilder(db)
|
||||||
|
|
||||||
|
# Case 1
|
||||||
|
meta1 = CaseMeta(
|
||||||
|
case_name="R v Smith",
|
||||||
|
mnc="[2023] NSWSC 1234",
|
||||||
|
court="NSWSC",
|
||||||
|
date_delivered="2023-06-15",
|
||||||
|
jurisdiction="NSW",
|
||||||
|
matter_type="criminal",
|
||||||
|
charges=["murder", "assault"],
|
||||||
|
verdict="guilty",
|
||||||
|
judge=["Judge Brown"],
|
||||||
|
)
|
||||||
|
chunks1 = [
|
||||||
|
Chunk(chunk_id="c1-1", doc_id=meta1.mnc, chunk_type="opening",
|
||||||
|
sequence=0, text="The Crown alleges that on 15 March 2023..."),
|
||||||
|
Chunk(chunk_id="c1-2", doc_id=meta1.mnc, chunk_type="testimony",
|
||||||
|
sequence=1, text="Dr Jones testified that the victim had..."),
|
||||||
|
Chunk(chunk_id="c1-3", doc_id=meta1.mnc, chunk_type="ruling",
|
||||||
|
sequence=2, text="The court finds that exhibit A is admissible."),
|
||||||
|
Chunk(chunk_id="c1-4", doc_id=meta1.mnc, chunk_type="judgment",
|
||||||
|
sequence=3, text="The accused is found guilty of murder."),
|
||||||
|
]
|
||||||
|
for c in chunks1:
|
||||||
|
c.token_count = len(c.text) // 4
|
||||||
|
import asyncio
|
||||||
|
asyncio.get_event_loop().run_until_complete(builder.build_full(meta1, chunks1))
|
||||||
|
|
||||||
|
# Case 2
|
||||||
|
meta2 = CaseMeta(
|
||||||
|
case_name="R v Jones",
|
||||||
|
mnc="[2024] FCA 567",
|
||||||
|
court="FCA",
|
||||||
|
date_delivered="2024-01-10",
|
||||||
|
jurisdiction="CTH",
|
||||||
|
matter_type="civil",
|
||||||
|
charges=[],
|
||||||
|
verdict="civil_judgment",
|
||||||
|
judge=["Judge Davis"],
|
||||||
|
)
|
||||||
|
chunks2 = [
|
||||||
|
Chunk(chunk_id="c2-1", doc_id=meta2.mnc, chunk_type="opening",
|
||||||
|
sequence=0, text="This matter concerns an appeal against..."),
|
||||||
|
Chunk(chunk_id="c2-2", doc_id=meta2.mnc, chunk_type="closing",
|
||||||
|
sequence=1, text="Counsel for the applicant submitted that..."),
|
||||||
|
]
|
||||||
|
for c in chunks2:
|
||||||
|
c.token_count = len(c.text) // 4
|
||||||
|
asyncio.get_event_loop().run_until_complete(builder.build_full(meta2, chunks2))
|
||||||
|
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_vector_index():
|
||||||
|
"""FakeVectorIndex with pre-seeded search results."""
|
||||||
|
return FakeVectorIndex([
|
||||||
|
{
|
||||||
|
"chunk_id": "c1-2",
|
||||||
|
"text": "Dr Jones testified that the victim had...",
|
||||||
|
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "testimony", "sequence": 1},
|
||||||
|
"distance": 0.12,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "c1-3",
|
||||||
|
"text": "The court finds that exhibit A is admissible.",
|
||||||
|
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "ruling", "sequence": 2},
|
||||||
|
"distance": 0.34,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "c2-1",
|
||||||
|
"text": "This matter concerns an appeal against...",
|
||||||
|
"metadata": {"doc_id": "[2024] FCA 567", "chunk_type": "opening", "sequence": 0},
|
||||||
|
"distance": 0.56,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(seeded_graph, fake_vector_index):
|
||||||
|
app = create_app(seeded_graph, fake_vector_index, max_tokens=4000)
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Health ---
|
||||||
|
|
||||||
|
class TestHealthEndpoint:
|
||||||
|
def test_health_returns_ok(self, client):
|
||||||
|
r = client.get("/api/v1/health")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert data["cases"] == 2
|
||||||
|
assert data["chunks"] == 6
|
||||||
|
assert data["nodes"] > 0
|
||||||
|
assert data["edges"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
# --- Personas ---
|
||||||
|
|
||||||
|
class TestPersonaEndpoints:
|
||||||
|
def test_list_personas(self, client):
|
||||||
|
r = client.get("/api/v1/personas")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == len(PERSONAS)
|
||||||
|
names = [p["name"] for p in data]
|
||||||
|
assert "foreman" in names
|
||||||
|
assert "nurse" in names
|
||||||
|
|
||||||
|
def test_persona_has_required_fields(self, client):
|
||||||
|
r = client.get("/api/v1/personas")
|
||||||
|
data = r.json()
|
||||||
|
for p in data:
|
||||||
|
assert "name" in p
|
||||||
|
assert "anchor_nodes" in p
|
||||||
|
assert "edge_types" in p
|
||||||
|
assert "chunk_types" in p
|
||||||
|
|
||||||
|
|
||||||
|
# --- Cases ---
|
||||||
|
|
||||||
|
class TestCaseEndpoints:
|
||||||
|
def test_list_cases(self, client):
|
||||||
|
r = client.get("/api/v1/cases")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data) == 2
|
||||||
|
|
||||||
|
def test_list_cases_filter_court(self, client):
|
||||||
|
r = client.get("/api/v1/cases", params={"court": "NSWSC"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["mnc"] == "[2023] NSWSC 1234"
|
||||||
|
|
||||||
|
def test_list_cases_filter_empty(self, client):
|
||||||
|
r = client.get("/api/v1/cases", params={"court": "HCA"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data) == 0
|
||||||
|
|
||||||
|
def test_list_cases_pagination(self, client):
|
||||||
|
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
|
||||||
|
def test_list_cases_pagination_offset(self, client):
|
||||||
|
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 1})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
# Second case is different from first
|
||||||
|
r2 = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
|
||||||
|
assert data[0]["mnc"] != r2.json()[0]["mnc"]
|
||||||
|
|
||||||
|
def test_search_cases(self, client):
|
||||||
|
r = client.get("/api/v1/cases/search", params={"q": "murder"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
# Vector index returns results grouped by doc_id
|
||||||
|
assert len(data) >= 1
|
||||||
|
# Should have found the NSWSC case
|
||||||
|
mncs = [c["mnc"] for c in data]
|
||||||
|
assert "[2023] NSWSC 1234" in mncs
|
||||||
|
|
||||||
|
def test_get_case_graph(self, client):
|
||||||
|
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["case"]["label"] == "Case"
|
||||||
|
assert data["case"]["properties"]["mnc"] == "[2023] NSWSC 1234"
|
||||||
|
assert len(data["judges"]) == 1
|
||||||
|
assert data["judges"][0]["properties"]["name"] == "Judge Brown"
|
||||||
|
assert len(data["charges"]) == 2
|
||||||
|
|
||||||
|
def test_get_case_graph_not_found(self, client):
|
||||||
|
r = client.get("/api/v1/cases/[9999]%20FAKE%201")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
def test_get_case_graph_fca(self, client):
|
||||||
|
r = client.get("/api/v1/cases/[2024]%20FCA%20567")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["case"]["properties"]["court"] == "FCA"
|
||||||
|
assert len(data["judges"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# --- Juror Context ---
|
||||||
|
|
||||||
|
class TestJurorContextEndpoint:
|
||||||
|
def test_juror_context(self, client):
|
||||||
|
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/foreman")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["persona"] == "foreman"
|
||||||
|
assert data["case_mnc"] == "[2023] NSWSC 1234"
|
||||||
|
assert data["total_tokens"] >= 0
|
||||||
|
|
||||||
|
def test_juror_context_bad_persona(self, client):
|
||||||
|
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/nonexistent")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
def test_juror_context_custom_tokens(self, client):
|
||||||
|
r = client.get(
|
||||||
|
"/api/v1/cases/[2023]%20NSWSC%201234/juror/skeptic",
|
||||||
|
params={"max_tokens": 500},
|
||||||
|
)
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["persona"] == "skeptic"
|
||||||
|
|
||||||
|
|
||||||
|
# --- Vector Search ---
|
||||||
|
|
||||||
|
class TestVectorSearchEndpoint:
|
||||||
|
def test_vector_search(self, client):
|
||||||
|
r = client.get("/api/v1/search", params={"q": "testimony evidence"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data) >= 1
|
||||||
|
assert "chunk_id" in data[0]
|
||||||
|
assert "text" in data[0]
|
||||||
|
assert "distance" in data[0]
|
||||||
|
|
||||||
|
def test_vector_search_filter_types(self, client):
|
||||||
|
r = client.get("/api/v1/search", params={"q": "test", "chunk_types": "ruling"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert all(r["metadata"]["chunk_type"] == "ruling" for r in data)
|
||||||
|
|
||||||
|
def test_vector_search_top_k(self, client):
|
||||||
|
r = client.get("/api/v1/search", params={"q": "test", "top_k": 1})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert len(r.json()) <= 1
|
||||||
|
|
||||||
|
|
||||||
|
# --- Hybrid Search ---
|
||||||
|
|
||||||
|
class TestHybridEndpoint:
|
||||||
|
def test_hybrid_search(self, client):
|
||||||
|
r = client.get("/api/v1/hybrid", params={"q": "murder trial"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data) >= 1
|
||||||
|
assert "case_mnc" in data[0]
|
||||||
|
assert "chunks" in data[0]
|
||||||
|
assert data[0]["juror_context"] is None # no persona
|
||||||
|
|
||||||
|
def test_hybrid_search_with_persona(self, client):
|
||||||
|
r = client.get("/api/v1/hybrid", params={"q": "test", "persona": "foreman"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert len(data) >= 1
|
||||||
|
# First result should have juror_context if it found the case
|
||||||
|
if data[0]["juror_context"]:
|
||||||
|
assert data[0]["juror_context"]["persona"] == "foreman"
|
||||||
Loading…
Reference in a new issue