aucourt-ingest/tests/test_orchestrator.py

422 lines
14 KiB
Python
Raw Normal View History

"""Tests for Orchestrator, TelegramAlert, and watch/backfill wiring."""
from __future__ import annotations
import asyncio
from typing import Any
import pytest
from unittest.mock import AsyncMock, MagicMock
from aucourt_ingest.models import CaseMeta, Chunk, FetchStatus, RawDocument, Verdict, MatterType
from aucourt_ingest.orchestrator import Orchestrator, NoOpAlert
from aucourt_ingest.utils.telegram import TelegramAlert, ALERT_TEMPLATES
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
# ── TelegramAlert ──
class TestTelegramAlert:
def test_disabled_by_default(self):
alert = TelegramAlert()
assert not alert.enabled
def test_enabled_when_configured(self):
alert = TelegramAlert(bot_token="tok123", chat_id="chat456", enabled=True)
assert alert.enabled
def test_not_enabled_without_token(self):
alert = TelegramAlert(chat_id="chat456", enabled=True)
assert not alert.enabled
def test_not_enabled_without_chat_id(self):
alert = TelegramAlert(bot_token="tok123", enabled=True)
assert not alert.enabled
def test_not_enabled_when_flag_false(self):
alert = TelegramAlert(bot_token="tok123", chat_id="chat456", enabled=False)
assert not alert.enabled
@pytest.mark.asyncio
async def test_send_disabled_returns_false(self):
alert = TelegramAlert()
assert await alert.send("test message") is False
@pytest.mark.asyncio
async def test_alert_disabled_returns_false(self):
alert = TelegramAlert()
assert await alert.alert("source_degraded", source_id="test") is False
def test_alert_templates(self):
assert "source_degraded" in ALERT_TEMPLATES
assert "daily_summary" in ALERT_TEMPLATES
assert "milestone" in ALERT_TEMPLATES
assert "fetch_error" in ALERT_TEMPLATES
def test_alert_template_format(self):
template = ALERT_TEMPLATES["daily_summary"]
result = template.format(new_docs=42, nodes=1000, errors=3)
assert "42" in result
assert "1000" in result
@pytest.mark.asyncio
async def test_alert_with_missing_key(self):
alert = TelegramAlert(bot_token="tok", chat_id="chat", enabled=True)
# Patch _get_client to avoid real HTTP calls
alert._get_client = MagicMock()
# Missing template key — should not crash
result = await alert.alert("daily_summary", new_docs=42) # missing nodes, errors
assert result is False # httpx mock returns None, so post fails
# ── NoOpAlert ──
class TestNoOpAlert:
@pytest.mark.asyncio
async def test_send_returns_false(self):
alert = NoOpAlert()
assert await alert.send("anything") is False
@pytest.mark.asyncio
async def test_alert_returns_false(self):
alert = NoOpAlert()
assert await alert.alert("source_degraded") is False
@pytest.mark.asyncio
async def test_close_is_noop(self):
alert = NoOpAlert()
await alert.close() # Should not raise
# ── Mock processing pipeline ──
class MockPipeline:
"""Mock pipeline that returns predictable results."""
def __init__(self, fail_at: str | None = None):
self._fail_at = fail_at
self.call_counts: dict[str, int] = {"parse": 0, "extract_meta": 0, "chunk": 0, "embed": 0}
async def parse(self, raw: RawDocument) -> tuple[str, str]:
self.call_counts["parse"] += 1
if self._fail_at == "parse":
raise RuntimeError("parse failed")
return f"Parsed: {raw.raw_text[:50]}", raw.format
async def extract_meta(self, doc_id: str, text: str) -> CaseMeta | None:
self.call_counts["extract_meta"] += 1
if self._fail_at == "extract_meta":
raise RuntimeError("extract failed")
return CaseMeta(
case_name=f"Test Case {doc_id}",
mnc=doc_id,
court="NSWSC",
judge=["Judge Test"],
charges=["test charge"],
verdict=Verdict.GUILTY,
matter_type=MatterType.CRIMINAL,
)
async def chunk(self, doc_id: str, text: str, meta: CaseMeta | None = None) -> list[Chunk]:
self.call_counts["chunk"] += 1
if self._fail_at == "chunk":
raise RuntimeError("chunk failed")
return [
Chunk(
chunk_id=f"{doc_id}_c0",
doc_id=doc_id,
chunk_type="testimony",
sequence=0,
text="Test chunk text",
),
]
async def embed(self, chunks: list[Chunk]) -> list[Chunk]:
self.call_counts["embed"] += 1
if self._fail_at == "embed":
raise RuntimeError("embed failed")
for c in chunks:
c.embedding = [0.1, 0.2, 0.3]
return chunks
class MockMetaDB:
"""Mock MetaDB for testing orchestrator."""
def __init__(self):
self._documents: dict[str, dict] = {}
self._queue: list[dict] = []
self._status_updates: list[tuple[str, str, str | None]] = []
self._meta_updates: list[tuple[str, dict]] = []
async def connect(self):
pass
async def dequeue(self) -> dict | None:
if self._queue:
return self._queue.pop(0)
return None
async def update_status(self, doc_id: str, status: str, error: str | None = None):
self._status_updates.append((doc_id, status, error))
async def update_doc_meta(self, doc_id: str, meta: dict):
self._meta_updates.append((doc_id, meta))
async def get_documents_by_status(self, status: str) -> list[dict]:
return [
{"doc_id": doc_id, "source_id": data.get("source_id", "")}
for doc_id, data in self._documents.items()
if data.get("status") == status
]
async def close(self):
pass
def add_to_queue(self, doc_id: str, source_id: str = "test"):
self._queue.append({"doc_id": doc_id, "source_id": source_id})
self._documents[doc_id] = {"status": FetchStatus.PENDING, "source_id": source_id}
class MockDocStore:
"""Mock DocStore for testing orchestrator."""
def __init__(self):
self._documents: dict[str, str] = {}
def store(self, source_id: str, doc_id: str, content: str):
self._documents[(source_id, doc_id)] = content
async def load(self, source_id: str, doc_id: str) -> str | None:
return self._documents.get((source_id, doc_id))
def exists(self, source_id: str, doc_id: str) -> bool:
return (source_id, doc_id) in self._documents
class MockVectorIndex:
"""Mock vector index for testing."""
def __init__(self):
self._stored: list[tuple[str, list[Chunk]]] = []
async def store_chunks(self, doc_id: str, chunks: list[Chunk]):
self._stored.append((doc_id, chunks))
# ── Orchestrator ──
class TestOrchestrator:
@pytest.mark.asyncio
async def test_process_queue_empty(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
orch = Orchestrator(meta_db, doc_store)
stats = await orch.process_queue()
assert stats["processed"] == 0
assert stats["errors"] == 0
@pytest.mark.asyncio
async def test_process_single_document(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Raw content for doc1")
meta_db.add_to_queue("doc1", "test_src")
pipeline = MockPipeline()
graph_db = InMemoryGraphDB()
vector_index = MockVectorIndex()
orch = Orchestrator(meta_db, doc_store, graph_db, vector_index, pipeline)
stats = await orch.process_queue()
assert stats["processed"] == 1
assert stats["errors"] == 0
assert stats["parsed"] == 1
assert stats["embedded"] == 1
assert stats["graphed"] == 1
@pytest.mark.asyncio
async def test_process_multiple_documents(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
for i in range(3):
doc_id = f"doc{i}"
doc_store.store("test_src", doc_id, f"Content {i}")
meta_db.add_to_queue(doc_id, "test_src")
pipeline = MockPipeline()
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
stats = await orch.process_queue()
assert stats["processed"] == 3
assert stats["parsed"] == 3
assert stats["errors"] == 0
@pytest.mark.asyncio
async def test_process_with_limit(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
for i in range(5):
doc_store.store("test_src", f"doc{i}", f"Content {i}")
meta_db.add_to_queue(f"doc{i}", "test_src")
pipeline = MockPipeline()
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
stats = await orch.process_queue(limit=2)
assert stats["processed"] == 2
@pytest.mark.asyncio
async def test_process_missing_document_marks_failed(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
# No document stored
meta_db.add_to_queue("missing_doc", "test_src")
orch = Orchestrator(meta_db, doc_store)
stats = await orch.process_queue()
assert stats["errors"] == 1
failed_updates = [u for u in meta_db._status_updates if u[1] == FetchStatus.FAILED]
assert len(failed_updates) == 1
@pytest.mark.asyncio
async def test_process_pipeline_error_marks_failed(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Content")
meta_db.add_to_queue("doc1", "test_src")
pipeline = MockPipeline(fail_at="parse")
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
stats = await orch.process_queue()
assert stats["errors"] == 1
failed_updates = [u for u in meta_db._status_updates if u[1] == FetchStatus.FAILED]
assert len(failed_updates) == 1
@pytest.mark.asyncio
async def test_status_transitions(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Content")
meta_db.add_to_queue("doc1", "test_src")
pipeline = MockPipeline()
graph_db = InMemoryGraphDB()
vector_index = MockVectorIndex()
orch = Orchestrator(meta_db, doc_store, graph_db, vector_index, pipeline)
await orch.process_queue()
statuses = [u[1] for u in meta_db._status_updates]
assert FetchStatus.PARSED in statuses
assert FetchStatus.EMBEDDED in statuses
assert FetchStatus.GRAPHED in statuses
@pytest.mark.asyncio
async def test_graph_built_for_document(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Content")
meta_db.add_to_queue("doc1", "test_src")
pipeline = MockPipeline()
graph_db = InMemoryGraphDB()
vector_index = MockVectorIndex()
orch = Orchestrator(meta_db, doc_store, graph_db, vector_index, pipeline)
await orch.process_queue()
assert await graph_db.node_count("Case") == 1
assert await graph_db.node_count("Judge") == 1
assert await graph_db.node_count("Charge") == 1
assert await graph_db.node_count("Chunk") == 1
assert await graph_db.relationship_count("HEARD_BY") == 1
assert await graph_db.relationship_count("CHARGED_WITH") == 1
@pytest.mark.asyncio
async def test_vector_index_stores_chunks(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Content")
meta_db.add_to_queue("doc1", "test_src")
pipeline = MockPipeline()
graph_db = InMemoryGraphDB()
vector_index = MockVectorIndex()
orch = Orchestrator(meta_db, doc_store, graph_db, vector_index, pipeline)
await orch.process_queue()
assert len(vector_index._stored) == 1
doc_id, chunks = vector_index._stored[0]
assert doc_id == "doc1"
assert len(chunks) == 1
@pytest.mark.asyncio
async def test_process_document_single(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Content")
pipeline = MockPipeline()
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
stats = await orch.process_document("doc1", "test_src")
assert stats["processed"] == 1
assert stats["errors"] == 0
@pytest.mark.asyncio
async def test_process_document_error(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Content")
pipeline = MockPipeline(fail_at="chunk")
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
stats = await orch.process_document("doc1", "test_src")
assert stats["errors"] == 1
assert stats["processed"] == 0
@pytest.mark.asyncio
async def test_meta_extracted_and_stored(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Content")
meta_db.add_to_queue("doc1", "test_src")
pipeline = MockPipeline()
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
await orch.process_queue()
assert len(meta_db._meta_updates) == 1
_, meta = meta_db._meta_updates[0]
assert "case_name" in meta
assert "mnc" in meta
@pytest.mark.asyncio
async def test_without_pipeline_still_parses(self):
"""Without pipeline, should still mark as parsed with raw text."""
meta_db = MockMetaDB()
doc_store = MockDocStore()
doc_store.store("test_src", "doc1", "Content")
meta_db.add_to_queue("doc1", "test_src")
orch = Orchestrator(meta_db, doc_store)
stats = await orch.process_queue()
assert stats["processed"] == 1
assert stats["parsed"] == 1
assert stats["embedded"] == 0 # No pipeline → no chunking/embedding
@pytest.mark.asyncio
async def test_get_queue_status(self):
meta_db = MockMetaDB()
doc_store = MockDocStore()
orch = Orchestrator(meta_db, doc_store)
status = await orch.get_queue_status()
assert "pending" in status
assert "fetched" in status
assert "failed" in status