422 lines
14 KiB
Python
422 lines
14 KiB
Python
|
|
"""Tests for Orchestrator, TelegramAlert, and watch/backfill wiring."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock
|
||
|
|
|
||
|
|
from aucourt_ingest.models import CaseMeta, Chunk, FetchStatus, RawDocument, Verdict, MatterType
|
||
|
|
from aucourt_ingest.orchestrator import Orchestrator, NoOpAlert
|
||
|
|
from aucourt_ingest.utils.telegram import TelegramAlert, ALERT_TEMPLATES
|
||
|
|
from aucourt_ingest.storage.in_memory_graph_db import InMemoryGraphDB
|
||
|
|
|
||
|
|
|
||
|
|
# ── TelegramAlert ──
|
||
|
|
|
||
|
|
class TestTelegramAlert:
|
||
|
|
def test_disabled_by_default(self):
|
||
|
|
alert = TelegramAlert()
|
||
|
|
assert not alert.enabled
|
||
|
|
|
||
|
|
def test_enabled_when_configured(self):
|
||
|
|
alert = TelegramAlert(bot_token="tok123", chat_id="chat456", enabled=True)
|
||
|
|
assert alert.enabled
|
||
|
|
|
||
|
|
def test_not_enabled_without_token(self):
|
||
|
|
alert = TelegramAlert(chat_id="chat456", enabled=True)
|
||
|
|
assert not alert.enabled
|
||
|
|
|
||
|
|
def test_not_enabled_without_chat_id(self):
|
||
|
|
alert = TelegramAlert(bot_token="tok123", enabled=True)
|
||
|
|
assert not alert.enabled
|
||
|
|
|
||
|
|
def test_not_enabled_when_flag_false(self):
|
||
|
|
alert = TelegramAlert(bot_token="tok123", chat_id="chat456", enabled=False)
|
||
|
|
assert not alert.enabled
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_send_disabled_returns_false(self):
|
||
|
|
alert = TelegramAlert()
|
||
|
|
assert await alert.send("test message") is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_alert_disabled_returns_false(self):
|
||
|
|
alert = TelegramAlert()
|
||
|
|
assert await alert.alert("source_degraded", source_id="test") is False
|
||
|
|
|
||
|
|
def test_alert_templates(self):
|
||
|
|
assert "source_degraded" in ALERT_TEMPLATES
|
||
|
|
assert "daily_summary" in ALERT_TEMPLATES
|
||
|
|
assert "milestone" in ALERT_TEMPLATES
|
||
|
|
assert "fetch_error" in ALERT_TEMPLATES
|
||
|
|
|
||
|
|
def test_alert_template_format(self):
|
||
|
|
template = ALERT_TEMPLATES["daily_summary"]
|
||
|
|
result = template.format(new_docs=42, nodes=1000, errors=3)
|
||
|
|
assert "42" in result
|
||
|
|
assert "1000" in result
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_alert_with_missing_key(self):
|
||
|
|
alert = TelegramAlert(bot_token="tok", chat_id="chat", enabled=True)
|
||
|
|
# Patch _get_client to avoid real HTTP calls
|
||
|
|
alert._get_client = MagicMock()
|
||
|
|
# Missing template key — should not crash
|
||
|
|
result = await alert.alert("daily_summary", new_docs=42) # missing nodes, errors
|
||
|
|
assert result is False # httpx mock returns None, so post fails
|
||
|
|
|
||
|
|
|
||
|
|
# ── NoOpAlert ──
|
||
|
|
|
||
|
|
class TestNoOpAlert:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_send_returns_false(self):
|
||
|
|
alert = NoOpAlert()
|
||
|
|
assert await alert.send("anything") is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_alert_returns_false(self):
|
||
|
|
alert = NoOpAlert()
|
||
|
|
assert await alert.alert("source_degraded") is False
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_close_is_noop(self):
|
||
|
|
alert = NoOpAlert()
|
||
|
|
await alert.close() # Should not raise
|
||
|
|
|
||
|
|
|
||
|
|
# ── Mock processing pipeline ──
|
||
|
|
|
||
|
|
class MockPipeline:
|
||
|
|
"""Mock pipeline that returns predictable results."""
|
||
|
|
|
||
|
|
def __init__(self, fail_at: str | None = None):
|
||
|
|
self._fail_at = fail_at
|
||
|
|
self.call_counts: dict[str, int] = {"parse": 0, "extract_meta": 0, "chunk": 0, "embed": 0}
|
||
|
|
|
||
|
|
async def parse(self, raw: RawDocument) -> tuple[str, str]:
|
||
|
|
self.call_counts["parse"] += 1
|
||
|
|
if self._fail_at == "parse":
|
||
|
|
raise RuntimeError("parse failed")
|
||
|
|
return f"Parsed: {raw.raw_text[:50]}", raw.format
|
||
|
|
|
||
|
|
async def extract_meta(self, doc_id: str, text: str) -> CaseMeta | None:
|
||
|
|
self.call_counts["extract_meta"] += 1
|
||
|
|
if self._fail_at == "extract_meta":
|
||
|
|
raise RuntimeError("extract failed")
|
||
|
|
return CaseMeta(
|
||
|
|
case_name=f"Test Case {doc_id}",
|
||
|
|
mnc=doc_id,
|
||
|
|
court="NSWSC",
|
||
|
|
judge=["Judge Test"],
|
||
|
|
charges=["test charge"],
|
||
|
|
verdict=Verdict.GUILTY,
|
||
|
|
matter_type=MatterType.CRIMINAL,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def chunk(self, doc_id: str, text: str, meta: CaseMeta | None = None) -> list[Chunk]:
|
||
|
|
self.call_counts["chunk"] += 1
|
||
|
|
if self._fail_at == "chunk":
|
||
|
|
raise RuntimeError("chunk failed")
|
||
|
|
return [
|
||
|
|
Chunk(
|
||
|
|
chunk_id=f"{doc_id}_c0",
|
||
|
|
doc_id=doc_id,
|
||
|
|
chunk_type="testimony",
|
||
|
|
sequence=0,
|
||
|
|
text="Test chunk text",
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
async def embed(self, chunks: list[Chunk]) -> list[Chunk]:
|
||
|
|
self.call_counts["embed"] += 1
|
||
|
|
if self._fail_at == "embed":
|
||
|
|
raise RuntimeError("embed failed")
|
||
|
|
for c in chunks:
|
||
|
|
c.embedding = [0.1, 0.2, 0.3]
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
|
||
|
|
class MockMetaDB:
|
||
|
|
"""Mock MetaDB for testing orchestrator."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._documents: dict[str, dict] = {}
|
||
|
|
self._queue: list[dict] = []
|
||
|
|
self._status_updates: list[tuple[str, str, str | None]] = []
|
||
|
|
self._meta_updates: list[tuple[str, dict]] = []
|
||
|
|
|
||
|
|
async def connect(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def dequeue(self) -> dict | None:
|
||
|
|
if self._queue:
|
||
|
|
return self._queue.pop(0)
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def update_status(self, doc_id: str, status: str, error: str | None = None):
|
||
|
|
self._status_updates.append((doc_id, status, error))
|
||
|
|
|
||
|
|
async def update_doc_meta(self, doc_id: str, meta: dict):
|
||
|
|
self._meta_updates.append((doc_id, meta))
|
||
|
|
|
||
|
|
async def get_documents_by_status(self, status: str) -> list[dict]:
|
||
|
|
return [
|
||
|
|
{"doc_id": doc_id, "source_id": data.get("source_id", "")}
|
||
|
|
for doc_id, data in self._documents.items()
|
||
|
|
if data.get("status") == status
|
||
|
|
]
|
||
|
|
|
||
|
|
async def close(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def add_to_queue(self, doc_id: str, source_id: str = "test"):
|
||
|
|
self._queue.append({"doc_id": doc_id, "source_id": source_id})
|
||
|
|
self._documents[doc_id] = {"status": FetchStatus.PENDING, "source_id": source_id}
|
||
|
|
|
||
|
|
|
||
|
|
class MockDocStore:
|
||
|
|
"""Mock DocStore for testing orchestrator."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._documents: dict[str, str] = {}
|
||
|
|
|
||
|
|
def store(self, source_id: str, doc_id: str, content: str):
|
||
|
|
self._documents[(source_id, doc_id)] = content
|
||
|
|
|
||
|
|
async def load(self, source_id: str, doc_id: str) -> str | None:
|
||
|
|
return self._documents.get((source_id, doc_id))
|
||
|
|
|
||
|
|
def exists(self, source_id: str, doc_id: str) -> bool:
|
||
|
|
return (source_id, doc_id) in self._documents
|
||
|
|
|
||
|
|
|
||
|
|
class MockVectorIndex:
|
||
|
|
"""Mock vector index for testing."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._stored: list[tuple[str, list[Chunk]]] = []
|
||
|
|
|
||
|
|
async def store_chunks(self, doc_id: str, chunks: list[Chunk]):
|
||
|
|
self._stored.append((doc_id, chunks))
|
||
|
|
|
||
|
|
|
||
|
|
# ── Orchestrator ──
|
||
|
|
|
||
|
|
class TestOrchestrator:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_process_queue_empty(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
orch = Orchestrator(meta_db, doc_store)
|
||
|
|
stats = await orch.process_queue()
|
||
|
|
|
||
|
|
assert stats["processed"] == 0
|
||
|
|
assert stats["errors"] == 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_process_single_document(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Raw content for doc1")
|
||
|
|
meta_db.add_to_queue("doc1", "test_src")
|
||
|
|
|
||
|
|
pipeline = MockPipeline()
|
||
|
|
graph_db = InMemoryGraphDB()
|
||
|
|
vector_index = MockVectorIndex()
|
||
|
|
orch = Orchestrator(meta_db, doc_store, graph_db, vector_index, pipeline)
|
||
|
|
stats = await orch.process_queue()
|
||
|
|
|
||
|
|
assert stats["processed"] == 1
|
||
|
|
assert stats["errors"] == 0
|
||
|
|
assert stats["parsed"] == 1
|
||
|
|
assert stats["embedded"] == 1
|
||
|
|
assert stats["graphed"] == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_process_multiple_documents(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
for i in range(3):
|
||
|
|
doc_id = f"doc{i}"
|
||
|
|
doc_store.store("test_src", doc_id, f"Content {i}")
|
||
|
|
meta_db.add_to_queue(doc_id, "test_src")
|
||
|
|
|
||
|
|
pipeline = MockPipeline()
|
||
|
|
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
|
||
|
|
stats = await orch.process_queue()
|
||
|
|
|
||
|
|
assert stats["processed"] == 3
|
||
|
|
assert stats["parsed"] == 3
|
||
|
|
assert stats["errors"] == 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_process_with_limit(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
for i in range(5):
|
||
|
|
doc_store.store("test_src", f"doc{i}", f"Content {i}")
|
||
|
|
meta_db.add_to_queue(f"doc{i}", "test_src")
|
||
|
|
|
||
|
|
pipeline = MockPipeline()
|
||
|
|
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
|
||
|
|
stats = await orch.process_queue(limit=2)
|
||
|
|
|
||
|
|
assert stats["processed"] == 2
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_process_missing_document_marks_failed(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
# No document stored
|
||
|
|
meta_db.add_to_queue("missing_doc", "test_src")
|
||
|
|
|
||
|
|
orch = Orchestrator(meta_db, doc_store)
|
||
|
|
stats = await orch.process_queue()
|
||
|
|
|
||
|
|
assert stats["errors"] == 1
|
||
|
|
failed_updates = [u for u in meta_db._status_updates if u[1] == FetchStatus.FAILED]
|
||
|
|
assert len(failed_updates) == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_process_pipeline_error_marks_failed(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Content")
|
||
|
|
meta_db.add_to_queue("doc1", "test_src")
|
||
|
|
|
||
|
|
pipeline = MockPipeline(fail_at="parse")
|
||
|
|
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
|
||
|
|
stats = await orch.process_queue()
|
||
|
|
|
||
|
|
assert stats["errors"] == 1
|
||
|
|
failed_updates = [u for u in meta_db._status_updates if u[1] == FetchStatus.FAILED]
|
||
|
|
assert len(failed_updates) == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_status_transitions(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Content")
|
||
|
|
meta_db.add_to_queue("doc1", "test_src")
|
||
|
|
|
||
|
|
pipeline = MockPipeline()
|
||
|
|
graph_db = InMemoryGraphDB()
|
||
|
|
vector_index = MockVectorIndex()
|
||
|
|
orch = Orchestrator(meta_db, doc_store, graph_db, vector_index, pipeline)
|
||
|
|
await orch.process_queue()
|
||
|
|
|
||
|
|
statuses = [u[1] for u in meta_db._status_updates]
|
||
|
|
assert FetchStatus.PARSED in statuses
|
||
|
|
assert FetchStatus.EMBEDDED in statuses
|
||
|
|
assert FetchStatus.GRAPHED in statuses
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_graph_built_for_document(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Content")
|
||
|
|
meta_db.add_to_queue("doc1", "test_src")
|
||
|
|
|
||
|
|
pipeline = MockPipeline()
|
||
|
|
graph_db = InMemoryGraphDB()
|
||
|
|
vector_index = MockVectorIndex()
|
||
|
|
orch = Orchestrator(meta_db, doc_store, graph_db, vector_index, pipeline)
|
||
|
|
await orch.process_queue()
|
||
|
|
|
||
|
|
assert await graph_db.node_count("Case") == 1
|
||
|
|
assert await graph_db.node_count("Judge") == 1
|
||
|
|
assert await graph_db.node_count("Charge") == 1
|
||
|
|
assert await graph_db.node_count("Chunk") == 1
|
||
|
|
assert await graph_db.relationship_count("HEARD_BY") == 1
|
||
|
|
assert await graph_db.relationship_count("CHARGED_WITH") == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_vector_index_stores_chunks(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Content")
|
||
|
|
meta_db.add_to_queue("doc1", "test_src")
|
||
|
|
|
||
|
|
pipeline = MockPipeline()
|
||
|
|
graph_db = InMemoryGraphDB()
|
||
|
|
vector_index = MockVectorIndex()
|
||
|
|
orch = Orchestrator(meta_db, doc_store, graph_db, vector_index, pipeline)
|
||
|
|
await orch.process_queue()
|
||
|
|
|
||
|
|
assert len(vector_index._stored) == 1
|
||
|
|
doc_id, chunks = vector_index._stored[0]
|
||
|
|
assert doc_id == "doc1"
|
||
|
|
assert len(chunks) == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_process_document_single(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Content")
|
||
|
|
|
||
|
|
pipeline = MockPipeline()
|
||
|
|
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
|
||
|
|
stats = await orch.process_document("doc1", "test_src")
|
||
|
|
|
||
|
|
assert stats["processed"] == 1
|
||
|
|
assert stats["errors"] == 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_process_document_error(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Content")
|
||
|
|
|
||
|
|
pipeline = MockPipeline(fail_at="chunk")
|
||
|
|
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
|
||
|
|
stats = await orch.process_document("doc1", "test_src")
|
||
|
|
|
||
|
|
assert stats["errors"] == 1
|
||
|
|
assert stats["processed"] == 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_meta_extracted_and_stored(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Content")
|
||
|
|
meta_db.add_to_queue("doc1", "test_src")
|
||
|
|
|
||
|
|
pipeline = MockPipeline()
|
||
|
|
orch = Orchestrator(meta_db, doc_store, pipeline=pipeline)
|
||
|
|
await orch.process_queue()
|
||
|
|
|
||
|
|
assert len(meta_db._meta_updates) == 1
|
||
|
|
_, meta = meta_db._meta_updates[0]
|
||
|
|
assert "case_name" in meta
|
||
|
|
assert "mnc" in meta
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_without_pipeline_still_parses(self):
|
||
|
|
"""Without pipeline, should still mark as parsed with raw text."""
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
doc_store.store("test_src", "doc1", "Content")
|
||
|
|
meta_db.add_to_queue("doc1", "test_src")
|
||
|
|
|
||
|
|
orch = Orchestrator(meta_db, doc_store)
|
||
|
|
stats = await orch.process_queue()
|
||
|
|
|
||
|
|
assert stats["processed"] == 1
|
||
|
|
assert stats["parsed"] == 1
|
||
|
|
assert stats["embedded"] == 0 # No pipeline → no chunking/embedding
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_get_queue_status(self):
|
||
|
|
meta_db = MockMetaDB()
|
||
|
|
doc_store = MockDocStore()
|
||
|
|
orch = Orchestrator(meta_db, doc_store)
|
||
|
|
|
||
|
|
status = await orch.get_queue_status()
|
||
|
|
assert "pending" in status
|
||
|
|
assert "fetched" in status
|
||
|
|
assert "failed" in status
|