aucourt-ingest/aucourt_ingest/orchestrator.py

213 lines
7 KiB
Python
Raw Permalink Normal View History

"""Orchestrator — main loop for processing fetched documents.
Dispatches documents through the processing pipeline:
fetched parsed embedded graphed
Handles stage transitions, error recovery, and alerts.
"""
from __future__ import annotations
import logging
from typing import Protocol, runtime_checkable
from aucourt_ingest.models import FetchStatus, RawDocument, CaseMeta, Chunk
logger = logging.getLogger(__name__)
class TelegramAlertProtocol(Protocol):
async def send(self, message: str) -> bool: ...
async def alert(self, alert_type: str, **kwargs) -> bool: ...
async def close(self) -> None: ...
class NoOpAlert:
"""No-op alert sender when Telegram is disabled."""
async def send(self, message: str) -> bool:
return False
async def alert(self, alert_type: str, **kwargs) -> bool:
return False
async def close(self) -> None:
pass
@runtime_checkable
class ProcessingPipeline(Protocol):
"""Protocol for the processing pipeline components."""
async def parse(self, raw: RawDocument) -> tuple[str, str]:
"""Parse raw document. Returns (text, format)."""
...
async def extract_meta(self, doc_id: str, text: str) -> CaseMeta | None:
"""Extract structured metadata from document text."""
...
async def chunk(self, doc_id: str, text: str, meta: CaseMeta | None = None) -> list[Chunk]:
"""Chunk document text into semantic units."""
...
async def embed(self, chunks: list[Chunk]) -> list[Chunk]:
"""Generate embeddings for chunks."""
...
class Orchestrator:
"""Main processing loop: fetch_queue → parse → extract → chunk → embed → graph.
Reads from MetaDB fetch_queue, processes each document through
the pipeline stages, updates status along the way.
"""
def __init__(
self,
meta_db,
doc_store,
graph_db=None,
vector_index=None,
pipeline: ProcessingPipeline | None = None,
alert: TelegramAlertProtocol | None = None,
):
self._meta_db = meta_db
self._doc_store = doc_store
self._graph_db = graph_db
self._vector_index = vector_index
self._pipeline = pipeline
self._alert = alert or NoOpAlert()
async def process_queue(self, limit: int = 0) -> dict:
"""Process items from the fetch queue.
Args:
limit: Max documents to process (0 = process all pending).
Returns:
Summary dict with counts per status.
"""
stats = {"processed": 0, "parsed": 0, "embedded": 0, "graphed": 0, "errors": 0}
count = 0
while True:
if limit > 0 and count >= limit:
break
item = await self._meta_db.dequeue()
if item is None:
logger.info("Queue empty")
break
doc_id = item.get("doc_id", item.get("url", ""))
source_id = item.get("source_id", "")
logger.info(f"Processing queue item: {doc_id} (source={source_id})")
try:
await self._process_document(doc_id, source_id, stats)
stats["processed"] += 1
except Exception as e:
logger.error(f"Failed to process {doc_id}: {e}", exc_info=True)
await self._meta_db.update_status(doc_id, FetchStatus.FAILED, str(e))
await self._alert.alert("fetch_error", source_id=source_id, error=str(e)[:100])
stats["errors"] += 1
count += 1
logger.info(f"Queue processing complete: {stats}")
return stats
async def process_document(self, doc_id: str, source_id: str = "") -> dict:
"""Process a single document through the full pipeline.
Convenience method for processing documents outside the queue
(e.g., audit mode re-processing).
"""
stats = {"processed": 0, "parsed": 0, "embedded": 0, "graphed": 0, "errors": 0}
try:
await self._process_document(doc_id, source_id, stats)
stats["processed"] = 1
except Exception as e:
logger.error(f"Failed to process {doc_id}: {e}", exc_info=True)
stats["errors"] = 1
return stats
async def _process_document(self, doc_id: str, source_id: str, stats: dict):
"""Internal: run a document through the pipeline stages."""
# Stage 1: Load raw document
raw_data = await self._doc_store.load(source_id, doc_id)
if raw_data is None:
raise FileNotFoundError(f"raw document not found: {source_id}/{doc_id}")
# Determine format from doc_id extension or stored metadata
fmt = "html"
if doc_id.lower().endswith(".pdf"):
fmt = "pdf"
elif doc_id.lower().endswith(".docx"):
fmt = "docx"
raw = RawDocument(
source_id=source_id,
doc_id=doc_id,
url=f"",
fetch_timestamp="",
raw_text=raw_data,
format=fmt,
)
# Stage 2: Parse
if self._pipeline:
text, parsed_fmt = await self._pipeline.parse(raw)
raw.raw_text = text
raw.format = parsed_fmt
await self._meta_db.update_status(doc_id, FetchStatus.PARSED)
stats["parsed"] += 1
# Stage 3: Extract metadata
meta = None
if self._pipeline:
meta = await self._pipeline.extract_meta(doc_id, raw.raw_text)
if meta:
await self._meta_db.update_doc_meta(doc_id, {
"case_name": meta.case_name,
"mnc": meta.mnc,
"court": meta.court,
"verdict": meta.verdict.value if hasattr(meta.verdict, "value") else str(meta.verdict),
})
# Stage 4: Chunk
chunks = []
if self._pipeline:
chunks = await self._pipeline.chunk(doc_id, raw.raw_text, meta)
# Stage 5: Embed
if self._pipeline and chunks:
chunks = await self._pipeline.embed(chunks)
await self._meta_db.update_status(doc_id, FetchStatus.EMBEDDED)
stats["embedded"] += 1
# Store in vector index
if self._vector_index:
await self._vector_index.store_chunks(doc_id, chunks)
# Stage 6: Graph
if self._graph_db and meta:
from aucourt_ingest.processing.graph_builder import GraphBuilder
builder = GraphBuilder(self._graph_db)
await builder.build_full(meta, chunks)
await self._meta_db.update_status(doc_id, FetchStatus.GRAPHED)
stats["graphed"] += 1
async def get_queue_status(self) -> dict:
"""Return counts of documents in each status."""
result = {}
for status in FetchStatus:
docs = await self._meta_db.get_documents_by_status(status.value)
result[status.value] = len(docs)
return result
async def close(self):
await self._alert.close()