Spaces:
Sleeping
Sleeping
import openai | |
import traceback | |
import json | |
import asyncio | |
from typing import Dict, Optional, Tuple, List, AsyncGenerator | |
from langsmith import traceable | |
try: | |
import config | |
from utils import format_context_for_openai | |
except ImportError: | |
print("Error: Failed to import config or utils in openai_service.py") | |
raise SystemExit("Failed imports in openai_service.py") | |
# --- Globals --- | |
openai_async_client: Optional[openai.AsyncOpenAI] = None | |
is_openai_ready: bool = False | |
openai_status_message: str = "OpenAI service not initialized." | |
# --- Initialization --- | |
def init_openai_client() -> Tuple[bool, str]: | |
"""Initializes the OpenAI async client.""" | |
global openai_async_client, is_openai_ready, openai_status_message | |
if is_openai_ready: | |
return True, openai_status_message | |
if not config.OPENAI_API_KEY: | |
openai_status_message = "Error: OPENAI_API_KEY not found in Secrets." | |
is_openai_ready = False | |
return False, openai_status_message | |
try: | |
openai_async_client = openai.AsyncOpenAI(api_key=config.OPENAI_API_KEY) | |
openai_status_message = ( | |
f"OpenAI service ready (Validate: {config.OPENAI_VALIDATION_MODEL}, " | |
f"Generate: {config.OPENAI_GENERATION_MODEL})." | |
) | |
is_openai_ready = True | |
print("OpenAI Service: Async client initialized.") | |
return True, openai_status_message | |
except Exception as e: | |
error_msg = f"Error initializing OpenAI async client: {type(e).__name__} - {e}" | |
print(error_msg) | |
traceback.print_exc() | |
openai_status_message = error_msg | |
is_openai_ready = False | |
openai_async_client = None | |
return False, openai_status_message | |
def get_openai_status() -> Tuple[bool, str]: | |
"""Returns the current status of the OpenAI service.""" | |
if not is_openai_ready: | |
init_openai_client() | |
return is_openai_ready, openai_status_message | |
# --- Validation Function (uses template) --- | |
async def validate_relevance_openai( | |
paragraph_data: Dict, user_question: str, paragraph_index: int | |
) -> Optional[Dict]: | |
global openai_async_client | |
ready, msg = get_openai_status() | |
if not ready or openai_async_client is None: | |
print(f"OpenAI validation failed (Para {paragraph_index+1}): Client not ready - {msg}") | |
return None | |
safe_paragraph_data = paragraph_data.copy() if isinstance(paragraph_data, dict) else {} | |
hebrew_text = paragraph_data.get('hebrew_text', '').strip() | |
english_text = paragraph_data.get('english_text', '').strip() | |
if not hebrew_text and not english_text: | |
return { | |
"validation": {"contains_relevant_info": False, "justification": "Paragraph text empty."}, | |
"paragraph_data": safe_paragraph_data | |
} | |
validation_model = config.OPENAI_VALIDATION_MODEL | |
prompt_content = config.VALIDATION_PROMPT_TEMPLATE.format( | |
user_question=user_question, | |
paragraph_index=paragraph_index+1, | |
hebrew_text=hebrew_text or "(No Hebrew)", | |
english_text=english_text or "(No English)" | |
) | |
try: | |
response = await openai_async_client.chat.completions.create( | |
model=validation_model, | |
messages=[{"role": "user", "content": prompt_content}], | |
temperature=0.1, | |
max_tokens=150, | |
response_format={"type": "json_object"} | |
) | |
validation_result = json.loads(response.choices[0].message.content) | |
return {"validation": validation_result, "paragraph_data": safe_paragraph_data} | |
except Exception as e: | |
print(f"Error (OpenAI Validate {paragraph_index+1}): {e}") | |
traceback.print_exc() | |
return { | |
"validation": {"contains_relevant_info": False, "justification": "Error during validation."}, | |
"paragraph_data": safe_paragraph_data | |
} | |
# --- Generation Function (unchanged) --- | |
async def generate_openai_stream( | |
messages: List[Dict], | |
context_documents: List[Dict], | |
) -> AsyncGenerator[str, None]: | |
""" | |
Generates a response using OpenAI GPT model based on history and context. | |
Falls back to non-streaming for o-series models if streaming fails. | |
""" | |
global openai_async_client | |
ready, msg = get_openai_status() | |
if not ready or openai_async_client is None: | |
yield f"--- Error: OpenAI client not available: {msg} ---" | |
return | |
# Format context | |
if not isinstance(context_documents, list) or not all(isinstance(item, dict) for item in context_documents): | |
yield "--- Error: Invalid context_documents format ---" | |
return | |
formatted_context = format_context_for_openai(context_documents) | |
if not formatted_context or formatted_context.startswith("No"): | |
yield "--- Error: No valid context provided ---" | |
return | |
last_user = next((m['content'] for m in reversed(messages) if m.get('role')=='user'), "") | |
user_prompt = ( | |
f"Source Texts:\n{formatted_context}\n\n" | |
f"User Question:\n{last_user}\n\n" | |
"Answer (in Hebrew, based ONLY on the Source Texts provided):" | |
) | |
sys_msg = config.OPENAI_SYSTEM_PROMPT | |
api_messages = [{"role":"system","content":sys_msg},{"role":"user","content":user_prompt}] | |
model = config.OPENAI_GENERATION_MODEL | |
print(f"Using generation model: {model}") | |
# Determine token parameter | |
token_key = "max_completion_tokens" if model.startswith(("o1","o3","o4")) else "max_tokens" | |
kwargs = {"model":model, "messages":api_messages, token_key:3000} | |
# Attempt streaming for non-o-series | |
if not model.startswith(("o1","o3","o4")): | |
kwargs.update({"stream":True, "temperature":0.5}) | |
try: | |
stream = await openai_async_client.chat.completions.create(**kwargs) | |
async for chunk in stream: | |
c = chunk.choices[0].delta.content | |
if c: | |
yield c | |
return | |
except Exception as e: | |
print(f"Streaming failed for model {model}: {e}") | |
traceback.print_exc() | |
# Fallback or direct call (o-series or streaming error) | |
try: | |
resp = await openai_async_client.chat.completions.create(**kwargs) | |
text = resp.choices[0].message.content | |
yield text | |
except Exception as e: | |
err = f"--- Error generating response: {e} ---" | |
print(err) | |
traceback.print_exc() | |
yield err | |