File size: 6,515 Bytes
7710196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import openai
import traceback
import json
import asyncio
from typing import Dict, Optional, Tuple, List, AsyncGenerator
from langsmith import traceable

try:
    import config
    from utils import format_context_for_openai
except ImportError:
    print("Error: Failed to import config or utils in openai_service.py")
    raise SystemExit("Failed imports in openai_service.py")

# --- Globals ---
openai_async_client: Optional[openai.AsyncOpenAI] = None
is_openai_ready: bool = False
openai_status_message: str = "OpenAI service not initialized."

# --- Initialization ---
def init_openai_client() -> Tuple[bool, str]:
    """Initializes the OpenAI async client."""
    global openai_async_client, is_openai_ready, openai_status_message
    if is_openai_ready:
        return True, openai_status_message
    if not config.OPENAI_API_KEY:
        openai_status_message = "Error: OPENAI_API_KEY not found in Secrets."
        is_openai_ready = False
        return False, openai_status_message
    try:
        openai_async_client = openai.AsyncOpenAI(api_key=config.OPENAI_API_KEY)
        openai_status_message = (
            f"OpenAI service ready (Validate: {config.OPENAI_VALIDATION_MODEL}, "
            f"Generate: {config.OPENAI_GENERATION_MODEL})."
        )
        is_openai_ready = True
        print("OpenAI Service: Async client initialized.")
        return True, openai_status_message
    except Exception as e:
        error_msg = f"Error initializing OpenAI async client: {type(e).__name__} - {e}"
        print(error_msg)
        traceback.print_exc()
        openai_status_message = error_msg
        is_openai_ready = False
        openai_async_client = None
        return False, openai_status_message


def get_openai_status() -> Tuple[bool, str]:
    """Returns the current status of the OpenAI service."""
    if not is_openai_ready:
        init_openai_client()
    return is_openai_ready, openai_status_message

# --- Validation Function (uses template) ---
@traceable(name="openai-validate-paragraph")
async def validate_relevance_openai(
    paragraph_data: Dict, user_question: str, paragraph_index: int
) -> Optional[Dict]:
    global openai_async_client
    ready, msg = get_openai_status()
    if not ready or openai_async_client is None:
        print(f"OpenAI validation failed (Para {paragraph_index+1}): Client not ready - {msg}")
        return None

    safe_paragraph_data = paragraph_data.copy() if isinstance(paragraph_data, dict) else {}
    hebrew_text = paragraph_data.get('hebrew_text', '').strip()
    english_text = paragraph_data.get('english_text', '').strip()
    if not hebrew_text and not english_text:
        return {
            "validation": {"contains_relevant_info": False, "justification": "Paragraph text empty."},
            "paragraph_data": safe_paragraph_data
        }

    validation_model = config.OPENAI_VALIDATION_MODEL
    prompt_content = config.VALIDATION_PROMPT_TEMPLATE.format(
        user_question=user_question,
        paragraph_index=paragraph_index+1,
        hebrew_text=hebrew_text or "(No Hebrew)",
        english_text=english_text or "(No English)"
    )

    try:
        response = await openai_async_client.chat.completions.create(
            model=validation_model,
            messages=[{"role": "user", "content": prompt_content}],
            temperature=0.1,
            max_tokens=150,
            response_format={"type": "json_object"}
        )
        validation_result = json.loads(response.choices[0].message.content)
        return {"validation": validation_result, "paragraph_data": safe_paragraph_data}
    except Exception as e:
        print(f"Error (OpenAI Validate {paragraph_index+1}): {e}")
        traceback.print_exc()
        return {
            "validation": {"contains_relevant_info": False, "justification": "Error during validation."},
            "paragraph_data": safe_paragraph_data
        }

# --- Generation Function (unchanged) ---
@traceable(name="openai-generate-stream")
async def generate_openai_stream(
    messages: List[Dict],
    context_documents: List[Dict],
) -> AsyncGenerator[str, None]:
    """
    Generates a response using OpenAI GPT model based on history and context.
    Falls back to non-streaming for o-series models if streaming fails.
    """
    global openai_async_client
    ready, msg = get_openai_status()
    if not ready or openai_async_client is None:
        yield f"--- Error: OpenAI client not available: {msg} ---"
        return

    # Format context
    if not isinstance(context_documents, list) or not all(isinstance(item, dict) for item in context_documents):
        yield "--- Error: Invalid context_documents format ---"
        return
    formatted_context = format_context_for_openai(context_documents)
    if not formatted_context or formatted_context.startswith("No"):
        yield "--- Error: No valid context provided ---"
        return

    last_user = next((m['content'] for m in reversed(messages) if m.get('role')=='user'), "")
    user_prompt = (
        f"Source Texts:\n{formatted_context}\n\n"
        f"User Question:\n{last_user}\n\n"
        "Answer (in Hebrew, based ONLY on the Source Texts provided):"
    )
    sys_msg = config.OPENAI_SYSTEM_PROMPT
    api_messages = [{"role":"system","content":sys_msg},{"role":"user","content":user_prompt}]

    model = config.OPENAI_GENERATION_MODEL
    print(f"Using generation model: {model}")

    # Determine token parameter
    token_key = "max_completion_tokens" if model.startswith(("o1","o3","o4")) else "max_tokens"
    kwargs = {"model":model, "messages":api_messages, token_key:3000}

    # Attempt streaming for non-o-series
    if not model.startswith(("o1","o3","o4")):
        kwargs.update({"stream":True, "temperature":0.5})
        try:
            stream = await openai_async_client.chat.completions.create(**kwargs)
            async for chunk in stream:
                c = chunk.choices[0].delta.content
                if c:
                    yield c
            return
        except Exception as e:
            print(f"Streaming failed for model {model}: {e}")
            traceback.print_exc()

    # Fallback or direct call (o-series or streaming error)
    try:
        resp = await openai_async_client.chat.completions.create(**kwargs)
        text = resp.choices[0].message.content
        yield text
    except Exception as e:
        err = f"--- Error generating response: {e} ---"
        print(err)
        traceback.print_exc()
        yield err