aucourt-ingest/tests/test_api.py
slothitude 6374aea0a2 Stage 9: add read-only FastAPI query API for juror RAG queries
8 GET endpoints under /api/v1 for health, personas, cases, vector search,
juror context, and hybrid search. Includes QueryService composing SubgraphQuery
+ VectorIndex + GraphDB, Pydantic response models, error handlers, and
`serve` CLI mode via uvicorn. 20 new tests, 190 total, zero regressions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-30 12:08:55 +10:00

309 lines
10 KiB
Python

"""Tests for the Query API (Stage 9).
All tests use InMemoryGraphDB + FakeVectorIndex — zero external services.
"""
from __future__ import annotations
import pytest
from fastapi.testclient import TestClient
from aucourt_ingest.api.app import create_app
from aucourt_ingest.jury.personas import PERSONAS
from aucourt_ingest.models import CaseMeta, Chunk
from aucourt_ingest.processing.graph_builder import GraphBuilder
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
class FakeVectorIndex:
"""Duck-typed VectorIndex for testing. Pre-seeded with results."""
def __init__(self, results: list[dict] | None = None):
self._results = results or []
self._stored: list = []
def store_chunks(self, chunks: list) -> None:
self._stored.extend(chunks)
def query(self, text: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None,
embedding: list[float] | None = None) -> list[dict]:
results = self._results[:top_k]
if chunk_types:
results = [r for r in results
if r["metadata"].get("chunk_type") in chunk_types]
if doc_ids:
results = [r for r in results
if r["metadata"].get("doc_id") in doc_ids]
return results
@property
def count(self) -> int:
return len(self._stored)
@pytest.fixture
def seeded_graph():
"""InMemoryGraphDB seeded with two cases, chunks, and relationships."""
db = InMemoryGraphDB()
builder = GraphBuilder(db)
# Case 1
meta1 = CaseMeta(
case_name="R v Smith",
mnc="[2023] NSWSC 1234",
court="NSWSC",
date_delivered="2023-06-15",
jurisdiction="NSW",
matter_type="criminal",
charges=["murder", "assault"],
verdict="guilty",
judge=["Judge Brown"],
)
chunks1 = [
Chunk(chunk_id="c1-1", doc_id=meta1.mnc, chunk_type="opening",
sequence=0, text="The Crown alleges that on 15 March 2023..."),
Chunk(chunk_id="c1-2", doc_id=meta1.mnc, chunk_type="testimony",
sequence=1, text="Dr Jones testified that the victim had..."),
Chunk(chunk_id="c1-3", doc_id=meta1.mnc, chunk_type="ruling",
sequence=2, text="The court finds that exhibit A is admissible."),
Chunk(chunk_id="c1-4", doc_id=meta1.mnc, chunk_type="judgment",
sequence=3, text="The accused is found guilty of murder."),
]
for c in chunks1:
c.token_count = len(c.text) // 4
import asyncio
asyncio.get_event_loop().run_until_complete(builder.build_full(meta1, chunks1))
# Case 2
meta2 = CaseMeta(
case_name="R v Jones",
mnc="[2024] FCA 567",
court="FCA",
date_delivered="2024-01-10",
jurisdiction="CTH",
matter_type="civil",
charges=[],
verdict="civil_judgment",
judge=["Judge Davis"],
)
chunks2 = [
Chunk(chunk_id="c2-1", doc_id=meta2.mnc, chunk_type="opening",
sequence=0, text="This matter concerns an appeal against..."),
Chunk(chunk_id="c2-2", doc_id=meta2.mnc, chunk_type="closing",
sequence=1, text="Counsel for the applicant submitted that..."),
]
for c in chunks2:
c.token_count = len(c.text) // 4
asyncio.get_event_loop().run_until_complete(builder.build_full(meta2, chunks2))
return db
@pytest.fixture
def fake_vector_index():
"""FakeVectorIndex with pre-seeded search results."""
return FakeVectorIndex([
{
"chunk_id": "c1-2",
"text": "Dr Jones testified that the victim had...",
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "testimony", "sequence": 1},
"distance": 0.12,
},
{
"chunk_id": "c1-3",
"text": "The court finds that exhibit A is admissible.",
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "ruling", "sequence": 2},
"distance": 0.34,
},
{
"chunk_id": "c2-1",
"text": "This matter concerns an appeal against...",
"metadata": {"doc_id": "[2024] FCA 567", "chunk_type": "opening", "sequence": 0},
"distance": 0.56,
},
])
@pytest.fixture
def client(seeded_graph, fake_vector_index):
app = create_app(seeded_graph, fake_vector_index, max_tokens=4000)
return TestClient(app)
# --- Health ---
class TestHealthEndpoint:
def test_health_returns_ok(self, client):
r = client.get("/api/v1/health")
assert r.status_code == 200
data = r.json()
assert data["status"] == "ok"
assert data["cases"] == 2
assert data["chunks"] == 6
assert data["nodes"] > 0
assert data["edges"] > 0
# --- Personas ---
class TestPersonaEndpoints:
def test_list_personas(self, client):
r = client.get("/api/v1/personas")
assert r.status_code == 200
data = r.json()
assert isinstance(data, list)
assert len(data) == len(PERSONAS)
names = [p["name"] for p in data]
assert "foreman" in names
assert "nurse" in names
def test_persona_has_required_fields(self, client):
r = client.get("/api/v1/personas")
data = r.json()
for p in data:
assert "name" in p
assert "anchor_nodes" in p
assert "edge_types" in p
assert "chunk_types" in p
# --- Cases ---
class TestCaseEndpoints:
def test_list_cases(self, client):
r = client.get("/api/v1/cases")
assert r.status_code == 200
data = r.json()
assert len(data) == 2
def test_list_cases_filter_court(self, client):
r = client.get("/api/v1/cases", params={"court": "NSWSC"})
assert r.status_code == 200
data = r.json()
assert len(data) == 1
assert data[0]["mnc"] == "[2023] NSWSC 1234"
def test_list_cases_filter_empty(self, client):
r = client.get("/api/v1/cases", params={"court": "HCA"})
assert r.status_code == 200
data = r.json()
assert len(data) == 0
def test_list_cases_pagination(self, client):
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
assert r.status_code == 200
data = r.json()
assert len(data) == 1
def test_list_cases_pagination_offset(self, client):
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 1})
assert r.status_code == 200
data = r.json()
assert len(data) == 1
# Second case is different from first
r2 = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
assert data[0]["mnc"] != r2.json()[0]["mnc"]
def test_search_cases(self, client):
r = client.get("/api/v1/cases/search", params={"q": "murder"})
assert r.status_code == 200
data = r.json()
# Vector index returns results grouped by doc_id
assert len(data) >= 1
# Should have found the NSWSC case
mncs = [c["mnc"] for c in data]
assert "[2023] NSWSC 1234" in mncs
def test_get_case_graph(self, client):
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234")
assert r.status_code == 200
data = r.json()
assert data["case"]["label"] == "Case"
assert data["case"]["properties"]["mnc"] == "[2023] NSWSC 1234"
assert len(data["judges"]) == 1
assert data["judges"][0]["properties"]["name"] == "Judge Brown"
assert len(data["charges"]) == 2
def test_get_case_graph_not_found(self, client):
r = client.get("/api/v1/cases/[9999]%20FAKE%201")
assert r.status_code == 404
def test_get_case_graph_fca(self, client):
r = client.get("/api/v1/cases/[2024]%20FCA%20567")
assert r.status_code == 200
data = r.json()
assert data["case"]["properties"]["court"] == "FCA"
assert len(data["judges"]) == 1
# --- Juror Context ---
class TestJurorContextEndpoint:
def test_juror_context(self, client):
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/foreman")
assert r.status_code == 200
data = r.json()
assert data["persona"] == "foreman"
assert data["case_mnc"] == "[2023] NSWSC 1234"
assert data["total_tokens"] >= 0
def test_juror_context_bad_persona(self, client):
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/nonexistent")
assert r.status_code == 404
def test_juror_context_custom_tokens(self, client):
r = client.get(
"/api/v1/cases/[2023]%20NSWSC%201234/juror/skeptic",
params={"max_tokens": 500},
)
assert r.status_code == 200
data = r.json()
assert data["persona"] == "skeptic"
# --- Vector Search ---
class TestVectorSearchEndpoint:
def test_vector_search(self, client):
r = client.get("/api/v1/search", params={"q": "testimony evidence"})
assert r.status_code == 200
data = r.json()
assert len(data) >= 1
assert "chunk_id" in data[0]
assert "text" in data[0]
assert "distance" in data[0]
def test_vector_search_filter_types(self, client):
r = client.get("/api/v1/search", params={"q": "test", "chunk_types": "ruling"})
assert r.status_code == 200
data = r.json()
assert all(r["metadata"]["chunk_type"] == "ruling" for r in data)
def test_vector_search_top_k(self, client):
r = client.get("/api/v1/search", params={"q": "test", "top_k": 1})
assert r.status_code == 200
assert len(r.json()) <= 1
# --- Hybrid Search ---
class TestHybridEndpoint:
def test_hybrid_search(self, client):
r = client.get("/api/v1/hybrid", params={"q": "murder trial"})
assert r.status_code == 200
data = r.json()
assert len(data) >= 1
assert "case_mnc" in data[0]
assert "chunks" in data[0]
assert data[0]["juror_context"] is None # no persona
def test_hybrid_search_with_persona(self, client):
r = client.get("/api/v1/hybrid", params={"q": "test", "persona": "foreman"})
assert r.status_code == 200
data = r.json()
assert len(data) >= 1
# First result should have juror_context if it found the case
if data[0]["juror_context"]:
assert data[0]["juror_context"]["persona"] == "foreman"