File size: 3,523 Bytes
c151c44
 
 
 
ed5b42d
c151c44
 
ed5b42d
c151c44
 
 
 
 
 
 
 
ed5b42d
 
 
 
 
 
 
 
c151c44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed5b42d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c151c44
 
 
 
ed5b42d
 
 
 
 
 
 
 
c151c44
 
ed5b42d
 
c151c44
ed5b42d
 
 
 
 
 
 
 
 
 
 
c151c44
 
 
dcefa44
c151c44
 
 
 
 
 
ed5b42d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from openai import AsyncOpenAI, OpenAI
import asyncio
import json
from typing import AsyncIterator
from typing import Union, List, Dict, Literal
from dotenv import load_dotenv
import os
from pydantic import BaseModel
load_dotenv()

# Initialize the async client
client = AsyncOpenAI(
    base_url=os.getenv("BASE_URL"),
    api_key=os.getenv("API_KEY"),
)

class Message(BaseModel):
    role: Literal["user", "assistant"]
    content: str

# Helper function to flatten chat messages into a single string prompt
def flatten_messages(messages: List[Message]) -> str:
    return "\n".join([f"{m.role}: {m.content}" for m in messages])

def process_input(data: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
    """
    Processes input to either uppercase a string or modify the 'content' field
    of a list of dictionaries.
    """
    if isinstance(data, str):
        return data.strip()  # Ensures prompt is cleaned up (optional)

    elif isinstance(data, list):
        # Ensure each item in the list is a dictionary with a 'content' key
        return [
            {**item, "content": item["content"].strip()}  # Trims whitespace in 'content'
            for item in data if isinstance(item, dict) and "content" in item
        ]
    
    else:
        raise TypeError("Input must be a string or a list of dictionaries with a 'content' field")


# async def get_completions(
#     prompt: Union[str, List[Dict[str, str]]],
#     instructions: str
# ) -> str:
#     processed_prompt = process_input(prompt)  # Ensures the input format is correct

#     if isinstance(processed_prompt, str):
#         messages = [
#             {"role": "system", "content": instructions},
#             {"role": "user", "content": processed_prompt}
#         ]
#     elif isinstance(processed_prompt, list):
#         messages = [{"role": "system", "content": instructions}] + processed_prompt
#     else:
#         raise TypeError("Unexpected processed input type.")

#     response = await client.chat.completions.create(
#         model=os.getenv("MODEL"),
#         messages=messages,
#         response_format={"type": "json_object"}
#     )

#     output: str = response.choices[0].message.content
#     return output

async def get_completions(
    prompt: Union[str, List[Dict[str, str]]],
    instructions: str
) -> str:
    if isinstance(prompt, list):
        formatted_query = flatten_messages(prompt)
    else:
        formatted_query = prompt

    processed_prompt = process_input(formatted_query)

    messages = [{"role": "system", "content": instructions}]

    if isinstance(processed_prompt, str):
        messages.append({"role": "user", "content": processed_prompt})

    elif isinstance(processed_prompt, list):
        # Only keep the history for context and append the latest user query at the end
        history = processed_prompt[:-1]
        last_user_msg = processed_prompt[-1]

        # Optional: Validate that the last message is from the user
        if last_user_msg.get("role") != "user":
            raise ValueError("Last message must be from the user.")

        messages += history
        messages.append(last_user_msg)

    else:
        raise TypeError("Unexpected processed input type.")

    # print(os.getenv("MODEL"))
    response = await client.chat.completions.create(
        model=os.getenv("MODEL"),
        messages=messages,
        response_format={"type": "json_object"}
    )

    return response.choices[0].message.content  # adjust based on your client