- Add response_model to all 8 route endpoints for runtime validation and correct Swagger docs - Remove global KeyError handler (routes catch it explicitly) - Add catch-all Exception handler with logging for 500 responses - Remove dead code in service.py get_case_graph (unused bucket variable) - Explicit graph_backend validation in cmd_serve (memory|neo4j, else exit) - Sanitise comma-separated query params (strip whitespace, filter empty) - Move HTTPException to top-level import in routes.py - Remove unused imports (Depends in dependencies.py, all_persona_names) - Fix deprecated asyncio.get_event_loop() in test fixture Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
310 lines
10 KiB
Python
310 lines
10 KiB
Python
"""Tests for the Query API (Stage 9).
|
|
|
|
All tests use InMemoryGraphDB + FakeVectorIndex — zero external services.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from aucourt_ingest.api.app import create_app
|
|
from aucourt_ingest.jury.personas import PERSONAS
|
|
from aucourt_ingest.models import CaseMeta, Chunk
|
|
from aucourt_ingest.processing.graph_builder import GraphBuilder
|
|
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
|
|
|
|
|
|
class FakeVectorIndex:
|
|
"""Duck-typed VectorIndex for testing. Pre-seeded with results."""
|
|
|
|
def __init__(self, results: list[dict] | None = None):
|
|
self._results = results or []
|
|
self._stored: list = []
|
|
|
|
def store_chunks(self, chunks: list) -> None:
|
|
self._stored.extend(chunks)
|
|
|
|
def query(self, text: str, top_k: int = 10,
|
|
chunk_types: list[str] | None = None,
|
|
doc_ids: list[str] | None = None,
|
|
embedding: list[float] | None = None) -> list[dict]:
|
|
results = self._results[:top_k]
|
|
if chunk_types:
|
|
results = [r for r in results
|
|
if r["metadata"].get("chunk_type") in chunk_types]
|
|
if doc_ids:
|
|
results = [r for r in results
|
|
if r["metadata"].get("doc_id") in doc_ids]
|
|
return results
|
|
|
|
@property
|
|
def count(self) -> int:
|
|
return len(self._stored)
|
|
|
|
|
|
@pytest.fixture
|
|
def seeded_graph():
|
|
"""InMemoryGraphDB seeded with two cases, chunks, and relationships."""
|
|
db = InMemoryGraphDB()
|
|
builder = GraphBuilder(db)
|
|
|
|
# Case 1
|
|
meta1 = CaseMeta(
|
|
case_name="R v Smith",
|
|
mnc="[2023] NSWSC 1234",
|
|
court="NSWSC",
|
|
date_delivered="2023-06-15",
|
|
jurisdiction="NSW",
|
|
matter_type="criminal",
|
|
charges=["murder", "assault"],
|
|
verdict="guilty",
|
|
judge=["Judge Brown"],
|
|
)
|
|
chunks1 = [
|
|
Chunk(chunk_id="c1-1", doc_id=meta1.mnc, chunk_type="opening",
|
|
sequence=0, text="The Crown alleges that on 15 March 2023..."),
|
|
Chunk(chunk_id="c1-2", doc_id=meta1.mnc, chunk_type="testimony",
|
|
sequence=1, text="Dr Jones testified that the victim had..."),
|
|
Chunk(chunk_id="c1-3", doc_id=meta1.mnc, chunk_type="ruling",
|
|
sequence=2, text="The court finds that exhibit A is admissible."),
|
|
Chunk(chunk_id="c1-4", doc_id=meta1.mnc, chunk_type="judgment",
|
|
sequence=3, text="The accused is found guilty of murder."),
|
|
]
|
|
for c in chunks1:
|
|
c.token_count = len(c.text) // 4
|
|
asyncio.run(builder.build_full(meta1, chunks1))
|
|
|
|
# Case 2
|
|
meta2 = CaseMeta(
|
|
case_name="R v Jones",
|
|
mnc="[2024] FCA 567",
|
|
court="FCA",
|
|
date_delivered="2024-01-10",
|
|
jurisdiction="CTH",
|
|
matter_type="civil",
|
|
charges=[],
|
|
verdict="civil_judgment",
|
|
judge=["Judge Davis"],
|
|
)
|
|
chunks2 = [
|
|
Chunk(chunk_id="c2-1", doc_id=meta2.mnc, chunk_type="opening",
|
|
sequence=0, text="This matter concerns an appeal against..."),
|
|
Chunk(chunk_id="c2-2", doc_id=meta2.mnc, chunk_type="closing",
|
|
sequence=1, text="Counsel for the applicant submitted that..."),
|
|
]
|
|
for c in chunks2:
|
|
c.token_count = len(c.text) // 4
|
|
asyncio.run(builder.build_full(meta2, chunks2))
|
|
|
|
return db
|
|
|
|
|
|
@pytest.fixture
|
|
def fake_vector_index():
|
|
"""FakeVectorIndex with pre-seeded search results."""
|
|
return FakeVectorIndex([
|
|
{
|
|
"chunk_id": "c1-2",
|
|
"text": "Dr Jones testified that the victim had...",
|
|
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "testimony", "sequence": 1},
|
|
"distance": 0.12,
|
|
},
|
|
{
|
|
"chunk_id": "c1-3",
|
|
"text": "The court finds that exhibit A is admissible.",
|
|
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "ruling", "sequence": 2},
|
|
"distance": 0.34,
|
|
},
|
|
{
|
|
"chunk_id": "c2-1",
|
|
"text": "This matter concerns an appeal against...",
|
|
"metadata": {"doc_id": "[2024] FCA 567", "chunk_type": "opening", "sequence": 0},
|
|
"distance": 0.56,
|
|
},
|
|
])
|
|
|
|
|
|
@pytest.fixture
|
|
def client(seeded_graph, fake_vector_index):
|
|
app = create_app(seeded_graph, fake_vector_index, max_tokens=4000)
|
|
return TestClient(app)
|
|
|
|
|
|
# --- Health ---
|
|
|
|
class TestHealthEndpoint:
|
|
def test_health_returns_ok(self, client):
|
|
r = client.get("/api/v1/health")
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert data["status"] == "ok"
|
|
assert data["cases"] == 2
|
|
assert data["chunks"] == 6
|
|
assert data["nodes"] > 0
|
|
assert data["edges"] > 0
|
|
|
|
|
|
# --- Personas ---
|
|
|
|
class TestPersonaEndpoints:
|
|
def test_list_personas(self, client):
|
|
r = client.get("/api/v1/personas")
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) == len(PERSONAS)
|
|
names = [p["name"] for p in data]
|
|
assert "foreman" in names
|
|
assert "nurse" in names
|
|
|
|
def test_persona_has_required_fields(self, client):
|
|
r = client.get("/api/v1/personas")
|
|
data = r.json()
|
|
for p in data:
|
|
assert "name" in p
|
|
assert "anchor_nodes" in p
|
|
assert "edge_types" in p
|
|
assert "chunk_types" in p
|
|
|
|
|
|
# --- Cases ---
|
|
|
|
class TestCaseEndpoints:
|
|
def test_list_cases(self, client):
|
|
r = client.get("/api/v1/cases")
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert len(data) == 2
|
|
|
|
def test_list_cases_filter_court(self, client):
|
|
r = client.get("/api/v1/cases", params={"court": "NSWSC"})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert len(data) == 1
|
|
assert data[0]["mnc"] == "[2023] NSWSC 1234"
|
|
|
|
def test_list_cases_filter_empty(self, client):
|
|
r = client.get("/api/v1/cases", params={"court": "HCA"})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert len(data) == 0
|
|
|
|
def test_list_cases_pagination(self, client):
|
|
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert len(data) == 1
|
|
|
|
def test_list_cases_pagination_offset(self, client):
|
|
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 1})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert len(data) == 1
|
|
# Second case is different from first
|
|
r2 = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
|
|
assert data[0]["mnc"] != r2.json()[0]["mnc"]
|
|
|
|
def test_search_cases(self, client):
|
|
r = client.get("/api/v1/cases/search", params={"q": "murder"})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
# Vector index returns results grouped by doc_id
|
|
assert len(data) >= 1
|
|
# Should have found the NSWSC case
|
|
mncs = [c["mnc"] for c in data]
|
|
assert "[2023] NSWSC 1234" in mncs
|
|
|
|
def test_get_case_graph(self, client):
|
|
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234")
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert data["case"]["label"] == "Case"
|
|
assert data["case"]["properties"]["mnc"] == "[2023] NSWSC 1234"
|
|
assert len(data["judges"]) == 1
|
|
assert data["judges"][0]["properties"]["name"] == "Judge Brown"
|
|
assert len(data["charges"]) == 2
|
|
|
|
def test_get_case_graph_not_found(self, client):
|
|
r = client.get("/api/v1/cases/[9999]%20FAKE%201")
|
|
assert r.status_code == 404
|
|
|
|
def test_get_case_graph_fca(self, client):
|
|
r = client.get("/api/v1/cases/[2024]%20FCA%20567")
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert data["case"]["properties"]["court"] == "FCA"
|
|
assert len(data["judges"]) == 1
|
|
|
|
|
|
# --- Juror Context ---
|
|
|
|
class TestJurorContextEndpoint:
|
|
def test_juror_context(self, client):
|
|
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/foreman")
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert data["persona"] == "foreman"
|
|
assert data["case_mnc"] == "[2023] NSWSC 1234"
|
|
assert data["total_tokens"] >= 0
|
|
|
|
def test_juror_context_bad_persona(self, client):
|
|
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/nonexistent")
|
|
assert r.status_code == 404
|
|
|
|
def test_juror_context_custom_tokens(self, client):
|
|
r = client.get(
|
|
"/api/v1/cases/[2023]%20NSWSC%201234/juror/skeptic",
|
|
params={"max_tokens": 500},
|
|
)
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert data["persona"] == "skeptic"
|
|
|
|
|
|
# --- Vector Search ---
|
|
|
|
class TestVectorSearchEndpoint:
|
|
def test_vector_search(self, client):
|
|
r = client.get("/api/v1/search", params={"q": "testimony evidence"})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert len(data) >= 1
|
|
assert "chunk_id" in data[0]
|
|
assert "text" in data[0]
|
|
assert "distance" in data[0]
|
|
|
|
def test_vector_search_filter_types(self, client):
|
|
r = client.get("/api/v1/search", params={"q": "test", "chunk_types": "ruling"})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert all(r["metadata"]["chunk_type"] == "ruling" for r in data)
|
|
|
|
def test_vector_search_top_k(self, client):
|
|
r = client.get("/api/v1/search", params={"q": "test", "top_k": 1})
|
|
assert r.status_code == 200
|
|
assert len(r.json()) <= 1
|
|
|
|
|
|
# --- Hybrid Search ---
|
|
|
|
class TestHybridEndpoint:
|
|
def test_hybrid_search(self, client):
|
|
r = client.get("/api/v1/hybrid", params={"q": "murder trial"})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert len(data) >= 1
|
|
assert "case_mnc" in data[0]
|
|
assert "chunks" in data[0]
|
|
assert data[0]["juror_context"] is None # no persona
|
|
|
|
def test_hybrid_search_with_persona(self, client):
|
|
r = client.get("/api/v1/hybrid", params={"q": "test", "persona": "foreman"})
|
|
assert r.status_code == 200
|
|
data = r.json()
|
|
assert len(data) >= 1
|
|
# First result should have juror_context if it found the case
|
|
if data[0]["juror_context"]:
|
|
assert data[0]["juror_context"]["persona"] == "foreman"
|