Stage 9: add read-only FastAPI query API for juror RAG queries

8 GET endpoints under /api/v1 for health, personas, cases, vector search,
juror context, and hybrid search. Includes QueryService composing SubgraphQuery
+ VectorIndex + GraphDB, Pydantic response models, error handlers, and
`serve` CLI mode via uvicorn. 20 new tests, 190 total, zero regressions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
slothitude 2026-05-30 12:08:55 +10:00
parent d77fe12cfc
commit 6374aea0a2
12 changed files with 809 additions and 0 deletions

View file

41
aucourt_ingest/api/app.py Normal file
View file

@ -0,0 +1,41 @@
"""FastAPI app factory for the query API."""
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI
from aucourt_ingest.api import dependencies
from aucourt_ingest.api.errors import register_error_handlers
from aucourt_ingest.api.routes import router
from aucourt_ingest.api.service import QueryService
from aucourt_ingest.storage.graph_db import GraphDB
def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastAPI:
"""Create and configure the FastAPI app.
Args:
graph_db: GraphDB instance (InMemoryGraphDB or Neo4jGraphDB)
vector_index: VectorIndex instance (or duck-typed fake for tests)
max_tokens: Default token budget for juror context assembly
"""
dependencies._query_service = QueryService(graph_db, vector_index, max_tokens)
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
await graph_db.close()
app = FastAPI(
title="AuCourtIngest Query API",
description="Read-only juror RAG query API for Australian legal cases",
version="0.1.0",
lifespan=lifespan,
)
app.include_router(router)
register_error_handlers(app)
return app

View file

@ -0,0 +1,15 @@
"""FastAPI dependency providers."""
from __future__ import annotations
from fastapi import Depends
from aucourt_ingest.api.service import QueryService
_query_service: QueryService | None = None
def get_query_service() -> QueryService:
if _query_service is None:
raise RuntimeError("QueryService not initialised — call create_app() first")
return _query_service

View file

@ -0,0 +1,16 @@
"""Error handlers for the query API."""
from __future__ import annotations
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
def register_error_handlers(app: FastAPI) -> None:
@app.exception_handler(ValueError)
async def value_error_handler(request: Request, exc: ValueError):
return JSONResponse(status_code=400, content={"detail": str(exc)})
@app.exception_handler(KeyError)
async def key_error_handler(request: Request, exc: KeyError):
return JSONResponse(status_code=404, content={"detail": str(exc)})

View file

@ -0,0 +1,87 @@
"""Route definitions for the read-only query API."""
from __future__ import annotations
from fastapi import APIRouter, Query, Depends
from aucourt_ingest.api.dependencies import get_query_service
router = APIRouter(prefix="/api/v1")
@router.get("/health")
async def health(svc=Depends(get_query_service)):
return await svc.health()
@router.get("/personas")
async def list_personas(svc=Depends(get_query_service)):
return svc.list_personas()
@router.get("/cases")
async def list_cases(
court: str | None = Query(None, description="Filter by court code"),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
svc=Depends(get_query_service),
):
return await svc.list_cases(court=court, limit=limit, offset=offset)
@router.get("/cases/search")
async def search_cases(
q: str = Query(..., description="Search query text"),
limit: int = Query(10, ge=1, le=50),
svc=Depends(get_query_service),
):
return await svc.search_cases(q=q, limit=limit)
@router.get("/cases/{case_mnc}/juror/{persona}")
async def get_juror_context(
case_mnc: str,
persona: str,
max_tokens: int | None = Query(None, ge=100, le=16000),
svc=Depends(get_query_service),
):
try:
return await svc.get_juror_context(
case_mnc, persona, max_tokens,
)
except KeyError as e:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail=str(e))
@router.get("/cases/{case_mnc}")
async def get_case_graph(case_mnc: str, svc=Depends(get_query_service)):
result = await svc.get_case_graph(case_mnc)
if result is None:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail=f"Case not found: {case_mnc}")
return result
@router.get("/search")
async def vector_search(
q: str = Query(..., description="Query text"),
top_k: int = Query(10, ge=1, le=50),
chunk_types: str | None = Query(None, description="Comma-separated chunk types"),
doc_ids: str | None = Query(None, description="Comma-separated doc IDs"),
svc=Depends(get_query_service),
):
types = chunk_types.split(",") if chunk_types else None
docs = doc_ids.split(",") if doc_ids else None
return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
@router.get("/hybrid")
async def hybrid_search(
q: str = Query(..., description="Query text"),
persona: str | None = Query(None, description="Juror persona name"),
top_k: int = Query(10, ge=1, le=50),
max_tokens: int | None = Query(None, ge=100, le=16000),
svc=Depends(get_query_service),
):
return await svc.hybrid_search(q, persona_name=persona, top_k=top_k, max_tokens=max_tokens)

