File size: 5,317 Bytes
8f78c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e214413
8f78c8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import json
import os
from typing import Generator, List, Optional

import pandas as pd
import requests
from dotenv import load_dotenv
from openai import OpenAI

load_dotenv()


def query_llm(
    messages,
    history: List,
    df: Optional[pd.DataFrame],
    llm_type: str,
    api_key: str,
    system_prompt: str,
) -> Generator[str, None, None]:
    """Chat function that streams responses using an LLM API.

    Args:
        messages (str or list): User input message(s).
        history (list): Conversation history.
        df (pd.DataFrame): a representation of the data already obtained
        system_prompt (str): The syste prompt
        api_key (str): The OpenAI api key
    Returns:
        str: The assistant's response.
    """

    if not api_key:
        if llm_type == "OpenAI":
            api_key = os.environ.get("OPENAI_API_KEY")
        elif llm_type == "Perplexity":
            api_key = os.environ.get("PERPLEXITY_API_KEY")
        else:
            yield "No API key provided for the selected LLM type."

    print(f"LLM Type: {llm_type}, API Key len: {len(api_key)}")  # Debugging

    if isinstance(messages, str):
        messages = [{"role": "user", "content": messages}]

    # Extract last 2 messages from history (if available)
    history = history[-2:] if history else []

    # Build message history (prepend system prompt)
    full_messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Past interactions: {history}"},
        {
            "role": "assistant",
            "content": f"Dataset: {df.to_json() if df is not None else {}}",
        },
    ] + messages

    if llm_type == "Perplexity":
        yield from query_perplexity(full_messages, api_key=api_key)
    elif llm_type == "OpenAI":
        yield from query_openai(full_messages, api_key=api_key)
    else:
        yield "Unsupported LLM type. Please choose either 'OpenAI' or 'Perplexity'."


def query_perplexity(
    full_messages,
    api_key: str,
    url="https://api.perplexity.ai/chat/completions",
    model="sonar-pro",
):
    """Query Perplexity AI API for a response.

    Args:
        full_messages (list): List of messages in the conversation.
        api_key (str): Perplexity API key.
        url (str): API endpoint URL.
        model (str): Model to use for the query.

    Returns:
        str: Parsed JSON response from Perplexity AI API.
    """

    payload = {
        "model": model,
        "messages": full_messages,
        "stream": True,
    }

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }

    with requests.post(url, json=payload, headers=headers, stream=True) as response:
        if response.status_code == 200:
            for line in response.iter_lines():
                if line:
                    try:
                        line = line.decode("utf-8").strip()
                        if line.startswith("data: "):
                            line = line[len("data: ") :]  # Remove "data: " prefix

                        data = json.loads(line)
                        if "choices" in data and len(data["choices"]) > 0:
                            yield data["choices"][0]["message"]["content"]
                    except json.JSONDecodeError:
                        yield f"Error decoding JSON: {line}"
        else:
            yield f"API request failed with status code {response.status_code}, details: {response.text}"


def query_openai(full_messages, api_key: str) -> Generator[str, None, None]:
    """Chat function that streams responses using OpenAI API.

    Args:
        full_messages (list): List of messages in the conversation.
        api_key (str): OpenAI API key.
    """
    openai_client = OpenAI(api_key=api_key)

    response = openai_client.chat.completions.create(
        model="gpt-4o",
        messages=full_messages,
        stream=True,  # Enable streaming
    )

    llm_response = ""
    for chunk in response:
        if chunk.choices[0].delta.content:
            llm_response += chunk.choices[0].delta.content
            yield llm_response


def llm_extract_table(chat_output, llm_type, api_key) -> str:
    system_prompt = """
    You are a pharmacology assistant specialized in analyzing and structuring medical data.
    Your role is to extract information in either markdown, JSON or text, and turn it structured information.
    You will be given output from a conversation with an LLM. This conversation should have a dataset formatted
    as either json or markdown. Extract the dataset and return a JSON object.
    The dataset should be a JSON object with a dict per medication, with the following format:
    ```json
    {
        "Medications": [
            {"Name": "Medication Name", "key1": "value1", "key2": "value2",..},
            {"Name": "Medication Name", "key1": "value1", "key2": "value2",..}
        ]
    }
    
    Guidelines:
    - Make sure the response contains only a valid JSON
    - Avoid adding text before or after
    """

    response = query_llm(
        messages=chat_output,
        history=None,
        df=None,
        llm_type=llm_type,
        api_key=api_key,
        system_prompt=system_prompt,
    )
    json_str = "".join(response).strip()
    return json_str