import requests import aiohttp import asyncio import logging import os import sys import time from typing import List, Dict, Any from contextlib import asynccontextmanager from langchain_core.documents import Document from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) class MistralLoader: """ Enhanced Mistral OCR loader with both sync and async support. Loads documents by processing them through the Mistral OCR API. """ BASE_API_URL = "https://api.mistral.ai/v1" def __init__( self, api_key: str, file_path: str, timeout: int = 300, # 5 minutes default max_retries: int = 3, enable_debug_logging: bool = False, ): """ Initializes the loader with enhanced features. Args: api_key: Your Mistral API key. file_path: The local path to the PDF file to process. timeout: Request timeout in seconds. max_retries: Maximum number of retry attempts. enable_debug_logging: Enable detailed debug logs. """ if not api_key: raise ValueError("API key cannot be empty.") if not os.path.exists(file_path): raise FileNotFoundError(f"File not found at {file_path}") self.api_key = api_key self.file_path = file_path self.timeout = timeout self.max_retries = max_retries self.debug = enable_debug_logging # Pre-compute file info for performance self.file_name = os.path.basename(file_path) self.file_size = os.path.getsize(file_path) self.headers = { "Authorization": f"Bearer {self.api_key}", "User-Agent": "OpenWebUI-MistralLoader/2.0", } def _debug_log(self, message: str, *args) -> None: """Conditional debug logging for performance.""" if self.debug: log.debug(message, *args) def _handle_response(self, response: requests.Response) -> Dict[str, Any]: """Checks response status and returns JSON content.""" try: response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx) # Handle potential empty responses for certain successful requests (e.g., DELETE) if response.status_code == 204 or not response.content: return {} # Return empty dict if no content return response.json() except requests.exceptions.HTTPError as http_err: log.error(f"HTTP error occurred: {http_err} - Response: {response.text}") raise except requests.exceptions.RequestException as req_err: log.error(f"Request exception occurred: {req_err}") raise except ValueError as json_err: # Includes JSONDecodeError log.error(f"JSON decode error: {json_err} - Response: {response.text}") raise # Re-raise after logging async def _handle_response_async( self, response: aiohttp.ClientResponse ) -> Dict[str, Any]: """Async version of response handling with better error info.""" try: response.raise_for_status() # Check content type content_type = response.headers.get("content-type", "") if "application/json" not in content_type: if response.status == 204: return {} text = await response.text() raise ValueError( f"Unexpected content type: {content_type}, body: {text[:200]}..." ) return await response.json() except aiohttp.ClientResponseError as e: error_text = await response.text() if response else "No response" log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}") raise except aiohttp.ClientError as e: log.error(f"Client error: {e}") raise except Exception as e: log.error(f"Unexpected error processing response: {e}") raise def _retry_request_sync(self, request_func, *args, **kwargs): """Synchronous retry logic with exponential backoff.""" for attempt in range(self.max_retries): try: return request_func(*args, **kwargs) except (requests.exceptions.RequestException, Exception) as e: if attempt == self.max_retries - 1: raise wait_time = (2**attempt) + 0.5 log.warning( f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..." ) time.sleep(wait_time) async def _retry_request_async(self, request_func, *args, **kwargs): """Async retry logic with exponential backoff.""" for attempt in range(self.max_retries): try: return await request_func(*args, **kwargs) except (aiohttp.ClientError, asyncio.TimeoutError) as e: if attempt == self.max_retries - 1: raise wait_time = (2**attempt) + 0.5 log.warning( f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..." ) await asyncio.sleep(wait_time) def _upload_file(self) -> str: """Uploads the file to Mistral for OCR processing (sync version).""" log.info("Uploading file to Mistral API") url = f"{self.BASE_API_URL}/files" file_name = os.path.basename(self.file_path) def upload_request(): with open(self.file_path, "rb") as f: files = {"file": (file_name, f, "application/pdf")} data = {"purpose": "ocr"} response = requests.post( url, headers=self.headers, files=files, data=data, timeout=self.timeout, ) return self._handle_response(response) try: response_data = self._retry_request_sync(upload_request) file_id = response_data.get("id") if not file_id: raise ValueError("File ID not found in upload response.") log.info(f"File uploaded successfully. File ID: {file_id}") return file_id except Exception as e: log.error(f"Failed to upload file: {e}") raise async def _upload_file_async(self, session: aiohttp.ClientSession) -> str: """Async file upload with streaming for better memory efficiency.""" url = f"{self.BASE_API_URL}/files" async def upload_request(): # Create multipart writer for streaming upload writer = aiohttp.MultipartWriter("form-data") # Add purpose field purpose_part = writer.append("ocr") purpose_part.set_content_disposition("form-data", name="purpose") # Add file part with streaming file_part = writer.append_payload( aiohttp.streams.FilePayload( self.file_path, filename=self.file_name, content_type="application/pdf", ) ) file_part.set_content_disposition( "form-data", name="file", filename=self.file_name ) self._debug_log( f"Uploading file: {self.file_name} ({self.file_size:,} bytes)" ) async with session.post( url, data=writer, headers=self.headers, timeout=aiohttp.ClientTimeout(total=self.timeout), ) as response: return await self._handle_response_async(response) response_data = await self._retry_request_async(upload_request) file_id = response_data.get("id") if not file_id: raise ValueError("File ID not found in upload response.") log.info(f"File uploaded successfully. File ID: {file_id}") return file_id def _get_signed_url(self, file_id: str) -> str: """Retrieves a temporary signed URL for the uploaded file (sync version).""" log.info(f"Getting signed URL for file ID: {file_id}") url = f"{self.BASE_API_URL}/files/{file_id}/url" params = {"expiry": 1} signed_url_headers = {**self.headers, "Accept": "application/json"} def url_request(): response = requests.get( url, headers=signed_url_headers, params=params, timeout=self.timeout ) return self._handle_response(response) try: response_data = self._retry_request_sync(url_request) signed_url = response_data.get("url") if not signed_url: raise ValueError("Signed URL not found in response.") log.info("Signed URL received.") return signed_url except Exception as e: log.error(f"Failed to get signed URL: {e}") raise async def _get_signed_url_async( self, session: aiohttp.ClientSession, file_id: str ) -> str: """Async signed URL retrieval.""" url = f"{self.BASE_API_URL}/files/{file_id}/url" params = {"expiry": 1} headers = {**self.headers, "Accept": "application/json"} async def url_request(): self._debug_log(f"Getting signed URL for file ID: {file_id}") async with session.get( url, headers=headers, params=params, timeout=aiohttp.ClientTimeout(total=self.timeout), ) as response: return await self._handle_response_async(response) response_data = await self._retry_request_async(url_request) signed_url = response_data.get("url") if not signed_url: raise ValueError("Signed URL not found in response.") self._debug_log("Signed URL received successfully") return signed_url def _process_ocr(self, signed_url: str) -> Dict[str, Any]: """Sends the signed URL to the OCR endpoint for processing (sync version).""" log.info("Processing OCR via Mistral API") url = f"{self.BASE_API_URL}/ocr" ocr_headers = { **self.headers, "Content-Type": "application/json", "Accept": "application/json", } payload = { "model": "mistral-ocr-latest", "document": { "type": "document_url", "document_url": signed_url, }, "include_image_base64": False, } def ocr_request(): response = requests.post( url, headers=ocr_headers, json=payload, timeout=self.timeout ) return self._handle_response(response) try: ocr_response = self._retry_request_sync(ocr_request) log.info("OCR processing done.") self._debug_log("OCR response: %s", ocr_response) return ocr_response except Exception as e: log.error(f"Failed during OCR processing: {e}") raise async def _process_ocr_async( self, session: aiohttp.ClientSession, signed_url: str ) -> Dict[str, Any]: """Async OCR processing with timing metrics.""" url = f"{self.BASE_API_URL}/ocr" headers = { **self.headers, "Content-Type": "application/json", "Accept": "application/json", } payload = { "model": "mistral-ocr-latest", "document": { "type": "document_url", "document_url": signed_url, }, "include_image_base64": False, } async def ocr_request(): log.info("Starting OCR processing via Mistral API") start_time = time.time() async with session.post( url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout), ) as response: ocr_response = await self._handle_response_async(response) processing_time = time.time() - start_time log.info(f"OCR processing completed in {processing_time:.2f}s") return ocr_response return await self._retry_request_async(ocr_request) def _delete_file(self, file_id: str) -> None: """Deletes the file from Mistral storage (sync version).""" log.info(f"Deleting uploaded file ID: {file_id}") url = f"{self.BASE_API_URL}/files/{file_id}" try: response = requests.delete(url, headers=self.headers, timeout=30) delete_response = self._handle_response(response) log.info(f"File deleted successfully: {delete_response}") except Exception as e: # Log error but don't necessarily halt execution if deletion fails log.error(f"Failed to delete file ID {file_id}: {e}") async def _delete_file_async( self, session: aiohttp.ClientSession, file_id: str ) -> None: """Async file deletion with error tolerance.""" try: async def delete_request(): self._debug_log(f"Deleting file ID: {file_id}") async with session.delete( url=f"{self.BASE_API_URL}/files/{file_id}", headers=self.headers, timeout=aiohttp.ClientTimeout( total=30 ), # Shorter timeout for cleanup ) as response: return await self._handle_response_async(response) await self._retry_request_async(delete_request) self._debug_log(f"File {file_id} deleted successfully") except Exception as e: # Don't fail the entire process if cleanup fails log.warning(f"Failed to delete file ID {file_id}: {e}") @asynccontextmanager async def _get_session(self): """Context manager for HTTP session with optimized settings.""" connector = aiohttp.TCPConnector( limit=10, # Total connection limit limit_per_host=5, # Per-host connection limit ttl_dns_cache=300, # DNS cache TTL use_dns_cache=True, keepalive_timeout=30, enable_cleanup_closed=True, ) async with aiohttp.ClientSession( connector=connector, timeout=aiohttp.ClientTimeout(total=self.timeout), headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, ) as session: yield session def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: """Process OCR results into Document objects with enhanced metadata.""" pages_data = ocr_response.get("pages") if not pages_data: log.warning("No pages found in OCR response.") return [ Document( page_content="No text content found", metadata={"error": "no_pages"} ) ] documents = [] total_pages = len(pages_data) skipped_pages = 0 for page_data in pages_data: page_content = page_data.get("markdown") page_index = page_data.get("index") # API uses 0-based index if page_content is not None and page_index is not None: # Clean up content efficiently cleaned_content = ( page_content.strip() if isinstance(page_content, str) else str(page_content) ) if cleaned_content: # Only add non-empty pages documents.append( Document( page_content=cleaned_content, metadata={ "page": page_index, # 0-based index from API "page_label": page_index + 1, # 1-based label for convenience "total_pages": total_pages, "file_name": self.file_name, "file_size": self.file_size, "processing_engine": "mistral-ocr", }, ) ) else: skipped_pages += 1 self._debug_log(f"Skipping empty page {page_index}") else: skipped_pages += 1 self._debug_log( f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}" ) if skipped_pages > 0: log.info( f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages" ) if not documents: # Case where pages existed but none had valid markdown/index log.warning( "OCR response contained pages, but none had valid content/index." ) return [ Document( page_content="No valid text content found in document", metadata={"error": "no_valid_pages", "total_pages": total_pages}, ) ] return documents def load(self) -> List[Document]: """ Executes the full OCR workflow: upload, get URL, process OCR, delete file. Synchronous version for backward compatibility. Returns: A list of Document objects, one for each page processed. """ file_id = None start_time = time.time() try: # 1. Upload file file_id = self._upload_file() # 2. Get Signed URL signed_url = self._get_signed_url(file_id) # 3. Process OCR ocr_response = self._process_ocr(signed_url) # 4. Process results documents = self._process_results(ocr_response) total_time = time.time() - start_time log.info( f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" ) return documents except Exception as e: total_time = time.time() - start_time log.error( f"An error occurred during the loading process after {total_time:.2f}s: {e}" ) # Return an error document on failure return [ Document( page_content=f"Error during processing: {e}", metadata={ "error": "processing_failed", "file_name": self.file_name, }, ) ] finally: # 5. Delete file (attempt even if prior steps failed after upload) if file_id: try: self._delete_file(file_id) except Exception as del_e: # Log deletion error, but don't overwrite original error if one occurred log.error( f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}" ) async def load_async(self) -> List[Document]: """ Asynchronous OCR workflow execution with optimized performance. Returns: A list of Document objects, one for each page processed. """ file_id = None start_time = time.time() try: async with self._get_session() as session: # 1. Upload file with streaming file_id = await self._upload_file_async(session) # 2. Get signed URL signed_url = await self._get_signed_url_async(session, file_id) # 3. Process OCR ocr_response = await self._process_ocr_async(session, signed_url) # 4. Process results documents = self._process_results(ocr_response) total_time = time.time() - start_time log.info( f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" ) return documents except Exception as e: total_time = time.time() - start_time log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}") return [ Document( page_content=f"Error during OCR processing: {e}", metadata={ "error": "processing_failed", "file_name": self.file_name, }, ) ] finally: # 5. Cleanup - always attempt file deletion if file_id: try: async with self._get_session() as session: await self._delete_file_async(session, file_id) except Exception as cleanup_error: log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}") @staticmethod async def load_multiple_async( loaders: List["MistralLoader"], ) -> List[List[Document]]: """ Process multiple files concurrently for maximum performance. Args: loaders: List of MistralLoader instances Returns: List of document lists, one for each loader """ if not loaders: return [] log.info(f"Starting concurrent processing of {len(loaders)} files") start_time = time.time() # Process all files concurrently tasks = [loader.load_async() for loader in loaders] results = await asyncio.gather(*tasks, return_exceptions=True) # Handle any exceptions in results processed_results = [] for i, result in enumerate(results): if isinstance(result, Exception): log.error(f"File {i} failed: {result}") processed_results.append( [ Document( page_content=f"Error processing file: {result}", metadata={ "error": "batch_processing_failed", "file_index": i, }, ) ] ) else: processed_results.append(result) total_time = time.time() - start_time total_docs = sum(len(docs) for docs in processed_results) log.info( f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents" ) return processed_results