View file

@ -0,0 +1,74 @@
"""Pydantic v2 response models for the query API."""
from __future__ import annotations
from pydantic import BaseModel, Field
class PersonaInfo(BaseModel):
name: str
anchor_nodes: list[str]
edge_types: list[str]
chunk_types: list[str]
class CaseSummary(BaseModel):
mnc: str
court: str = ""
date: str = ""
jurisdiction: str = ""
matter_type: str = ""
verdict: str = ""
class ChunkResult(BaseModel):
chunk_id: str
text: str | None = None
metadata: dict = Field(default_factory=dict)
distance: float | None = None
class JurorContextResponse(BaseModel):
persona: str
case_mnc: str
context_text: str
source_chunk_ids: list[str]
total_tokens: int
class CaseGraphNode(BaseModel):
node_id: str
label: str
properties: dict = Field(default_factory=dict)
class CaseGraphRelationship(BaseModel):
rel_id: str
from_id: str
to_id: str
rel_type: str
properties: dict = Field(default_factory=dict)
class CaseGraphSummary(BaseModel):
case: CaseGraphNode
judges: list[CaseGraphNode] = Field(default_factory=list)
charges: list[CaseGraphNode] = Field(default_factory=list)
rulings: list[CaseGraphNode] = Field(default_factory=list)
witnesses: list[CaseGraphNode] = Field(default_factory=list)
chunks: list[CaseGraphNode] = Field(default_factory=list)
relationships: list[CaseGraphRelationship] = Field(default_factory=list)
class HybridResult(BaseModel):
case_mnc: str
chunks: list[ChunkResult] = Field(default_factory=list)
juror_context: JurorContextResponse | None = None
class HealthResponse(BaseModel):
status: str
nodes: int
edges: int
cases: int
chunks: int

View file

