213 lines
7 KiB
Python
213 lines
7 KiB
Python
|
|
"""Orchestrator — main loop for processing fetched documents.
|
||
|
|
|
||
|
|
Dispatches documents through the processing pipeline:
|
||
|
|
fetched → parsed → embedded → graphed
|
||
|
|
|
||
|
|
Handles stage transitions, error recovery, and alerts.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from typing import Protocol, runtime_checkable
|
||
|
|
|
||
|
|
from aucourt_ingest.models import FetchStatus, RawDocument, CaseMeta, Chunk
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class TelegramAlertProtocol(Protocol):
|
||
|
|
async def send(self, message: str) -> bool: ...
|
||
|
|
async def alert(self, alert_type: str, **kwargs) -> bool: ...
|
||
|
|
async def close(self) -> None: ...
|
||
|
|
|
||
|
|
|
||
|
|
class NoOpAlert:
|
||
|
|
"""No-op alert sender when Telegram is disabled."""
|
||
|
|
|
||
|
|
async def send(self, message: str) -> bool:
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def alert(self, alert_type: str, **kwargs) -> bool:
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
@runtime_checkable
|
||
|
|
class ProcessingPipeline(Protocol):
|
||
|
|
"""Protocol for the processing pipeline components."""
|
||
|
|
|
||
|
|
async def parse(self, raw: RawDocument) -> tuple[str, str]:
|
||
|
|
"""Parse raw document. Returns (text, format)."""
|
||
|
|
...
|
||
|
|
|
||
|
|
async def extract_meta(self, doc_id: str, text: str) -> CaseMeta | None:
|
||
|
|
"""Extract structured metadata from document text."""
|
||
|
|
...
|
||
|
|
|
||
|
|
async def chunk(self, doc_id: str, text: str, meta: CaseMeta | None = None) -> list[Chunk]:
|
||
|
|
"""Chunk document text into semantic units."""
|
||
|
|
...
|
||
|
|
|
||
|
|
async def embed(self, chunks: list[Chunk]) -> list[Chunk]:
|
||
|
|
"""Generate embeddings for chunks."""
|
||
|
|
...
|
||
|
|
|
||
|
|
|
||
|
|
class Orchestrator:
|
||
|
|
"""Main processing loop: fetch_queue → parse → extract → chunk → embed → graph.
|
||
|
|
|
||
|
|
Reads from MetaDB fetch_queue, processes each document through
|
||
|
|
the pipeline stages, updates status along the way.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
meta_db,
|
||
|
|
doc_store,
|
||
|
|
graph_db=None,
|
||
|
|
vector_index=None,
|
||
|
|
pipeline: ProcessingPipeline | None = None,
|
||
|
|
alert: TelegramAlertProtocol | None = None,
|
||
|
|
):
|
||
|
|
self._meta_db = meta_db
|
||
|
|
self._doc_store = doc_store
|
||
|
|
self._graph_db = graph_db
|
||
|
|
self._vector_index = vector_index
|
||
|
|
self._pipeline = pipeline
|
||
|
|
self._alert = alert or NoOpAlert()
|
||
|
|
|
||
|
|
async def process_queue(self, limit: int = 0) -> dict:
|
||
|
|
"""Process items from the fetch queue.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
limit: Max documents to process (0 = process all pending).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Summary dict with counts per status.
|
||
|
|
"""
|
||
|
|
stats = {"processed": 0, "parsed": 0, "embedded": 0, "graphed": 0, "errors": 0}
|
||
|
|
count = 0
|
||
|
|
|
||
|
|
while True:
|
||
|
|
if limit > 0 and count >= limit:
|
||
|
|
break
|
||
|
|
|
||
|
|
item = await self._meta_db.dequeue()
|
||
|
|
if item is None:
|
||
|
|
logger.info("Queue empty")
|
||
|
|
break
|
||
|
|
|
||
|
|
doc_id = item.get("doc_id", item.get("url", ""))
|
||
|
|
source_id = item.get("source_id", "")
|
||
|
|
logger.info(f"Processing queue item: {doc_id} (source={source_id})")
|
||
|
|
|
||
|
|
try:
|
||
|
|
await self._process_document(doc_id, source_id, stats)
|
||
|
|
stats["processed"] += 1
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to process {doc_id}: {e}", exc_info=True)
|
||
|
|
await self._meta_db.update_status(doc_id, FetchStatus.FAILED, str(e))
|
||
|
|
await self._alert.alert("fetch_error", source_id=source_id, error=str(e)[:100])
|
||
|
|
stats["errors"] += 1
|
||
|
|
|
||
|
|
count += 1
|
||
|
|
|
||
|
|
logger.info(f"Queue processing complete: {stats}")
|
||
|
|
return stats
|
||
|
|
|
||
|
|
async def process_document(self, doc_id: str, source_id: str = "") -> dict:
|
||
|
|
"""Process a single document through the full pipeline.
|
||
|
|
|
||
|
|
Convenience method for processing documents outside the queue
|
||
|
|
(e.g., audit mode re-processing).
|
||
|
|
"""
|
||
|
|
stats = {"processed": 0, "parsed": 0, "embedded": 0, "graphed": 0, "errors": 0}
|
||
|
|
try:
|
||
|
|
await self._process_document(doc_id, source_id, stats)
|
||
|
|
stats["processed"] = 1
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to process {doc_id}: {e}", exc_info=True)
|
||
|
|
stats["errors"] = 1
|
||
|
|
return stats
|
||
|
|
|
||
|
|
async def _process_document(self, doc_id: str, source_id: str, stats: dict):
|
||
|
|
"""Internal: run a document through the pipeline stages."""
|
||
|
|
|
||
|
|
# Stage 1: Load raw document
|
||
|
|
raw_data = await self._doc_store.load(source_id, doc_id)
|
||
|
|
if raw_data is None:
|
||
|
|
raise FileNotFoundError(f"raw document not found: {source_id}/{doc_id}")
|
||
|
|
|
||
|
|
# Determine format from doc_id extension or stored metadata
|
||
|
|
fmt = "html"
|
||
|
|
if doc_id.lower().endswith(".pdf"):
|
||
|
|
fmt = "pdf"
|
||
|
|
elif doc_id.lower().endswith(".docx"):
|
||
|
|
fmt = "docx"
|
||
|
|
|
||
|
|
raw = RawDocument(
|
||
|
|
source_id=source_id,
|
||
|
|
doc_id=doc_id,
|
||
|
|
url=f"",
|
||
|
|
fetch_timestamp="",
|
||
|
|
raw_text=raw_data,
|
||
|
|
format=fmt,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Stage 2: Parse
|
||
|
|
if self._pipeline:
|
||
|
|
text, parsed_fmt = await self._pipeline.parse(raw)
|
||
|
|
raw.raw_text = text
|
||
|
|
raw.format = parsed_fmt
|
||
|
|
await self._meta_db.update_status(doc_id, FetchStatus.PARSED)
|
||
|
|
stats["parsed"] += 1
|
||
|
|
|
||
|
|
# Stage 3: Extract metadata
|
||
|
|
meta = None
|
||
|
|
if self._pipeline:
|
||
|
|
meta = await self._pipeline.extract_meta(doc_id, raw.raw_text)
|
||
|
|
if meta:
|
||
|
|
await self._meta_db.update_doc_meta(doc_id, {
|
||
|
|
"case_name": meta.case_name,
|
||
|
|
"mnc": meta.mnc,
|
||
|
|
"court": meta.court,
|
||
|
|
"verdict": meta.verdict.value if hasattr(meta.verdict, "value") else str(meta.verdict),
|
||
|
|
})
|
||
|
|
|
||
|
|
# Stage 4: Chunk
|
||
|
|
chunks = []
|
||
|
|
if self._pipeline:
|
||
|
|
chunks = await self._pipeline.chunk(doc_id, raw.raw_text, meta)
|
||
|
|
|
||
|
|
# Stage 5: Embed
|
||
|
|
if self._pipeline and chunks:
|
||
|
|
chunks = await self._pipeline.embed(chunks)
|
||
|
|
await self._meta_db.update_status(doc_id, FetchStatus.EMBEDDED)
|
||
|
|
stats["embedded"] += 1
|
||
|
|
|
||
|
|
# Store in vector index
|
||
|
|
if self._vector_index:
|
||
|
|
await self._vector_index.store_chunks(doc_id, chunks)
|
||
|
|
|
||
|
|
# Stage 6: Graph
|
||
|
|
if self._graph_db and meta:
|
||
|
|
from aucourt_ingest.processing.graph_builder import GraphBuilder
|
||
|
|
builder = GraphBuilder(self._graph_db)
|
||
|
|
await builder.build_full(meta, chunks)
|
||
|
|
await self._meta_db.update_status(doc_id, FetchStatus.GRAPHED)
|
||
|
|
stats["graphed"] += 1
|
||
|
|
|
||
|
|
async def get_queue_status(self) -> dict:
|
||
|
|
"""Return counts of documents in each status."""
|
||
|
|
result = {}
|
||
|
|
for status in FetchStatus:
|
||
|
|
docs = await self._meta_db.get_documents_by_status(status.value)
|
||
|
|
result[status.value] = len(docs)
|
||
|
|
return result
|
||
|
|
|
||
|
|
async def close(self):
|
||
|
|
await self._alert.close()
|