Stage 9: add read-only FastAPI query API for juror RAG queries
8 GET endpoints under /api/v1 for health, personas, cases, vector search, juror context, and hybrid search. Includes QueryService composing SubgraphQuery + VectorIndex + GraphDB, Pydantic response models, error handlers, and `serve` CLI mode via uvicorn. 20 new tests, 190 total, zero regressions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d77fe12cfc
commit
6374aea0a2
12 changed files with 809 additions and 0 deletions
0
aucourt_ingest/api/__init__.py
Normal file
0
aucourt_ingest/api/__init__.py
Normal file
41
aucourt_ingest/api/app.py
Normal file
41
aucourt_ingest/api/app.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""FastAPI app factory for the query API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from aucourt_ingest.api import dependencies
|
||||
from aucourt_ingest.api.errors import register_error_handlers
|
||||
from aucourt_ingest.api.routes import router
|
||||
from aucourt_ingest.api.service import QueryService
|
||||
from aucourt_ingest.storage.graph_db import GraphDB
|
||||
|
||||
|
||||
def create_app(graph_db: GraphDB, vector_index, max_tokens: int = 4000) -> FastAPI:
|
||||
"""Create and configure the FastAPI app.
|
||||
|
||||
Args:
|
||||
graph_db: GraphDB instance (InMemoryGraphDB or Neo4jGraphDB)
|
||||
vector_index: VectorIndex instance (or duck-typed fake for tests)
|
||||
max_tokens: Default token budget for juror context assembly
|
||||
"""
|
||||
dependencies._query_service = QueryService(graph_db, vector_index, max_tokens)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
await graph_db.close()
|
||||
|
||||
app = FastAPI(
|
||||
title="AuCourtIngest Query API",
|
||||
description="Read-only juror RAG query API for Australian legal cases",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
register_error_handlers(app)
|
||||
|
||||
return app
|
||||
15
aucourt_ingest/api/dependencies.py
Normal file
15
aucourt_ingest/api/dependencies.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""FastAPI dependency providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from aucourt_ingest.api.service import QueryService
|
||||
|
||||
_query_service: QueryService | None = None
|
||||
|
||||
|
||||
def get_query_service() -> QueryService:
|
||||
if _query_service is None:
|
||||
raise RuntimeError("QueryService not initialised — call create_app() first")
|
||||
return _query_service
|
||||
16
aucourt_ingest/api/errors.py
Normal file
16
aucourt_ingest/api/errors.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Error handlers for the query API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
def register_error_handlers(app: FastAPI) -> None:
|
||||
@app.exception_handler(ValueError)
|
||||
async def value_error_handler(request: Request, exc: ValueError):
|
||||
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
||||
|
||||
@app.exception_handler(KeyError)
|
||||
async def key_error_handler(request: Request, exc: KeyError):
|
||||
return JSONResponse(status_code=404, content={"detail": str(exc)})
|
||||
87
aucourt_ingest/api/routes.py
Normal file
87
aucourt_ingest/api/routes.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Route definitions for the read-only query API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Query, Depends
|
||||
|
||||
from aucourt_ingest.api.dependencies import get_query_service
|
||||
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health(svc=Depends(get_query_service)):
|
||||
return await svc.health()
|
||||
|
||||
|
||||
@router.get("/personas")
|
||||
async def list_personas(svc=Depends(get_query_service)):
|
||||
return svc.list_personas()
|
||||
|
||||
|
||||
@router.get("/cases")
|
||||
async def list_cases(
|
||||
court: str | None = Query(None, description="Filter by court code"),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
svc=Depends(get_query_service),
|
||||
):
|
||||
return await svc.list_cases(court=court, limit=limit, offset=offset)
|
||||
|
||||
|
||||
@router.get("/cases/search")
|
||||
async def search_cases(
|
||||
q: str = Query(..., description="Search query text"),
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
svc=Depends(get_query_service),
|
||||
):
|
||||
return await svc.search_cases(q=q, limit=limit)
|
||||
|
||||
|
||||
@router.get("/cases/{case_mnc}/juror/{persona}")
|
||||
async def get_juror_context(
|
||||
case_mnc: str,
|
||||
persona: str,
|
||||
max_tokens: int | None = Query(None, ge=100, le=16000),
|
||||
svc=Depends(get_query_service),
|
||||
):
|
||||
try:
|
||||
return await svc.get_juror_context(
|
||||
case_mnc, persona, max_tokens,
|
||||
)
|
||||
except KeyError as e:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/cases/{case_mnc}")
|
||||
async def get_case_graph(case_mnc: str, svc=Depends(get_query_service)):
|
||||
result = await svc.get_case_graph(case_mnc)
|
||||
if result is None:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail=f"Case not found: {case_mnc}")
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def vector_search(
|
||||
q: str = Query(..., description="Query text"),
|
||||
top_k: int = Query(10, ge=1, le=50),
|
||||
chunk_types: str | None = Query(None, description="Comma-separated chunk types"),
|
||||
doc_ids: str | None = Query(None, description="Comma-separated doc IDs"),
|
||||
svc=Depends(get_query_service),
|
||||
):
|
||||
types = chunk_types.split(",") if chunk_types else None
|
||||
docs = doc_ids.split(",") if doc_ids else None
|
||||
return svc.vector_search(q, top_k=top_k, chunk_types=types, doc_ids=docs)
|
||||
|
||||
|
||||
@router.get("/hybrid")
|
||||
async def hybrid_search(
|
||||
q: str = Query(..., description="Query text"),
|
||||
persona: str | None = Query(None, description="Juror persona name"),
|
||||
top_k: int = Query(10, ge=1, le=50),
|
||||
max_tokens: int | None = Query(None, ge=100, le=16000),
|
||||
svc=Depends(get_query_service),
|
||||
):
|
||||
return await svc.hybrid_search(q, persona_name=persona, top_k=top_k, max_tokens=max_tokens)
|
||||
74
aucourt_ingest/api/schemas.py
Normal file
74
aucourt_ingest/api/schemas.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""Pydantic v2 response models for the query API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PersonaInfo(BaseModel):
|
||||
name: str
|
||||
anchor_nodes: list[str]
|
||||
edge_types: list[str]
|
||||
chunk_types: list[str]
|
||||
|
||||
|
||||
class CaseSummary(BaseModel):
|
||||
mnc: str
|
||||
court: str = ""
|
||||
date: str = ""
|
||||
jurisdiction: str = ""
|
||||
matter_type: str = ""
|
||||
verdict: str = ""
|
||||
|
||||
|
||||
class ChunkResult(BaseModel):
|
||||
chunk_id: str
|
||||
text: str | None = None
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
distance: float | None = None
|
||||
|
||||
|
||||
class JurorContextResponse(BaseModel):
|
||||
persona: str
|
||||
case_mnc: str
|
||||
context_text: str
|
||||
source_chunk_ids: list[str]
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CaseGraphNode(BaseModel):
|
||||
node_id: str
|
||||
label: str
|
||||
properties: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CaseGraphRelationship(BaseModel):
|
||||
rel_id: str
|
||||
from_id: str
|
||||
to_id: str
|
||||
rel_type: str
|
||||
properties: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CaseGraphSummary(BaseModel):
|
||||
case: CaseGraphNode
|
||||
judges: list[CaseGraphNode] = Field(default_factory=list)
|
||||
charges: list[CaseGraphNode] = Field(default_factory=list)
|
||||
rulings: list[CaseGraphNode] = Field(default_factory=list)
|
||||
witnesses: list[CaseGraphNode] = Field(default_factory=list)
|
||||
chunks: list[CaseGraphNode] = Field(default_factory=list)
|
||||
relationships: list[CaseGraphRelationship] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HybridResult(BaseModel):
|
||||
case_mnc: str
|
||||
chunks: list[ChunkResult] = Field(default_factory=list)
|
||||
juror_context: JurorContextResponse | None = None
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
nodes: int
|
||||
edges: int
|
||||
cases: int
|
||||
chunks: int
|
||||
194
aucourt_ingest/api/service.py
Normal file
194
aucourt_ingest/api/service.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
"""QueryService — composes SubgraphQuery + VectorIndex + GraphDB for the read API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from aucourt_ingest.jury.personas import PERSONAS, get_persona, all_persona_names
|
||||
from aucourt_ingest.jury.subgraph_query import SubgraphQuery
|
||||
from aucourt_ingest.storage.graph_db import GraphDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryService:
|
||||
"""Read-only query service for juror RAG queries."""
|
||||
|
||||
def __init__(self, graph_db: GraphDB, vector_index, max_tokens: int = 4000):
|
||||
self._graph_db = graph_db
|
||||
self._vector_index = vector_index
|
||||
self._subgraph = SubgraphQuery(graph_db)
|
||||
self._max_tokens = max_tokens
|
||||
|
||||
def list_personas(self) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
"anchor_nodes": p.anchor_nodes,
|
||||
"edge_types": p.edge_types,
|
||||
"chunk_types": p.chunk_types,
|
||||
}
|
||||
for name, p in PERSONAS.items()
|
||||
]
|
||||
|
||||
async def list_cases(self, court: str | None = None,
|
||||
limit: int = 50, offset: int = 0) -> list[dict]:
|
||||
nodes = await self._graph_db.query_nodes("Case")
|
||||
cases = []
|
||||
for node in nodes:
|
||||
if court and node.properties.get("court") != court:
|
||||
continue
|
||||
cases.append({
|
||||
"mnc": node.properties.get("mnc", node.node_id),
|
||||
"court": node.properties.get("court", ""),
|
||||
"date": node.properties.get("date", ""),
|
||||
"jurisdiction": node.properties.get("jurisdiction", ""),
|
||||
"matter_type": node.properties.get("matter_type", ""),
|
||||
"verdict": node.properties.get("verdict", ""),
|
||||
})
|
||||
cases.sort(key=lambda c: c.get("date", ""), reverse=True)
|
||||
return cases[offset:offset + limit]
|
||||
|
||||
async def search_cases(self, q: str, limit: int = 10) -> list[dict]:
|
||||
results = self._vector_index.query(q, top_k=limit)
|
||||
seen: set[str] = set()
|
||||
cases = []
|
||||
for r in results:
|
||||
doc_id = r["metadata"].get("doc_id", "")
|
||||
if doc_id and doc_id not in seen:
|
||||
seen.add(doc_id)
|
||||
cases.append(await self._get_case_by_mnc(doc_id))
|
||||
return [c for c in cases if c]
|
||||
|
||||
async def get_juror_context(self, case_mnc: str, persona_name: str,
|
||||
max_tokens: int | None = None) -> dict:
|
||||
persona = get_persona(persona_name)
|
||||
budget = max_tokens or self._max_tokens
|
||||
ctx = await self._subgraph.get_juror_context(case_mnc, persona, budget)
|
||||
return {
|
||||
"persona": ctx.persona,
|
||||
"case_mnc": ctx.case_mnc,
|
||||
"context_text": ctx.context_text,
|
||||
"source_chunk_ids": ctx.source_chunk_ids,
|
||||
"total_tokens": ctx.total_tokens,
|
||||
}
|
||||
|
||||
def vector_search(self, query: str, top_k: int = 10,
|
||||
chunk_types: list[str] | None = None,
|
||||
doc_ids: list[str] | None = None) -> list[dict]:
|
||||
return self._vector_index.query(
|
||||
query, top_k=top_k, chunk_types=chunk_types, doc_ids=doc_ids,
|
||||
)
|
||||
|
||||
async def hybrid_search(self, query: str, persona_name: str | None = None,
|
||||
top_k: int = 10,
|
||||
max_tokens: int | None = None) -> list[dict]:
|
||||
vector_results = self._vector_index.query(query, top_k=top_k)
|
||||
# Group by doc_id
|
||||
by_doc: dict[str, list[dict]] = {}
|
||||
for r in vector_results:
|
||||
doc_id = r["metadata"].get("doc_id", "unknown")
|
||||
by_doc.setdefault(doc_id, []).append(r)
|
||||
|
||||
results = []
|
||||
for doc_mnc, chunks in by_doc.items():
|
||||
entry: dict = {
|
||||
"case_mnc": doc_mnc,
|
||||
"chunks": chunks,
|
||||
"juror_context": None,
|
||||
}
|
||||
if persona_name and persona_name in PERSONAS:
|
||||
ctx = await self.get_juror_context(
|
||||
doc_mnc, persona_name, max_tokens,
|
||||
)
|
||||
entry["juror_context"] = ctx
|
||||
results.append(entry)
|
||||
return results
|
||||
|
||||
async def get_case_graph(self, case_mnc: str) -> dict | None:
|
||||
case_id = f"Case:{case_mnc}"
|
||||
case_node = await self._graph_db.get_node(case_id)
|
||||
if case_node is None:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"case": _node_to_dict(case_node),
|
||||
"judges": [],
|
||||
"charges": [],
|
||||
"rulings": [],
|
||||
"witnesses": [],
|
||||
"chunks": [],
|
||||
"relationships": [],
|
||||
}
|
||||
|
||||
rels = await self._graph_db.get_relationships(case_id, direction="both")
|
||||
neighbor_ids: set[str] = set()
|
||||
|
||||
for rel in rels:
|
||||
result["relationships"].append(_rel_to_dict(rel))
|
||||
neighbor_id = rel.to_id if rel.from_id == case_id else rel.from_id
|
||||
neighbor_ids.add(neighbor_id)
|
||||
|
||||
for nid in neighbor_ids:
|
||||
node = await self._graph_db.get_node(nid)
|
||||
if node is None:
|
||||
continue
|
||||
bucket = node.label.lower() + "s" if node.label != "Witness" else "witnesses"
|
||||
# Map labels to plural keys
|
||||
label_map = {
|
||||
"Judge": "judges",
|
||||
"Charge": "charges",
|
||||
"Ruling": "rulings",
|
||||
"Witness": "witnesses",
|
||||
"Chunk": "chunks",
|
||||
}
|
||||
key = label_map.get(node.label)
|
||||
if key and key in result:
|
||||
result[key].append(_node_to_dict(node))
|
||||
|
||||
return result
|
||||
|
||||
async def health(self) -> dict:
|
||||
total_nodes = await self._graph_db.node_count()
|
||||
total_edges = await self._graph_db.relationship_count()
|
||||
cases = await self._graph_db.node_count("Case")
|
||||
chunks = await self._graph_db.node_count("Chunk")
|
||||
return {
|
||||
"status": "ok",
|
||||
"nodes": total_nodes,
|
||||
"edges": total_edges,
|
||||
"cases": cases,
|
||||
"chunks": chunks,
|
||||
}
|
||||
|
||||
async def _get_case_by_mnc(self, mnc: str) -> dict | None:
|
||||
case_id = f"Case:{mnc}"
|
||||
node = await self._graph_db.get_node(case_id)
|
||||
if node is None:
|
||||
return None
|
||||
return {
|
||||
"mnc": node.properties.get("mnc", ""),
|
||||
"court": node.properties.get("court", ""),
|
||||
"date": node.properties.get("date", ""),
|
||||
"jurisdiction": node.properties.get("jurisdiction", ""),
|
||||
"matter_type": node.properties.get("matter_type", ""),
|
||||
"verdict": node.properties.get("verdict", ""),
|
||||
}
|
||||
|
||||
|
||||
def _node_to_dict(node) -> dict:
|
||||
return {
|
||||
"node_id": node.node_id,
|
||||
"label": node.label,
|
||||
"properties": node.properties,
|
||||
}
|
||||
|
||||
|
||||
def _rel_to_dict(rel) -> dict:
|
||||
return {
|
||||
"rel_id": rel.rel_id,
|
||||
"from_id": rel.from_id,
|
||||
"to_id": rel.to_id,
|
||||
"rel_type": rel.rel_type,
|
||||
"properties": rel.properties,
|
||||
}
|
||||
|
|
@ -55,6 +55,14 @@ class TelegramConfig:
|
|||
enabled: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8000
|
||||
graph_backend: str = "memory"
|
||||
default_max_tokens: int = 4000
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""Root config — loaded from config.toml."""
|
||||
|
|
@ -64,6 +72,7 @@ class AppConfig:
|
|||
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||
telegram: TelegramConfig = field(default_factory=TelegramConfig)
|
||||
server: ServerConfig = field(default_factory=ServerConfig)
|
||||
user_agent: str = "AuCourtIngest/0.1 (legal research)"
|
||||
|
||||
@classmethod
|
||||
|
|
@ -138,6 +147,15 @@ class AppConfig:
|
|||
enabled=tg.get("enabled", False),
|
||||
)
|
||||
|
||||
# Server
|
||||
srv = raw.get("server", {})
|
||||
config.server = ServerConfig(
|
||||
host=srv.get("host", "127.0.0.1"),
|
||||
port=srv.get("port", 8000),
|
||||
graph_backend=srv.get("graph_backend", "memory"),
|
||||
default_max_tokens=srv.get("default_max_tokens", 4000),
|
||||
)
|
||||
|
||||
config.user_agent = raw.get("user_agent", "AuCourtIngest/0.1 (legal research)")
|
||||
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -57,6 +57,13 @@ def build_parser() -> argparse.ArgumentParser:
|
|||
p_proc.add_argument("--db", default="data/meta.db", help="MetaDB path")
|
||||
p_proc.add_argument("--raw-dir", default="data/raw", help="Raw document storage dir")
|
||||
|
||||
# serve
|
||||
p_serve = sub.add_parser("serve", help="Start the read-only query API server")
|
||||
p_serve.add_argument("--host", default=None, help="Bind host (default from config)")
|
||||
p_serve.add_argument("--port", type=int, default=None, help="Bind port (default from config)")
|
||||
p_serve.add_argument("--graph-backend", default=None,
|
||||
help="Graph backend: memory|neo4j (default from config)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
|
@ -291,6 +298,44 @@ async def cmd_process(args):
|
|||
print(f"Processing complete: {stats}")
|
||||
|
||||
|
||||
def cmd_serve(args):
|
||||
"""Start the FastAPI query server."""
|
||||
config = AppConfig.load(args.config)
|
||||
|
||||
host = args.host or config.server.host
|
||||
port = args.port or config.server.port
|
||||
backend = args.graph_backend or config.server.graph_backend
|
||||
max_tokens = config.server.default_max_tokens
|
||||
|
||||
if backend == "memory":
|
||||
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
|
||||
graph_db = InMemoryGraphDB()
|
||||
else:
|
||||
from aucourt_ingest.storage.graph_db import Neo4jGraphDB
|
||||
graph_db = Neo4jGraphDB(
|
||||
uri=config.storage.neo4j_uri,
|
||||
user=config.storage.neo4j_user,
|
||||
password=config.storage.neo4j_password,
|
||||
database=config.storage.neo4j_database,
|
||||
)
|
||||
|
||||
# VectorIndex requires ChromaDB directory — skip for memory backend in tests
|
||||
vector_index = None
|
||||
if backend != "memory":
|
||||
from aucourt_ingest.storage.vector_index import VectorIndex
|
||||
vector_index = VectorIndex(str(config.storage.chromadb_dir))
|
||||
else:
|
||||
from aucourt_ingest.storage.vector_index import VectorIndex
|
||||
vector_index = VectorIndex(str(config.storage.chromadb_dir))
|
||||
|
||||
from aucourt_ingest.api.app import create_app
|
||||
app = create_app(graph_db, vector_index, max_tokens)
|
||||
|
||||
import uvicorn
|
||||
print(f"Starting AuCourtIngest API on {host}:{port} (graph={backend})")
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
async def async_main():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
|
@ -309,6 +354,8 @@ async def async_main():
|
|||
await cmd_audit(args)
|
||||
elif args.mode == "process":
|
||||
await cmd_process(args)
|
||||
elif args.mode == "serve":
|
||||
cmd_serve(args)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
|||
|
|
@ -76,3 +76,9 @@ embedding_batch_size = 100
|
|||
bot_token = ""
|
||||
chat_id = ""
|
||||
enabled = false
|
||||
|
||||
[server]
|
||||
host = "127.0.0.1"
|
||||
port = 8000
|
||||
graph_backend = "memory"
|
||||
default_max_tokens = 4000
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ dependencies = [
|
|||
"aiosqlite>=0.19",
|
||||
"anthropic>=0.25",
|
||||
"apscheduler>=3.10",
|
||||
"fastapi>=0.110",
|
||||
"uvicorn[standard]>=0.29",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
309
tests/test_api.py
Normal file
309
tests/test_api.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
"""Tests for the Query API (Stage 9).
|
||||
|
||||
All tests use InMemoryGraphDB + FakeVectorIndex — zero external services.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from aucourt_ingest.api.app import create_app
|
||||
from aucourt_ingest.jury.personas import PERSONAS
|
||||
from aucourt_ingest.models import CaseMeta, Chunk
|
||||
from aucourt_ingest.processing.graph_builder import GraphBuilder
|
||||
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
|
||||
|
||||
|
||||
class FakeVectorIndex:
|
||||
"""Duck-typed VectorIndex for testing. Pre-seeded with results."""
|
||||
|
||||
def __init__(self, results: list[dict] | None = None):
|
||||
self._results = results or []
|
||||
self._stored: list = []
|
||||
|
||||
def store_chunks(self, chunks: list) -> None:
|
||||
self._stored.extend(chunks)
|
||||
|
||||
def query(self, text: str, top_k: int = 10,
|
||||
chunk_types: list[str] | None = None,
|
||||
doc_ids: list[str] | None = None,
|
||||
embedding: list[float] | None = None) -> list[dict]:
|
||||
results = self._results[:top_k]
|
||||
if chunk_types:
|
||||
results = [r for r in results
|
||||
if r["metadata"].get("chunk_type") in chunk_types]
|
||||
if doc_ids:
|
||||
results = [r for r in results
|
||||
if r["metadata"].get("doc_id") in doc_ids]
|
||||
return results
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self._stored)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seeded_graph():
|
||||
"""InMemoryGraphDB seeded with two cases, chunks, and relationships."""
|
||||
db = InMemoryGraphDB()
|
||||
builder = GraphBuilder(db)
|
||||
|
||||
# Case 1
|
||||
meta1 = CaseMeta(
|
||||
case_name="R v Smith",
|
||||
mnc="[2023] NSWSC 1234",
|
||||
court="NSWSC",
|
||||
date_delivered="2023-06-15",
|
||||
jurisdiction="NSW",
|
||||
matter_type="criminal",
|
||||
charges=["murder", "assault"],
|
||||
verdict="guilty",
|
||||
judge=["Judge Brown"],
|
||||
)
|
||||
chunks1 = [
|
||||
Chunk(chunk_id="c1-1", doc_id=meta1.mnc, chunk_type="opening",
|
||||
sequence=0, text="The Crown alleges that on 15 March 2023..."),
|
||||
Chunk(chunk_id="c1-2", doc_id=meta1.mnc, chunk_type="testimony",
|
||||
sequence=1, text="Dr Jones testified that the victim had..."),
|
||||
Chunk(chunk_id="c1-3", doc_id=meta1.mnc, chunk_type="ruling",
|
||||
sequence=2, text="The court finds that exhibit A is admissible."),
|
||||
Chunk(chunk_id="c1-4", doc_id=meta1.mnc, chunk_type="judgment",
|
||||
sequence=3, text="The accused is found guilty of murder."),
|
||||
]
|
||||
for c in chunks1:
|
||||
c.token_count = len(c.text) // 4
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(builder.build_full(meta1, chunks1))
|
||||
|
||||
# Case 2
|
||||
meta2 = CaseMeta(
|
||||
case_name="R v Jones",
|
||||
mnc="[2024] FCA 567",
|
||||
court="FCA",
|
||||
date_delivered="2024-01-10",
|
||||
jurisdiction="CTH",
|
||||
matter_type="civil",
|
||||
charges=[],
|
||||
verdict="civil_judgment",
|
||||
judge=["Judge Davis"],
|
||||
)
|
||||
chunks2 = [
|
||||
Chunk(chunk_id="c2-1", doc_id=meta2.mnc, chunk_type="opening",
|
||||
sequence=0, text="This matter concerns an appeal against..."),
|
||||
Chunk(chunk_id="c2-2", doc_id=meta2.mnc, chunk_type="closing",
|
||||
sequence=1, text="Counsel for the applicant submitted that..."),
|
||||
]
|
||||
for c in chunks2:
|
||||
c.token_count = len(c.text) // 4
|
||||
asyncio.get_event_loop().run_until_complete(builder.build_full(meta2, chunks2))
|
||||
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_vector_index():
|
||||
"""FakeVectorIndex with pre-seeded search results."""
|
||||
return FakeVectorIndex([
|
||||
{
|
||||
"chunk_id": "c1-2",
|
||||
"text": "Dr Jones testified that the victim had...",
|
||||
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "testimony", "sequence": 1},
|
||||
"distance": 0.12,
|
||||
},
|
||||
{
|
||||
"chunk_id": "c1-3",
|
||||
"text": "The court finds that exhibit A is admissible.",
|
||||
"metadata": {"doc_id": "[2023] NSWSC 1234", "chunk_type": "ruling", "sequence": 2},
|
||||
"distance": 0.34,
|
||||
},
|
||||
{
|
||||
"chunk_id": "c2-1",
|
||||
"text": "This matter concerns an appeal against...",
|
||||
"metadata": {"doc_id": "[2024] FCA 567", "chunk_type": "opening", "sequence": 0},
|
||||
"distance": 0.56,
|
||||
},
|
||||
])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(seeded_graph, fake_vector_index):
|
||||
app = create_app(seeded_graph, fake_vector_index, max_tokens=4000)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# --- Health ---
|
||||
|
||||
class TestHealthEndpoint:
|
||||
def test_health_returns_ok(self, client):
|
||||
r = client.get("/api/v1/health")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["cases"] == 2
|
||||
assert data["chunks"] == 6
|
||||
assert data["nodes"] > 0
|
||||
assert data["edges"] > 0
|
||||
|
||||
|
||||
# --- Personas ---
|
||||
|
||||
class TestPersonaEndpoints:
|
||||
def test_list_personas(self, client):
|
||||
r = client.get("/api/v1/personas")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == len(PERSONAS)
|
||||
names = [p["name"] for p in data]
|
||||
assert "foreman" in names
|
||||
assert "nurse" in names
|
||||
|
||||
def test_persona_has_required_fields(self, client):
|
||||
r = client.get("/api/v1/personas")
|
||||
data = r.json()
|
||||
for p in data:
|
||||
assert "name" in p
|
||||
assert "anchor_nodes" in p
|
||||
assert "edge_types" in p
|
||||
assert "chunk_types" in p
|
||||
|
||||
|
||||
# --- Cases ---
|
||||
|
||||
class TestCaseEndpoints:
|
||||
def test_list_cases(self, client):
|
||||
r = client.get("/api/v1/cases")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 2
|
||||
|
||||
def test_list_cases_filter_court(self, client):
|
||||
r = client.get("/api/v1/cases", params={"court": "NSWSC"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["mnc"] == "[2023] NSWSC 1234"
|
||||
|
||||
def test_list_cases_filter_empty(self, client):
|
||||
r = client.get("/api/v1/cases", params={"court": "HCA"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 0
|
||||
|
||||
def test_list_cases_pagination(self, client):
|
||||
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
|
||||
def test_list_cases_pagination_offset(self, client):
|
||||
r = client.get("/api/v1/cases", params={"limit": 1, "offset": 1})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
# Second case is different from first
|
||||
r2 = client.get("/api/v1/cases", params={"limit": 1, "offset": 0})
|
||||
assert data[0]["mnc"] != r2.json()[0]["mnc"]
|
||||
|
||||
def test_search_cases(self, client):
|
||||
r = client.get("/api/v1/cases/search", params={"q": "murder"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
# Vector index returns results grouped by doc_id
|
||||
assert len(data) >= 1
|
||||
# Should have found the NSWSC case
|
||||
mncs = [c["mnc"] for c in data]
|
||||
assert "[2023] NSWSC 1234" in mncs
|
||||
|
||||
def test_get_case_graph(self, client):
|
||||
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["case"]["label"] == "Case"
|
||||
assert data["case"]["properties"]["mnc"] == "[2023] NSWSC 1234"
|
||||
assert len(data["judges"]) == 1
|
||||
assert data["judges"][0]["properties"]["name"] == "Judge Brown"
|
||||
assert len(data["charges"]) == 2
|
||||
|
||||
def test_get_case_graph_not_found(self, client):
|
||||
r = client.get("/api/v1/cases/[9999]%20FAKE%201")
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_get_case_graph_fca(self, client):
|
||||
r = client.get("/api/v1/cases/[2024]%20FCA%20567")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["case"]["properties"]["court"] == "FCA"
|
||||
assert len(data["judges"]) == 1
|
||||
|
||||
|
||||
# --- Juror Context ---
|
||||
|
||||
class TestJurorContextEndpoint:
|
||||
def test_juror_context(self, client):
|
||||
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/foreman")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["persona"] == "foreman"
|
||||
assert data["case_mnc"] == "[2023] NSWSC 1234"
|
||||
assert data["total_tokens"] >= 0
|
||||
|
||||
def test_juror_context_bad_persona(self, client):
|
||||
r = client.get("/api/v1/cases/[2023]%20NSWSC%201234/juror/nonexistent")
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_juror_context_custom_tokens(self, client):
|
||||
r = client.get(
|
||||
"/api/v1/cases/[2023]%20NSWSC%201234/juror/skeptic",
|
||||
params={"max_tokens": 500},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["persona"] == "skeptic"
|
||||
|
||||
|
||||
# --- Vector Search ---
|
||||
|
||||
class TestVectorSearchEndpoint:
|
||||
def test_vector_search(self, client):
|
||||
r = client.get("/api/v1/search", params={"q": "testimony evidence"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) >= 1
|
||||
assert "chunk_id" in data[0]
|
||||
assert "text" in data[0]
|
||||
assert "distance" in data[0]
|
||||
|
||||
def test_vector_search_filter_types(self, client):
|
||||
r = client.get("/api/v1/search", params={"q": "test", "chunk_types": "ruling"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert all(r["metadata"]["chunk_type"] == "ruling" for r in data)
|
||||
|
||||
def test_vector_search_top_k(self, client):
|
||||
r = client.get("/api/v1/search", params={"q": "test", "top_k": 1})
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()) <= 1
|
||||
|
||||
|
||||
# --- Hybrid Search ---
|
||||
|
||||
class TestHybridEndpoint:
|
||||
def test_hybrid_search(self, client):
|
||||
r = client.get("/api/v1/hybrid", params={"q": "murder trial"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) >= 1
|
||||
assert "case_mnc" in data[0]
|
||||
assert "chunks" in data[0]
|
||||
assert data[0]["juror_context"] is None # no persona
|
||||
|
||||
def test_hybrid_search_with_persona(self, client):
|
||||
r = client.get("/api/v1/hybrid", params={"q": "test", "persona": "foreman"})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) >= 1
|
||||
# First result should have juror_context if it found the case
|
||||
if data[0]["juror_context"]:
|
||||
assert data[0]["juror_context"]["persona"] == "foreman"
|
||||
Loading…
Reference in a new issue