@ -0,0 +1,194 @@
"""QueryService — composes SubgraphQuery + VectorIndex + GraphDB for the read API."""
from __future__ import annotations
import logging
from aucourt_ingest.jury.personas import PERSONAS, get_persona, all_persona_names
from aucourt_ingest.jury.subgraph_query import SubgraphQuery
from aucourt_ingest.storage.graph_db import GraphDB
logger = logging.getLogger(__name__)
class QueryService:
"""Read-only query service for juror RAG queries."""
def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000):
self._graph_db = graph_db
self._vector_index = vector_index
self._subgraph = SubgraphQuery(graph_db)
self._max_tokens = max_tokens
def list_personas(self) -> list[dict]:
return [
{
"name": name,
"anchor_nodes": p.anchor_nodes,
"edge_types": p.edge_types,
"chunk_types": p.chunk_types,
}
for name, p in PERSONAS.items()
]
async def list_cases(self, court: str | None = None,
limit: int = 50, offset: int = 0) -> list[dict]:
nodes = await self._graph_db.query_nodes("Case")
cases = []
for node in nodes:
if court and node.properties.get("court") != court:
continue
cases.append({
"mnc": node.properties.get("mnc", node.node_id),
"court": node.properties.get("court", ""),
"date": node.properties.get("date", ""),
"jurisdiction": node.properties.get("jurisdiction", ""),
"matter_type": node.properties.get("matter_type", ""),
"verdict": node.properties.get("verdict", ""),
})
cases.sort(key=lambda c: c.get("date", ""), reverse=True)
return cases[offset:offset + limit]
async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
results = self._vector_index.query(q, top_k=limit)
seen: set[str] = set()
cases = []
for r in results:
doc_id = r["metadata"].get("doc_id", "")
if doc_id and doc_id not in seen:
seen.add(doc_id)
cases.append(await self._get_case_by_mnc(doc_id))
return [c for c in cases if c]
async def get_juror_context(self, case_mnc: str, persona_name: str,
max_tokens: int | None = None) -> dict:
persona = get_persona(persona_name)
budget = max_tokens or self._max_tokens
ctx = await self._subgraph.get_juror_context(case_mnc, persona, budget)
return {
"persona": ctx.persona,
"case_mnc": ctx.case_mnc,
"context_text": ctx.context_text,
"source_chunk_ids": ctx.source_chunk_ids,
"total_tokens": ctx.total_tokens,
}
def vector_search(self, query: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None) -> list[dict]:
return self._vector_index.query(
query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
)
async def hybrid_search(self, query: str, persona_name: str | None = None,
top_k: int = 10,
max_tokens: int | None = None) -> list[dict]:
vector_results = self._vector_index.query(query, top_k=top_k)
# Group by doc_id
by_doc: dict[str, list[dict]] = {}
for r in vector_results:
doc_id = r["metadata"].get("doc_id", "unknown")
by_doc.setdefault(doc_id, []).append(r)
results = []
for doc_mnc, chunks in by_doc.items():
entry: dict = {
"case_mnc": doc_mnc,
"chunks": chunks,
"juror_context": None,
}
if persona_name and persona_name in PERSONAS:
ctx = await self.get_juror_context(
doc_mnc, persona_name, max_tokens,
)
entry["juror_context"] = ctx
results.append(entry)
return results
async def get_case_graph(self, case_mnc: str) -> dict | None:
case_id = f"Case:{case_mnc}"
case_node = await self._graph_db.get_node(case_id)
if case_node is None:
return None
result = {
"case": _node_to_dict(case_node),
"judges": [],
"charges": [],
"rulings": [],
"witnesses": [],
"chunks": [],
"relationships": [],
}
rels = await self._graph_db.get_relationships(case_id, direction="both")
neighbor_ids: set[str] = set()
for rel in rels:
result["relationships"].append(_rel_to_dict(rel))
neighbor_id = rel.to_id if rel.from_id == case_id else rel.from_id
neighbor_ids.add(neighbor_id)
for nid in neighbor_ids:
node = await self._graph_db.get_node(nid)
if node is None:
continue
bucket = node.label.lower() + "s" if node.label != "Witness" else "witnesses"
# Map labels to plural keys
label_map = {
"Judge": "judges",
"Charge": "charges",
"Ruling": "rulings",
"Witness": "witnesses",
"Chunk": "chunks",
}
key = label_map.get(node.label)
if key and key in result:
result[key].append(_node_to_dict(node))
return result
async def health(self) -> dict:
total_nodes = await self._graph_db.node_count()
total_edges = await self._graph_db.relationship_count()
cases = await self._graph_db.node_count("Case")
chunks = await self._graph_db.node_count("Chunk")
return {
"status": "ok",
"nodes": total_nodes,
"edges": total_edges,
"cases": cases,
"chunks": chunks,
}
async def _get_case_by_mnc(self, mnc: str) -> dict | None:
case_id = f"Case:{mnc}"
node = await self._graph_db.get_node(case_id)
if node is None:
return None
return {
"mnc": node.properties.get("mnc", ""),
"court": node.properties.get("court", ""),
"date": node.properties.get("date", ""),
"jurisdiction": node.properties.get("jurisdiction", ""),
"matter_type": node.properties.get("matter_type", ""),
"verdict": node.properties.get("verdict", ""),
}
def _node_to_dict(node) -> dict:
return {
"node_id": node.node_id,
"label": node.label,
"properties": node.properties,
}
def _rel_to_dict(rel) -> dict:
return {
"rel_id": rel.rel_id,
"from_id": rel.from_id,
"to_id": rel.to_id,
"rel_type": rel.rel_type,
"properties": rel.properties,
}

View file

@ -55,6 +55,14 @@ class TelegramConfig:
enabled: bool = False
@dataclass
class ServerConfig:
host: str = "127.0.0.1"
port: int = 8000
graph_backend: str = "memory"
default_max_tokens: int = 4000
@dataclass
class AppConfig:
"""Root config — loaded from config.toml."""
@ -64,6 +72,7 @@ class AppConfig:
storage: StorageConfig = field(default_factory=StorageConfig)
llm: LLMConfig = field(default_factory=LLMConfig)
telegram: TelegramConfig = field(default_factory=TelegramConfig)
server: ServerConfig = field(default_factory=ServerConfig)
user_agent: str = "AuCourtIngest/0.1 (legal research)"
@classmethod
@ -138,6 +147,15 @@ class AppConfig:
enabled=tg.get("enabled", False),
)
# Server
srv = raw.get("server", {})
config.server = ServerConfig(
host=srv.get("host", "127.0.0.1"),
port=srv.get("port", 8000),
graph_backend=srv.get("graph_backend", "memory"),
default_max_tokens=srv.get("default_max_tokens", 4000),
)
config.user_agent = raw.get("user_agent", "AuCourtIngest/0.1 (legal research)")
return config

