"""Tests for the Query API (Stage 9). All tests use InMemoryGraphDB + FakeVectorIndex — zero external services. """ from __future__ import annotations import asyncio import pytest from fastapi.testclient import TestClient from aucourt_ingest.api.app import create_app from aucourt_ingest.jury.personas import PERSONAS from aucourt_ingest.models import CaseMeta, Chunk from aucourt_ingest.processing.graph_builder import GraphBuilder from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB class FakeVectorIndex: """Duck-typed VectorIndex for testing. Pre-seeded with results.""" def __init__(self, results: list[dict] | None = None): self._results = results or [] self._stored: list = [] def store_chunks(self, chunks: list) -> None: self._stored.extend(chunks) def query(self, text: str, top_k: int = 10, chunk_types: list[str] | None = None, doc_ids: list[str] | None = None, embedding: list[float] | None = None) -> list[dict]: results = self._results[:top_k] if chunk_types: results = [r for r in results if r["metadata"].get("chunk_type") in chunk_types] if doc_ids: results = [r for r in results if r["metadata"].get("doc_id") in doc_ids] return results @property def count(self) -> int: return len(self._stored) @pytest.fixture def seeded_graph(): """InMemoryGraphDB seeded with two cases, chunks, and relationships.""" db = InMemoryGraphDB() builder = GraphBuilder(db) # Case 1 meta1 = CaseMeta( case_name="R v Smith", mnc="[2023] NSWSC 1234", court="NSWSC", date_delivered="2023-06-15", jurisdiction="NSW", matter_type="criminal", charges=["murder", "assault"], verdict="guilty", judge=["Judge Brown"], ) chunks1 = [ Chunk(chunk_id="c1-1", doc_id=meta1.mnc, chunk_type="opening", sequence=0, text="The Crown alleges that on 15 March 2023..."), Chunk(chunk_id="c1-2", doc_id=meta1.mnc, chunk_type="testimony", sequence=1, text="Dr Jones testified that the victim had..."), Chunk(chunk_id="c1-3", doc_id=meta1.mnc, chunk_type="ruling", sequence=2, text="The court finds that exhibit A is admissible."), Chunk(chunk_id="c1-4", doc_id=meta1.mnc, chunk_type="judgment", sequence=3, text="The accused is found guilty of murder."), ] for c in chunks1: c.token_count = len(c.text) // 4 asyncio.run(builder.build_full(meta1, chunks1)) # Case 2 meta2 = CaseMeta( case_name="R v Jones", mnc="[2024] FCA 567", court="FCA", date_delivered="2024-01-10", jurisdiction="CTH", matter_type="civil", charges=[], verdict="civil_judgment", judge=["Judge Davis"], ) chunks2 = [ Chunk(chunk_id="c2-1", doc_id=meta2.mnc, chunk_type="opening", sequence=0, text="This matter concerns an appeal against..."), Chunk(chunk_id="c2-2", doc_id=meta2.mnc, chunk_type="closing", sequence=1, text="Counsel for the applicant submitted that..."), ] for c in chunks2: c.token_count = len(c.text) // 4 asyncio.run(builder.build_full(meta2, chunks2)) return db @pytest.fixture def fake_vector_index(): """FakeVectorIndex with pre-seeded search results.""" return FakeVectorIndex([ { "chunk_id": "c1-2", "text": "Dr Jones testified that the victim had...", "metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "testimony", "sequence": 1}, "distance": 0.12, }, { "chunk_id": "c1-3", "text": "The court finds that exhibit A is admissible.", "metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "ruling", "sequence": 2}, "distance": 0.34, }, { "chunk_id": "c2-1", "text": "This matter concerns an appeal against...", "metadata": {"doc_id": "[2024] FCA 567", "chunk_type": "opening", "sequence": 0}, "distance": 0.56, }, ]) @pytest.fixture def client(seeded_graph, fake_vector_index): app = create_app(seeded_graph, fake_vector_index, max_tokens=4000) return TestClient(app) # --- Health --- class TestHealthEndpoint: def test_health_returns_ok(self, client): r = client.get("/api/v1/health") assert r.status_code == 200 data = r.json() assert data["status"] == "ok" assert data["cases"] == 2 assert data["chunks"] == 6 assert data["nodes"] > 0 assert data["edges"] > 0 # --- Personas --- class TestPersonaEndpoints: def test_list_personas(self, client): r = client.get("/api/v1/personas") assert r.status_code == 200 data = r.json() assert isinstance(data, list) assert len(data) == len(PERSONAS) names = [p["name"] for p in data] assert "foreman" in names assert "nurse" in names def test_persona_has_required_fields(self, client): r = client.get("/api/v1/personas") data = r.json() for p in data: assert "name" in p assert "anchor_nodes" in p assert "edge_types" in p assert "chunk_types" in p # --- Cases --- class TestCaseEndpoints: def test_list_cases(self, client): r = client.get("/api/v1/cases") assert r.status_code == 200 data = r.json() assert len(data) == 2 def test_list_cases_filter_court(self, client): r = client.get("/api/v1/cases", params={"court": "NSWSC"}) assert r.status_code == 200 data = r.json() assert len(data) == 1 assert data[0]["mnc"] == "[2023] NSWSC 1234" def test_list_cases_filter_empty(self, client): r = client.get("/api/v1/cases", params={"court": "HCA"}) assert r.status_code == 200 data = r.json() assert len(data) == 0 def test_list_cases_pagination(self, client): r = client.get("/api/v1/cases", params={"limit": 1, "offset": 0}) assert r.status_code == 200 data = r.json() assert len(data) == 1 def test_list_cases_pagination_offset(self, client): r = client.get("/api/v1/cases", params={"limit": 1, "offset": 1}) assert r.status_code == 200 data = r.json() assert len(data) == 1 # Second case is different from first r2 = client.get("/api/v1/cases", params={"limit": 1, "offset": 0}) assert data[0]["mnc"] != r2.json()[0]["mnc"] def test_search_cases(self, client): r = client.get("/api/v1/cases/search", params={"q": "murder"}) assert r.status_code == 200 data = r.json() # Vector index returns results grouped by doc_id assert len(data) >= 1 # Should have found the NSWSC case mncs = [c["mnc"] for c in data] assert "[2023] NSWSC 1234" in mncs def test_get_case_graph(self, client): r = client.get("/api/v1/cases/[2023]%20NSWSC%201234") assert r.status_code == 200 data = r.json() assert data["case"]["label"] == "Case" assert data["case"]["properties"]["mnc"] == "[2023] NSWSC 1234" assert len(data["judges"]) == 1 assert data["judges"][0]["properties"]["name"] == "Judge Brown" assert len(data["charges"]) == 2 def test_get_case_graph_not_found(self, client): r = client.get("/api/v1/cases/[9999]%20FAKE%201") assert r.status_code == 404 def test_get_case_graph_fca(self, client): r = client.get("/api/v1/cases/[2024]%20FCA%20567") assert r.status_code == 200 data = r.json() assert data["case"]["properties"]["court"] == "FCA" assert len(data["judges"]) == 1 # --- Juror Context --- class TestJurorContextEndpoint: def test_juror_context(self, client): r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/foreman") assert r.status_code == 200 data = r.json() assert data["persona"] == "foreman" assert data["case_mnc"] == "[2023] NSWSC 1234" assert data["total_tokens"] >= 0 def test_juror_context_bad_persona(self, client): r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/nonexistent") assert r.status_code == 404 def test_juror_context_custom_tokens(self, client): r = client.get( "/api/v1/cases/[2023]%20NSWSC%201234/juror/skeptic", params={"max_tokens": 500}, ) assert r.status_code == 200 data = r.json() assert data["persona"] == "skeptic" # --- Vector Search --- class TestVectorSearchEndpoint: def test_vector_search(self, client): r = client.get("/api/v1/search", params={"q": "testimony evidence"}) assert r.status_code == 200 data = r.json() assert len(data) >= 1 assert "chunk_id" in data[0] assert "text" in data[0] assert "distance" in data[0] def test_vector_search_filter_types(self, client): r = client.get("/api/v1/search", params={"q": "test", "chunk_types": "ruling"}) assert r.status_code == 200 data = r.json() assert all(r["metadata"]["chunk_type"] == "ruling" for r in data) def test_vector_search_top_k(self, client): r = client.get("/api/v1/search", params={"q": "test", "top_k": 1}) assert r.status_code == 200 assert len(r.json()) <= 1 # --- Hybrid Search --- class TestHybridEndpoint: def test_hybrid_search(self, client): r = client.get("/api/v1/hybrid", params={"q": "murder trial"}) assert r.status_code == 200 data = r.json() assert len(data) >= 1 assert "case_mnc" in data[0] assert "chunks" in data[0] assert data[0]["juror_context"] is None # no persona def test_hybrid_search_with_persona(self, client): r = client.get("/api/v1/hybrid", params={"q": "test", "persona": "foreman"}) assert r.status_code == 200 data = r.json() assert len(data) >= 1 # First result should have juror_context if it found the case if data[0]["juror_context"]: assert data[0]["juror_context"]["persona"] == "foreman"