File size: 4,233 Bytes
8bae60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Dict, List

from openai import OpenAI

import requests
import json
import simplejson

from pydantic import BaseModel

class AnswerFormat(BaseModel):
    dataset: List[Dict]
    explanations: str
    references: str


def query_perplexity(
        system_prompt: str,
        user_prompt: str,
        json_data: str,
        api_key: str,
        url="https://api.perplexity.ai/chat/completions",
        model="sonar-pro",
):
    """Query Perplexity AI API for a response.

    Args:
        system_prompt (str): System message providing AI context.
        user_prompt (str): User's query.
        json_data (str): JSON data representing the current dataset.
        api_key (str): Perplexity AI API key.
        url (str): API endpoint.
        model (str): Perplexity AI model to use.
        max_tokens (int): Maximum number of tokens in the response.
        temperature (float): Sampling temperature for randomness.
        top_p (float): Nucleus sampling parameter.
        top_k (int): Top-k filtering.
        presence_penalty (float): Encourages new token diversity.
        frequency_penalty (float): Penalizes frequent tokens.
        return_images (bool): Whether to include images in response.
        return_related_questions (bool): Whether to include related questions.
        search_domain_filter (str or None): Domain filter for web search.
        search_recency_filter (str or None): Recency filter for web search.
        stream (bool): Whether to stream response.

    Returns:
        str: Parsed JSON response from Perplexity AI API.
    """

    payload = {
        "model": model,
        "messages": [
            {"role": "system", "content": f"{system_prompt}\n"
                                          f"Make sure you add the citations found to the references key"},
            {"role": "user", "content": f"Here is the dataset: {json_data}\n\n"
                                        f"User query:\n"
                                        f"{user_prompt}"},
        ],
        "response_format": {
		    "type": "json_schema",
        "json_schema": {"schema": AnswerFormat.model_json_schema()},
    },
    }

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }

    response = requests.post(url, json=payload, headers=headers)

    if response.status_code == 200:
        response_json = response.json()
        return response_json["choices"][0]["message"]["content"]
    else:
        return f"API request failed with status code {response.status_code}, details: {response.text}"



def query_openai(system_prompt: str, user_prompt: str, json_data: str, openai_client: OpenAI) -> str:
    """Query OpenAI API for a response.

    Args:
        system_prompt (str): System prompt providing context to the AI.
        user_prompt (str): User's query.
        json_data (str): JSON data representing the current dataset.
        openai_client (OpenAI): OpenAI client instance with API key set.

    Returns:
        str: JSON response from the API.
    """

    response = openai_client.chat.completions.create(
        model="gpt-4-turbo",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"Here is the dataset: {json_data}"},
            {"role": "user", "content": user_prompt},
        ],
        response_format={"type": "json_object"},
    )

    if len(response.choices) > 0:
        content = response.choices[0].message.content
        return content
    else:
        return "Bad response from OpenAI"


def validate_llm_response(response: str) -> dict:

    # extract dict from json
    try:
        return json.loads(response)
    except json.JSONDecodeError:
        try:
            return simplejson.loads(response)  # More forgiving JSON parser
        except simplejson.JSONDecodeError:
            return None  # JSON is too broken to fix

    # Validate expected keys
    required_keys = {"dataset", "explanation", "references"}
    if not required_keys.issubset(response.keys()):
        raise ValueError(f"Missing required keys: {required_keys - response.keys()}")

    return response  # Return as a structured dictionary