View file

@ -57,6 +57,13 @@ def build_parser() -> argparse.ArgumentParser:
p_proc.add_argument("--db", default="data/meta.db", help="MetaDB path")
p_proc.add_argument("--raw-dir", default="data/raw", help="Raw document storage dir")
# serve
p_serve = sub.add_parser("serve", help="Start the read-only query API server")
p_serve.add_argument("--host", default=None, help="Bind host (default from config)")
p_serve.add_argument("--port", type=int, default=None, help="Bind port (default from config)")
p_serve.add_argument("--graph-backend", default=None,
help="Graph backend: memory|neo4j (default from config)")
return parser
@ -291,6 +298,44 @@ async def cmd_process(args):
print(f"Processing complete: {stats}")
def cmd_serve(args):
"""Start the FastAPI query server."""
config = AppConfig.load(args.config)
host = args.host or config.server.host
port = args.port or config.server.port
backend = args.graph_backend or config.server.graph_backend
max_tokens = config.server.default_max_tokens
if backend == "memory":
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
graph_db = InMemoryGraphDB()
else:
from aucourt_ingest.storage.graph_db import Neo4jGraphDB
graph_db = Neo4jGraphDB(
uri=config.storage.neo4j_uri,
user=config.storage.neo4j_user,
password=config.storage.neo4j_password,
database=config.storage.neo4j_database,
)
# VectorIndex requires ChromaDB directory — skip for memory backend in tests
vector_index = None
if backend != "memory":
from aucourt_ingest.storage.vector_index import VectorIndex
vector_index = VectorIndex(str(config.storage.chromadb_dir))
else:
from aucourt_ingest.storage.vector_index import VectorIndex
vector_index = VectorIndex(str(config.storage.chromadb_dir))
from aucourt_ingest.api.app import create_app
app = create_app(graph_db, vector_index, max_tokens)
import uvicorn
print(f"Starting AuCourtIngest API on {host}:{port} (graph={backend})")
uvicorn.run(app, host=host, port=port)
async def async_main():
parser = build_parser()
args = parser.parse_args()
@ -309,6 +354,8 @@ async def async_main():
await cmd_audit(args)
elif args.mode == "process":
await cmd_process(args)
elif args.mode == "serve":
cmd_serve(args)
def main():

View file

@ -76,3 +76,9 @@ embedding_batch_size = 100
bot_token = ""
chat_id = ""
enabled = false
[server]
host = "127.0.0.1"
port = 8000
graph_backend = "memory"
default_max_tokens = 4000

View file

@ -13,6 +13,8 @@ dependencies = [
"aiosqlite>=0.19",
"anthropic>=0.25",
"apscheduler>=3.10",
"fastapi>=0.110",
"uvicorn[standard]>=0.29",
]
[project.scripts]

309
tests/test_api.py Normal file
View file

