Spaces:
Sleeping
Sleeping
File size: 6,515 Bytes
7710196 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import openai
import traceback
import json
import asyncio
from typing import Dict, Optional, Tuple, List, AsyncGenerator
from langsmith import traceable
try:
import config
from utils import format_context_for_openai
except ImportError:
print("Error: Failed to import config or utils in openai_service.py")
raise SystemExit("Failed imports in openai_service.py")
# --- Globals ---
openai_async_client: Optional[openai.AsyncOpenAI] = None
is_openai_ready: bool = False
openai_status_message: str = "OpenAI service not initialized."
# --- Initialization ---
def init_openai_client() -> Tuple[bool, str]:
"""Initializes the OpenAI async client."""
global openai_async_client, is_openai_ready, openai_status_message
if is_openai_ready:
return True, openai_status_message
if not config.OPENAI_API_KEY:
openai_status_message = "Error: OPENAI_API_KEY not found in Secrets."
is_openai_ready = False
return False, openai_status_message
try:
openai_async_client = openai.AsyncOpenAI(api_key=config.OPENAI_API_KEY)
openai_status_message = (
f"OpenAI service ready (Validate: {config.OPENAI_VALIDATION_MODEL}, "
f"Generate: {config.OPENAI_GENERATION_MODEL})."
)
is_openai_ready = True
print("OpenAI Service: Async client initialized.")
return True, openai_status_message
except Exception as e:
error_msg = f"Error initializing OpenAI async client: {type(e).__name__} - {e}"
print(error_msg)
traceback.print_exc()
openai_status_message = error_msg
is_openai_ready = False
openai_async_client = None
return False, openai_status_message
def get_openai_status() -> Tuple[bool, str]:
"""Returns the current status of the OpenAI service."""
if not is_openai_ready:
init_openai_client()
return is_openai_ready, openai_status_message
# --- Validation Function (uses template) ---
@traceable(name="openai-validate-paragraph")
async def validate_relevance_openai(
paragraph_data: Dict, user_question: str, paragraph_index: int
) -> Optional[Dict]:
global openai_async_client
ready, msg = get_openai_status()
if not ready or openai_async_client is None:
print(f"OpenAI validation failed (Para {paragraph_index+1}): Client not ready - {msg}")
return None
safe_paragraph_data = paragraph_data.copy() if isinstance(paragraph_data, dict) else {}
hebrew_text = paragraph_data.get('hebrew_text', '').strip()
english_text = paragraph_data.get('english_text', '').strip()
if not hebrew_text and not english_text:
return {
"validation": {"contains_relevant_info": False, "justification": "Paragraph text empty."},
"paragraph_data": safe_paragraph_data
}
validation_model = config.OPENAI_VALIDATION_MODEL
prompt_content = config.VALIDATION_PROMPT_TEMPLATE.format(
user_question=user_question,
paragraph_index=paragraph_index+1,
hebrew_text=hebrew_text or "(No Hebrew)",
english_text=english_text or "(No English)"
)
try:
response = await openai_async_client.chat.completions.create(
model=validation_model,
messages=[{"role": "user", "content": prompt_content}],
temperature=0.1,
max_tokens=150,
response_format={"type": "json_object"}
)
validation_result = json.loads(response.choices[0].message.content)
return {"validation": validation_result, "paragraph_data": safe_paragraph_data}
except Exception as e:
print(f"Error (OpenAI Validate {paragraph_index+1}): {e}")
traceback.print_exc()
return {
"validation": {"contains_relevant_info": False, "justification": "Error during validation."},
"paragraph_data": safe_paragraph_data
}
# --- Generation Function (unchanged) ---
@traceable(name="openai-generate-stream")
async def generate_openai_stream(
messages: List[Dict],
context_documents: List[Dict],
) -> AsyncGenerator[str, None]:
"""
Generates a response using OpenAI GPT model based on history and context.
Falls back to non-streaming for o-series models if streaming fails.
"""
global openai_async_client
ready, msg = get_openai_status()
if not ready or openai_async_client is None:
yield f"--- Error: OpenAI client not available: {msg} ---"
return
# Format context
if not isinstance(context_documents, list) or not all(isinstance(item, dict) for item in context_documents):
yield "--- Error: Invalid context_documents format ---"
return
formatted_context = format_context_for_openai(context_documents)
if not formatted_context or formatted_context.startswith("No"):
yield "--- Error: No valid context provided ---"
return
last_user = next((m['content'] for m in reversed(messages) if m.get('role')=='user'), "")
user_prompt = (
f"Source Texts:\n{formatted_context}\n\n"
f"User Question:\n{last_user}\n\n"
"Answer (in Hebrew, based ONLY on the Source Texts provided):"
)
sys_msg = config.OPENAI_SYSTEM_PROMPT
api_messages = [{"role":"system","content":sys_msg},{"role":"user","content":user_prompt}]
model = config.OPENAI_GENERATION_MODEL
print(f"Using generation model: {model}")
# Determine token parameter
token_key = "max_completion_tokens" if model.startswith(("o1","o3","o4")) else "max_tokens"
kwargs = {"model":model, "messages":api_messages, token_key:3000}
# Attempt streaming for non-o-series
if not model.startswith(("o1","o3","o4")):
kwargs.update({"stream":True, "temperature":0.5})
try:
stream = await openai_async_client.chat.completions.create(**kwargs)
async for chunk in stream:
c = chunk.choices[0].delta.content
if c:
yield c
return
except Exception as e:
print(f"Streaming failed for model {model}: {e}")
traceback.print_exc()
# Fallback or direct call (o-series or streaming error)
try:
resp = await openai_async_client.chat.completions.create(**kwargs)
text = resp.choices[0].message.content
yield text
except Exception as e:
err = f"--- Error generating response: {e} ---"
print(err)
traceback.print_exc()
yield err
|