From 6374aea0a221ab93ca10f002feeafa9a67c544d6 Mon Sep 17 00:00:00 2001 From: slothitude Date: Sat, 30 May 2026 12:08:55 +1000 Subject: [PATCH] Stage 9: add read-only FastAPI query API for juror RAG queries 8 GET endpoints under /api/v1 for health, personas, cases, vector search, juror context, and hybrid search. Includes QueryService composing SubgraphQuery + VectorIndex + GraphDB, Pydantic response models, error handlers, and `serve` CLI mode via uvicorn. 20 new tests, 190 total, zero regressions. Co-Authored-By: Claude Opus 4.6 --- aucourt_ingest/api/__init__.py | 0 aucourt_ingest/api/app.py | 41 ++++ aucourt_ingest/api/dependencies.py | 15 ++ aucourt_ingest/api/errors.py | 16 ++ aucourt_ingest/api/routes.py | 87 ++++++++ aucourt_ingest/api/schemas.py | 74 +++++++ aucourt_ingest/api/service.py | 194 ++++++++++++++++++ aucourt_ingest/config.py | 18 ++ aucourt_ingest/main.py | 47 +++++ config.toml | 6 + pyproject.toml | 2 + tests/test_api.py | 309 +++++++++++++++++++++++++++++ 12 files changed, 809 insertions(+) create mode 100644 aucourt_ingest/api/__init__.py create mode 100644 aucourt_ingest/api/app.py create mode 100644 aucourt_ingest/api/dependencies.py create mode 100644 aucourt_ingest/api/errors.py create mode 100644 aucourt_ingest/api/routes.py create mode 100644 aucourt_ingest/api/schemas.py create mode 100644 aucourt_ingest/api/service.py create mode 100644 tests/test_api.py diff --git a/aucourt_ingest/api/__init__.py b/aucourt_ingest/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aucourt_ingest/api/app.py b/aucourt_ingest/api/app.py new file mode 100644 index 0000000..1663ba6 --- /dev/null +++ b/aucourt_ingest/api/app.py @@ -0,0 +1,41 @@ +"""FastAPI app factory for the query API.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from aucourt_ingest.api import dependencies +from aucourt_ingest.api.errors import register_error_handlers +from aucourt_ingest.api.routes import router +from aucourt_ingest.api.service import QueryService +from aucourt_ingest.storage.graph_db import GraphDB + + +def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastAPI: + """Create and configure the FastAPI app. + + Args: + graph_db: GraphDB instance (InMemoryGraphDB or Neo4jGraphDB) + vector_index: VectorIndex instance (or duck-typed fake for tests) + max_tokens: Default token budget for juror context assembly + """ + dependencies._query_service = QueryService(graph_db, vector_index, max_tokens) + + @asynccontextmanager + async def lifespan(app: FastAPI): + yield + await graph_db.close() + + app = FastAPI( + title="AuCourtIngest Query API", + description="Read-only juror RAG query API for Australian legal cases", + version="0.1.0", + lifespan=lifespan, + ) + + app.include_router(router) + register_error_handlers(app) + + return app diff --git a/aucourt_ingest/api/dependencies.py b/aucourt_ingest/api/dependencies.py new file mode 100644 index 0000000..9d78182 --- /dev/null +++ b/aucourt_ingest/api/dependencies.py @@ -0,0 +1,15 @@ +"""FastAPI dependency providers.""" + +from __future__ import annotations + +from fastapi import Depends + +from aucourt_ingest.api.service import QueryService + +_query_service: QueryService | None = None + + +def get_query_service() -> QueryService: + if _query_service is None: + raise RuntimeError("QueryService not initialised — call create_app() first") + return _query_service diff --git a/aucourt_ingest/api/errors.py b/aucourt_ingest/api/errors.py new file mode 100644 index 0000000..ea1edc0 --- /dev/null +++ b/aucourt_ingest/api/errors.py @@ -0,0 +1,16 @@ +"""Error handlers for the query API.""" + +from __future__ import annotations + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + + +def register_error_handlers(app: FastAPI) -> None: + @app.exception_handler(ValueError) + async def value_error_handler(request: Request, exc: ValueError): + return JSONResponse(status_code=400, content={"detail": str(exc)}) + + @app.exception_handler(KeyError) + async def key_error_handler(request: Request, exc: KeyError): + return JSONResponse(status_code=404, content={"detail": str(exc)}) diff --git a/aucourt_ingest/api/routes.py b/aucourt_ingest/api/routes.py new file mode 100644 index 0000000..61041ef --- /dev/null +++ b/aucourt_ingest/api/routes.py @@ -0,0 +1,87 @@ +"""Route definitions for the read-only query API.""" + +from __future__ import annotations + +from fastapi import APIRouter, Query, Depends + +from aucourt_ingest.api.dependencies import get_query_service + +router = APIRouter(prefix="/api/v1") + + +@router.get("/health") +async def health(svc=Depends(get_query_service)): + return await svc.health() + + +@router.get("/personas") +async def list_personas(svc=Depends(get_query_service)): + return svc.list_personas() + + +@router.get("/cases") +async def list_cases( + court: str | None = Query(None, description="Filter by court code"), + limit: int = Query(50, ge=1, le=200), + offset: int = Query(0, ge=0), + svc=Depends(get_query_service), +): + return await svc.list_cases(court=court, limit=limit, offset=offset) + + +@router.get("/cases/search") +async def search_cases( + q: str = Query(..., description="Search query text"), + limit: int = Query(10, ge=1, le=50), + svc=Depends(get_query_service), +): + return await svc.search_cases(q=q, limit=limit) + + +@router.get("/cases/{case_mnc}/juror/{persona}") +async def get_juror_context( + case_mnc: str, + persona: str, + max_tokens: int | None = Query(None, ge=100, le=16000), + svc=Depends(get_query_service), +): + try: + return await svc.get_juror_context( + case_mnc, persona, max_tokens, + ) + except KeyError as e: + from fastapi import HTTPException + raise HTTPException(status_code=404, detail=str(e)) + + +@router.get("/cases/{case_mnc}") +async def get_case_graph(case_mnc: str, svc=Depends(get_query_service)): + result = await svc.get_case_graph(case_mnc) + if result is None: + from fastapi import HTTPException + raise HTTPException(status_code=404, detail=f"Case not found: {case_mnc}") + return result + + +@router.get("/search") +async def vector_search( + q: str = Query(..., description="Query text"), + top_k: int = Query(10, ge=1, le=50), + chunk_types: str | None = Query(None, description="Comma-separated chunk types"), + doc_ids: str | None = Query(None, description="Comma-separated doc IDs"), + svc=Depends(get_query_service), +): + types = chunk_types.split(",") if chunk_types else None + docs = doc_ids.split(",") if doc_ids else None + return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs) + + +@router.get("/hybrid") +async def hybrid_search( + q: str = Query(..., description="Query text"), + persona: str | None = Query(None, description="Juror persona name"), + top_k: int = Query(10, ge=1, le=50), + max_tokens: int | None = Query(None, ge=100, le=16000), + svc=Depends(get_query_service), +): + return await svc.hybrid_search(q, persona_name=persona, top_k=top_k, max_tokens=max_tokens) diff --git a/aucourt_ingest/api/schemas.py b/aucourt_ingest/api/schemas.py new file mode 100644 index 0000000..faff951 --- /dev/null +++ b/aucourt_ingest/api/schemas.py @@ -0,0 +1,74 @@ +"""Pydantic v2 response models for the query API.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class PersonaInfo(BaseModel): + name: str + anchor_nodes: list[str] + edge_types: list[str] + chunk_types: list[str] + + +class CaseSummary(BaseModel): + mnc: str + court: str = "" + date: str = "" + jurisdiction: str = "" + matter_type: str = "" + verdict: str = "" + + +class ChunkResult(BaseModel): + chunk_id: str + text: str | None = None + metadata: dict = Field(default_factory=dict) + distance: float | None = None + + +class JurorContextResponse(BaseModel): + persona: str + case_mnc: str + context_text: str + source_chunk_ids: list[str] + total_tokens: int + + +class CaseGraphNode(BaseModel): + node_id: str + label: str + properties: dict = Field(default_factory=dict) + + +class CaseGraphRelationship(BaseModel): + rel_id: str + from_id: str + to_id: str + rel_type: str + properties: dict = Field(default_factory=dict) + + +class CaseGraphSummary(BaseModel): + case: CaseGraphNode + judges: list[CaseGraphNode] = Field(default_factory=list) + charges: list[CaseGraphNode] = Field(default_factory=list) + rulings: list[CaseGraphNode] = Field(default_factory=list) + witnesses: list[CaseGraphNode] = Field(default_factory=list) + chunks: list[CaseGraphNode] = Field(default_factory=list) + relationships: list[CaseGraphRelationship] = Field(default_factory=list) + + +class HybridResult(BaseModel): + case_mnc: str + chunks: list[ChunkResult] = Field(default_factory=list) + juror_context: JurorContextResponse | None = None + + +class HealthResponse(BaseModel): + status: str + nodes: int + edges: int + cases: int + chunks: int diff --git a/aucourt_ingest/api/service.py b/aucourt_ingest/api/service.py new file mode 100644 index 0000000..d35f930 --- /dev/null +++ b/aucourt_ingest/api/service.py @@ -0,0 +1,194 @@ +"""QueryService — composes SubgraphQuery + VectorIndex + GraphDB for the read API.""" + +from __future__ import annotations + +import logging + +from aucourt_ingest.jury.personas import PERSONAS, get_persona, all_persona_names +from aucourt_ingest.jury.subgraph_query import SubgraphQuery +from aucourt_ingest.storage.graph_db import GraphDB + +logger = logging.getLogger(__name__) + + +class QueryService: + """Read-only query service for juror RAG queries.""" + + def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000): + self._graph_db = graph_db + self._vector_index = vector_index + self._subgraph = SubgraphQuery(graph_db) + self._max_tokens = max_tokens + + def list_personas(self) -> list[dict]: + return [ + { + "name": name, + "anchor_nodes": p.anchor_nodes, + "edge_types": p.edge_types, + "chunk_types": p.chunk_types, + } + for name, p in PERSONAS.items() + ] + + async def list_cases(self, court: str | None = None, + limit: int = 50, offset: int = 0) -> list[dict]: + nodes = await self._graph_db.query_nodes("Case") + cases = [] + for node in nodes: + if court and node.properties.get("court") != court: + continue + cases.append({ + "mnc": node.properties.get("mnc", node.node_id), + "court": node.properties.get("court", ""), + "date": node.properties.get("date", ""), + "jurisdiction": node.properties.get("jurisdiction", ""), + "matter_type": node.properties.get("matter_type", ""), + "verdict": node.properties.get("verdict", ""), + }) + cases.sort(key=lambda c: c.get("date", ""), reverse=True) + return cases[offset:offset + limit] + + async def search_cases(self, q: str, limit: int = 10) -> list[dict]: + results = self._vector_index.query(q, top_k=limit) + seen: set[str] = set() + cases = [] + for r in results: + doc_id = r["metadata"].get("doc_id", "") + if doc_id and doc_id not in seen: + seen.add(doc_id) + cases.append(await self._get_case_by_mnc(doc_id)) + return [c for c in cases if c] + + async def get_juror_context(self, case_mnc: str, persona_name: str, + max_tokens: int | None = None) -> dict: + persona = get_persona(persona_name) + budget = max_tokens or self._max_tokens + ctx = await self._subgraph.get_juror_context(case_mnc, persona, budget) + return { + "persona": ctx.persona, + "case_mnc": ctx.case_mnc, + "context_text": ctx.context_text, + "source_chunk_ids": ctx.source_chunk_ids, + "total_tokens": ctx.total_tokens, + } + + def vector_search(self, query: str, top_k: int = 10, + chunk_types: list[str] | None = None, + doc_ids: list[str] | None = None) -> list[dict]: + return self._vector_index.query( + query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids, + ) + + async def hybrid_search(self, query: str, persona_name: str | None = None, + top_k: int = 10, + max_tokens: int | None = None) -> list[dict]: + vector_results = self._vector_index.query(query, top_k=top_k) + # Group by doc_id + by_doc: dict[str, list[dict]] = {} + for r in vector_results: + doc_id = r["metadata"].get("doc_id", "unknown") + by_doc.setdefault(doc_id, []).append(r) + + results = [] + for doc_mnc, chunks in by_doc.items(): + entry: dict = { + "case_mnc": doc_mnc, + "chunks": chunks, + "juror_context": None, + } + if persona_name and persona_name in PERSONAS: + ctx = await self.get_juror_context( + doc_mnc, persona_name, max_tokens, + ) + entry["juror_context"] = ctx + results.append(entry) + return results + + async def get_case_graph(self, case_mnc: str) -> dict | None: + case_id = f"Case:{case_mnc}" + case_node = await self._graph_db.get_node(case_id) + if case_node is None: + return None + + result = { + "case": _node_to_dict(case_node), + "judges": [], + "charges": [], + "rulings": [], + "witnesses": [], + "chunks": [], + "relationships": [], + } + + rels = await self._graph_db.get_relationships(case_id, direction="both") + neighbor_ids: set[str] = set() + + for rel in rels: + result["relationships"].append(_rel_to_dict(rel)) + neighbor_id = rel.to_id if rel.from_id == case_id else rel.from_id + neighbor_ids.add(neighbor_id) + + for nid in neighbor_ids: + node = await self._graph_db.get_node(nid) + if node is None: + continue + bucket = node.label.lower() + "s" if node.label != "Witness" else "witnesses" + # Map labels to plural keys + label_map = { + "Judge": "judges", + "Charge": "charges", + "Ruling": "rulings", + "Witness": "witnesses", + "Chunk": "chunks", + } + key = label_map.get(node.label) + if key and key in result: + result[key].append(_node_to_dict(node)) + + return result + + async def health(self) -> dict: + total_nodes = await self._graph_db.node_count() + total_edges = await self._graph_db.relationship_count() + cases = await self._graph_db.node_count("Case") + chunks = await self._graph_db.node_count("Chunk") + return { + "status": "ok", + "nodes": total_nodes, + "edges": total_edges, + "cases": cases, + "chunks": chunks, + } + + async def _get_case_by_mnc(self, mnc: str) -> dict | None: + case_id = f"Case:{mnc}" + node = await self._graph_db.get_node(case_id) + if node is None: + return None + return { + "mnc": node.properties.get("mnc", ""), + "court": node.properties.get("court", ""), + "date": node.properties.get("date", ""), + "jurisdiction": node.properties.get("jurisdiction", ""), + "matter_type": node.properties.get("matter_type", ""), + "verdict": node.properties.get("verdict", ""), + } + + +def _node_to_dict(node) -> dict: + return { + "node_id": node.node_id, + "label": node.label, + "properties": node.properties, + } + + +def _rel_to_dict(rel) -> dict: + return { + "rel_id": rel.rel_id, + "from_id": rel.from_id, + "to_id": rel.to_id, + "rel_type": rel.rel_type, + "properties": rel.properties, + } diff --git a/aucourt_ingest/config.py b/aucourt_ingest/config.py index 205cbc0..a968fac 100644 --- a/aucourt_ingest/config.py +++ b/aucourt_ingest/config.py @@ -55,6 +55,14 @@ class TelegramConfig: enabled: bool = False +@dataclass +class ServerConfig: + host: str = "127.0.0.1" + port: int = 8000 + graph_backend: str = "memory" + default_max_tokens: int = 4000 + + @dataclass class AppConfig: """Root config — loaded from config.toml.""" @@ -64,6 +72,7 @@ class AppConfig: storage: StorageConfig = field(default_factory=StorageConfig) llm: LLMConfig = field(default_factory=LLMConfig) telegram: TelegramConfig = field(default_factory=TelegramConfig) + server: ServerConfig = field(default_factory=ServerConfig) user_agent: str = "AuCourtIngest/0.1 (legal research)" @classmethod @@ -138,6 +147,15 @@ class AppConfig: enabled=tg.get("enabled", False), ) + # Server + srv = raw.get("server", {}) + config.server = ServerConfig( + host=srv.get("host", "127.0.0.1"), + port=srv.get("port", 8000), + graph_backend=srv.get("graph_backend", "memory"), + default_max_tokens=srv.get("default_max_tokens", 4000), + ) + config.user_agent = raw.get("user_agent", "AuCourtIngest/0.1 (legal research)") return config diff --git a/aucourt_ingest/main.py b/aucourt_ingest/main.py index 78e23bc..0accdf6 100644 --- a/aucourt_ingest/main.py +++ b/aucourt_ingest/main.py @@ -57,6 +57,13 @@ def build_parser() -> argparse.ArgumentParser: p_proc.add_argument("--db", default="data/meta.db", help="MetaDB path") p_proc.add_argument("--raw-dir", default="data/raw", help="Raw document storage dir") + # serve + p_serve = sub.add_parser("serve", help="Start the read-only query API server") + p_serve.add_argument("--host", default=None, help="Bind host (default from config)") + p_serve.add_argument("--port", type=int, default=None, help="Bind port (default from config)") + p_serve.add_argument("--graph-backend", default=None, + help="Graph backend: memory|neo4j (default from config)") + return parser @@ -291,6 +298,44 @@ async def cmd_process(args): print(f"Processing complete: {stats}") +def cmd_serve(args): + """Start the FastAPI query server.""" + config = AppConfig.load(args.config) + + host = args.host or config.server.host + port = args.port or config.server.port + backend = args.graph_backend or config.server.graph_backend + max_tokens = config.server.default_max_tokens + + if backend == "memory": + from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB + graph_db = InMemoryGraphDB() + else: + from aucourt_ingest.storage.graph_db import Neo4jGraphDB + graph_db = Neo4jGraphDB( + uri=config.storage.neo4j_uri, + user=config.storage.neo4j_user, + password=config.storage.neo4j_password, + database=config.storage.neo4j_database, + ) + + # VectorIndex requires ChromaDB directory — skip for memory backend in tests + vector_index = None + if backend != "memory": + from aucourt_ingest.storage.vector_index import VectorIndex + vector_index = VectorIndex(str(config.storage.chromadb_dir)) + else: + from aucourt_ingest.storage.vector_index import VectorIndex + vector_index = VectorIndex(str(config.storage.chromadb_dir)) + + from aucourt_ingest.api.app import create_app + app = create_app(graph_db, vector_index, max_tokens) + + import uvicorn + print(f"Starting AuCourtIngest API on {host}:{port} (graph={backend})") + uvicorn.run(app, host=host, port=port) + + async def async_main(): parser = build_parser() args = parser.parse_args() @@ -309,6 +354,8 @@ async def async_main(): await cmd_audit(args) elif args.mode == "process": await cmd_process(args) + elif args.mode == "serve": + cmd_serve(args) def main(): diff --git a/config.toml b/config.toml index cf8ad2b..bff5aca 100644 --- a/config.toml +++ b/config.toml @@ -76,3 +76,9 @@ embedding_batch_size = 100 bot_token = "" chat_id = "" enabled = false + +[server] +host = "127.0.0.1" +port = 8000 +graph_backend = "memory" +default_max_tokens = 4000 diff --git a/pyproject.toml b/pyproject.toml index 03cc40b..a5a2dc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,8 @@ dependencies = [ "aiosqlite>=0.19", "anthropic>=0.25", "apscheduler>=3.10", + "fastapi>=0.110", + "uvicorn[standard]>=0.29", ] [project.scripts] diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..a07d6d2 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,309 @@ +"""Tests for the Query API (Stage 9). + +All tests use InMemoryGraphDB + FakeVectorIndex — zero external services. +""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from aucourt_ingest.api.app import create_app +from aucourt_ingest.jury.personas import PERSONAS +from aucourt_ingest.models import CaseMeta, Chunk +from aucourt_ingest.processing.graph_builder import GraphBuilder +from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB + + +class FakeVectorIndex: + """Duck-typed VectorIndex for testing. Pre-seeded with results.""" + + def __init__(self, results: list[dict] | None = None): + self._results = results or [] + self._stored: list = [] + + def store_chunks(self, chunks: list) -> None: + self._stored.extend(chunks) + + def query(self, text: str, top_k: int = 10, + chunk_types: list[str] | None = None, + doc_ids: list[str] | None = None, + embedding: list[float] | None = None) -> list[dict]: + results = self._results[:top_k] + if chunk_types: + results = [r for r in results + if r["metadata"].get("chunk_type") in chunk_types] + if doc_ids: + results = [r for r in results + if r["metadata"].get("doc_id") in doc_ids] + return results + + @property + def count(self) -> int: + return len(self._stored) + + +@pytest.fixture +def seeded_graph(): + """InMemoryGraphDB seeded with two cases, chunks, and relationships.""" + db = InMemoryGraphDB() + builder = GraphBuilder(db) + + # Case 1 + meta1 = CaseMeta( + case_name="R v Smith", + mnc="[2023] NSWSC 1234", + court="NSWSC", + date_delivered="2023-06-15", + jurisdiction="NSW", + matter_type="criminal", + charges=["murder", "assault"], + verdict="guilty", + judge=["Judge Brown"], + ) + chunks1 = [ + Chunk(chunk_id="c1-1", doc_id=meta1.mnc, chunk_type="opening", + sequence=0, text="The Crown alleges that on 15 March 2023..."), + Chunk(chunk_id="c1-2", doc_id=meta1.mnc, chunk_type="testimony", + sequence=1, text="Dr Jones testified that the victim had..."), + Chunk(chunk_id="c1-3", doc_id=meta1.mnc, chunk_type="ruling", + sequence=2, text="The court finds that exhibit A is admissible."), + Chunk(chunk_id="c1-4", doc_id=meta1.mnc, chunk_type="judgment", + sequence=3, text="The accused is found guilty of murder."), + ] + for c in chunks1: + c.token_count = len(c.text) // 4 + import asyncio + asyncio.get_event_loop().run_until_complete(builder.build_full(meta1, chunks1)) + + # Case 2 + meta2 = CaseMeta( + case_name="R v Jones", + mnc="[2024] FCA 567", + court="FCA", + date_delivered="2024-01-10", + jurisdiction="CTH", + matter_type="civil", + charges=[], + verdict="civil_judgment", + judge=["Judge Davis"], + ) + chunks2 = [ + Chunk(chunk_id="c2-1", doc_id=meta2.mnc, chunk_type="opening", + sequence=0, text="This matter concerns an appeal against..."), + Chunk(chunk_id="c2-2", doc_id=meta2.mnc, chunk_type="closing", + sequence=1, text="Counsel for the applicant submitted that..."), + ] + for c in chunks2: + c.token_count = len(c.text) // 4 + asyncio.get_event_loop().run_until_complete(builder.build_full(meta2, chunks2)) + + return db + + +@pytest.fixture +def fake_vector_index(): + """FakeVectorIndex with pre-seeded search results.""" + return FakeVectorIndex([ + { + "chunk_id": "c1-2", + "text": "Dr Jones testified that the victim had...", + "metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "testimony", "sequence": 1}, + "distance": 0.12, + }, + { + "chunk_id": "c1-3", + "text": "The court finds that exhibit A is admissible.", + "metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "ruling", "sequence": 2}, + "distance": 0.34, + }, + { + "chunk_id": "c2-1", + "text": "This matter concerns an appeal against...", + "metadata": {"doc_id": "[2024] FCA 567", "chunk_type": "opening", "sequence": 0}, + "distance": 0.56, + }, + ]) + + +@pytest.fixture +def client(seeded_graph, fake_vector_index): + app = create_app(seeded_graph, fake_vector_index, max_tokens=4000) + return TestClient(app) + + +# --- Health --- + +class TestHealthEndpoint: + def test_health_returns_ok(self, client): + r = client.get("/api/v1/health") + assert r.status_code == 200 + data = r.json() + assert data["status"] == "ok" + assert data["cases"] == 2 + assert data["chunks"] == 6 + assert data["nodes"] > 0 + assert data["edges"] > 0 + + +# --- Personas --- + +class TestPersonaEndpoints: + def test_list_personas(self, client): + r = client.get("/api/v1/personas") + assert r.status_code == 200 + data = r.json() + assert isinstance(data, list) + assert len(data) == len(PERSONAS) + names = [p["name"] for p in data] + assert "foreman" in names + assert "nurse" in names + + def test_persona_has_required_fields(self, client): + r = client.get("/api/v1/personas") + data = r.json() + for p in data: + assert "name" in p + assert "anchor_nodes" in p + assert "edge_types" in p + assert "chunk_types" in p + + +# --- Cases --- + +class TestCaseEndpoints: + def test_list_cases(self, client): + r = client.get("/api/v1/cases") + assert r.status_code == 200 + data = r.json() + assert len(data) == 2 + + def test_list_cases_filter_court(self, client): + r = client.get("/api/v1/cases", params={"court": "NSWSC"}) + assert r.status_code == 200 + data = r.json() + assert len(data) == 1 + assert data[0]["mnc"] == "[2023] NSWSC 1234" + + def test_list_cases_filter_empty(self, client): + r = client.get("/api/v1/cases", params={"court": "HCA"}) + assert r.status_code == 200 + data = r.json() + assert len(data) == 0 + + def test_list_cases_pagination(self, client): + r = client.get("/api/v1/cases", params={"limit": 1, "offset": 0}) + assert r.status_code == 200 + data = r.json() + assert len(data) == 1 + + def test_list_cases_pagination_offset(self, client): + r = client.get("/api/v1/cases", params={"limit": 1, "offset": 1}) + assert r.status_code == 200 + data = r.json() + assert len(data) == 1 + # Second case is different from first + r2 = client.get("/api/v1/cases", params={"limit": 1, "offset": 0}) + assert data[0]["mnc"] != r2.json()[0]["mnc"] + + def test_search_cases(self, client): + r = client.get("/api/v1/cases/search", params={"q": "murder"}) + assert r.status_code == 200 + data = r.json() + # Vector index returns results grouped by doc_id + assert len(data) >= 1 + # Should have found the NSWSC case + mncs = [c["mnc"] for c in data] + assert "[2023] NSWSC 1234" in mncs + + def test_get_case_graph(self, client): + r = client.get("/api/v1/cases/[2023]%20NSWSC%201234") + assert r.status_code == 200 + data = r.json() + assert data["case"]["label"] == "Case" + assert data["case"]["properties"]["mnc"] == "[2023] NSWSC 1234" + assert len(data["judges"]) == 1 + assert data["judges"][0]["properties"]["name"] == "Judge Brown" + assert len(data["charges"]) == 2 + + def test_get_case_graph_not_found(self, client): + r = client.get("/api/v1/cases/[9999]%20FAKE%201") + assert r.status_code == 404 + + def test_get_case_graph_fca(self, client): + r = client.get("/api/v1/cases/[2024]%20FCA%20567") + assert r.status_code == 200 + data = r.json() + assert data["case"]["properties"]["court"] == "FCA" + assert len(data["judges"]) == 1 + + +# --- Juror Context --- + +class TestJurorContextEndpoint: + def test_juror_context(self, client): + r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/foreman") + assert r.status_code == 200 + data = r.json() + assert data["persona"] == "foreman" + assert data["case_mnc"] == "[2023] NSWSC 1234" + assert data["total_tokens"] >= 0 + + def test_juror_context_bad_persona(self, client): + r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/nonexistent") + assert r.status_code == 404 + + def test_juror_context_custom_tokens(self, client): + r = client.get( + "/api/v1/cases/[2023]%20NSWSC%201234/juror/skeptic", + params={"max_tokens": 500}, + ) + assert r.status_code == 200 + data = r.json() + assert data["persona"] == "skeptic" + + +# --- Vector Search --- + +class TestVectorSearchEndpoint: + def test_vector_search(self, client): + r = client.get("/api/v1/search", params={"q": "testimony evidence"}) + assert r.status_code == 200 + data = r.json() + assert len(data) >= 1 + assert "chunk_id" in data[0] + assert "text" in data[0] + assert "distance" in data[0] + + def test_vector_search_filter_types(self, client): + r = client.get("/api/v1/search", params={"q": "test", "chunk_types": "ruling"}) + assert r.status_code == 200 + data = r.json() + assert all(r["metadata"]["chunk_type"] == "ruling" for r in data) + + def test_vector_search_top_k(self, client): + r = client.get("/api/v1/search", params={"q": "test", "top_k": 1}) + assert r.status_code == 200 + assert len(r.json()) <= 1 + + +# --- Hybrid Search --- + +class TestHybridEndpoint: + def test_hybrid_search(self, client): + r = client.get("/api/v1/hybrid", params={"q": "murder trial"}) + assert r.status_code == 200 + data = r.json() + assert len(data) >= 1 + assert "case_mnc" in data[0] + assert "chunks" in data[0] + assert data[0]["juror_context"] is None # no persona + + def test_hybrid_search_with_persona(self, client): + r = client.get("/api/v1/hybrid", params={"q": "test", "persona": "foreman"}) + assert r.status_code == 200 + data = r.json() + assert len(data) >= 1 + # First result should have juror_context if it found the case + if data[0]["juror_context"]: + assert data[0]["juror_context"]["persona"] == "foreman"