@ -0,0 +1,309 @@
"""Tests for the Query API (Stage 9).
All tests use InMemoryGraphDB + FakeVectorIndex zero external services.
"""
from __future__ import annotations
import pytest
from fastapi.testclient import TestClient
from aucourt_ingest.api.app import create_app
from aucourt_ingest.jury.personas import PERSONAS
from aucourt_ingest.models import CaseMeta, Chunk
from aucourt_ingest.processing.graph_builder import GraphBuilder
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
class FakeVectorIndex:
"""Duck-typed VectorIndex for testing. Pre-seeded with results."""
def __init__(self, results: list[dict] | None = None):
self._results = results or []
self._stored: list = []
def store_chunks(self, chunks: list) -> None:
self._stored.extend(chunks)
def query(self, text: str, top_k: int = 10,
chunk_types: list[str] | None = None,
doc_ids: list[str] | None = None,
embedding: list[float] | None = None) -> list[dict]:
results = self._results[:top_k]
if chunk_types:
results = [r for r in results
if r["metadata"].get("chunk_type") in chunk_types]
if doc_ids:
results = [r for r in results
if r["metadata"].get("doc_id") in doc_ids]
return results
@property
def count(self) -> int:
return len(self._stored)
@pytest.fixture
def seeded_graph():
"""InMemoryGraphDB seeded with two cases, chunks, and relationships."""
db = InMemoryGraphDB()
builder = GraphBuilder(db)
# Case 1
meta1 = CaseMeta(
case_name="R v Smith",
mnc="[2023] NSWSC 1234",
court="NSWSC",
date_delivered="2023-06-15",
jurisdiction="NSW",
matter_type="criminal",
charges=["murder", "assault"],
verdict="guilty",
judge=["Judge Brown"],
)
chunks1 = [
Chunk(chunk_id="c1-1", doc_id=meta1.mnc, chunk_type="opening",
sequence=0, text="The Crown alleges that on 15 March 2023..."),
Chunk(chunk_id="c1-2", doc_id=meta1.mnc, chunk_type="testimony",
sequence=1, text="Dr Jones testified that the victim had..."),
Chunk(chunk_id="c1-3", doc_id=meta1.mnc, chunk_type="ruling",
sequence=2, text="The court finds that exhibit A is admissible."),
Chunk(chunk_id="c1-4", doc_id=meta1.mnc, chunk_type="judgment",
sequence=3, text="The accused is found guilty of murder."),
]
for c in chunks1:
c.token_count = len(c.text) // 4
import asyncio
asyncio.get_event_loop().run_until_complete(builder.build_full(meta1, chunks1))
# Case 2
meta2 = CaseMeta(
case_name="R v Jones",
mnc="[2024] FCA 567",
court="FCA",
date_delivered="2024-01-10",
jurisdiction="CTH",
matter_type="civil",
charges=[],
verdict="civil_judgment",
judge=["Judge Davis"],
)
chunks2 = [
Chunk(chunk_id="c2-1", doc_id=meta2.mnc, chunk_type="opening",
sequence=0, text="This matter concerns an appeal against..."),
Chunk(chunk_id="c2-2", doc_id=meta2.mnc, chunk_type="closing",
sequence=1, text="Counsel for the applicant submitted that..."),
]
for c in chunks2:
c.token_count = len(c.text) // 4
asyncio.get_event_loop().run_until_complete(builder.build_full(meta2, chunks2))
return db
@pytest.fixture
def fake_vector_index():
"""FakeVectorIndex with pre-seeded search results."""
return FakeVectorIndex([
{
"chunk_id": "c1-2",
"text": "Dr Jones testified that the victim had...",
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "testimony", "sequence": 1},
"distance": 0.12,
},
{
"chunk_id": "c1-3",
"text": "The court finds that exhibit A is admissible.",
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "ruling", "sequence": 2},
"distance": 0.34,
},
{
"chunk_id": "c2-1",
"text": "This matter concerns an appeal against...",
"metadata": {"doc_id": "[2024] FCA 567", "chunk_type": "opening", "sequence": 0},
"distance": 0.56,
},
])
@pytest.fixture
def client(seeded_graph, fake_vector_index):
app = create_app(seeded_graph, fake_vector_index, max_tokens=4000)
return TestClient(app)
# --- Health ---
class TestHealthEndpoint:
def test_health_returns_ok(self, client):
r = client.get("/api/v1/health")
assert r.status_code == 200
data = r.json()
assert data["status"] == "ok"
assert data["cases"] == 2
assert data["chunks"] == 6
assert data["nodes"] > 0
assert data["edges"] > 0
# --- Personas ---
class TestPersonaEndpoints:
def test_list_personas(self, client):
r = client.get("/api/v1/personas")
assert r.status_code == 200
data = r.json()
assert isinstance(data, list)
assert len(data) == len(PERSONAS)
names = [p["name"] for p in data]
assert "foreman" in names
assert "nurse" in names
def test_persona_has_required_fields(self, client):
r = client.get("/api/v1/personas")
data = r.json()
for p in data:
assert "name" in p
assert "anchor_nodes" in p
assert "edge_types" in p
assert "chunk_types" in p
# --- Cases ---
class TestCaseEndpoints:
def test_list_cases(self, client):
r = client.get("/api/v1/cases")
assert r.status_code == 200
data = r.json()
assert len(data) == 2
def test_list_cases_filter_court(self, client):
r = client.get("/api/v1/cases", params={"court": "NSWSC"})
assert r.status_code == 200
data = r.json()
assert len(data) == 1
assert data[0]["mnc"] == "[2023] NSWSC 1234"
def test_list_cases_filter_empty(self, client):
r = client.get("/api/v1/cases", params={"court": "HCA"})
assert r.status_code == 200
data = r.json()
assert len(data) == 0
def test_list_cases_pagination(self, client):
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
assert r.status_code == 200
data = r.json()
assert len(data) == 1
def test_list_cases_pagination_offset(self, client):
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 1})
assert r.status_code == 200
data = r.json()
assert len(data) == 1
# Second case is different from first
r2 = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
assert data[0]["mnc"] != r2.json()[0]["mnc"]
def test_search_cases(self, client):
r = client.get("/api/v1/cases/search", params={"q": "murder"})
assert r.status_code == 200
data = r.json()
# Vector index returns results grouped by doc_id
assert len(data) >= 1
# Should have found the NSWSC case
mncs = [c["mnc"] for c in data]
assert "[2023] NSWSC 1234" in mncs
def test_get_case_graph(self, client):
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234")
assert r.status_code == 200
data = r.json()
assert data["case"]["label"] == "Case"
assert data["case"]["properties"]["mnc"] == "[2023] NSWSC 1234"
assert len(data["judges"]) == 1
assert data["judges"][0]["properties"]["name"] == "Judge Brown"
assert len(data["charges"]) == 2
def test_get_case_graph_not_found(self, client):
r = client.get("/api/v1/cases/[9999]%20FAKE%201")
assert r.status_code == 404
def test_get_case_graph_fca(self, client):
r = client.get("/api/v1/cases/[2024]%20FCA%20567")
assert r.status_code == 200
data = r.json()
assert data["case"]["properties"]["court"] == "FCA"
assert len(data["judges"]) == 1
# --- Juror Context ---
class TestJurorContextEndpoint:
def test_juror_context(self, client):
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/foreman")
assert r.status_code == 200
data = r.json()
assert data["persona"] == "foreman"
assert data["case_mnc"] == "[2023] NSWSC 1234"
assert data["total_tokens"] >= 0
def test_juror_context_bad_persona(self, client):
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/nonexistent")
assert r.status_code == 404
def test_juror_context_custom_tokens(self, client):
r = client.get(
"/api/v1/cases/[2023]%20NSWSC%201234/juror/skeptic",
params={"max_tokens": 500},
)
assert r.status_code == 200
data = r.json()
assert data["persona"] == "skeptic"
# --- Vector Search ---
class TestVectorSearchEndpoint:
def test_vector_search(self, client):
r = client.get("/api/v1/search", params={"q": "testimony evidence"})
assert r.status_code == 200
data = r.json()
assert len(data) >= 1
assert "chunk_id" in data[0]
assert "text" in data[0]
assert "distance" in data[0]
def test_vector_search_filter_types(self, client):
r = client.get("/api/v1/search", params={"q": "test", "chunk_types": "ruling"})
assert r.status_code == 200
data = r.json()
assert all(r["metadata"]["chunk_type"] == "ruling" for r in data)
def test_vector_search_top_k(self, client):
r = client.get("/api/v1/search", params={"q": "test", "top_k": 1})
assert r.status_code == 200
assert len(r.json()) <= 1
# --- Hybrid Search ---
class TestHybridEndpoint:
def test_hybrid_search(self, client):
r = client.get("/api/v1/hybrid", params={"q": "murder trial"})
assert r.status_code == 200
data = r.json()
assert len(data) >= 1
assert "case_mnc" in data[0]
assert "chunks" in data[0]
assert data[0]["juror_context"] is None # no persona
def test_hybrid_search_with_persona(self, client):
r = client.get("/api/v1/hybrid", params={"q": "test", "persona": "foreman"})
assert r.status_code == 200
data = r.json()
assert len(data) >= 1
# First result should have juror_context if it found the case
if data[0]["juror_context"]:
assert data[0]["juror_context"]["persona"] == "foreman"