Spaces:
Running
Running
File size: 3,523 Bytes
c151c44 ed5b42d c151c44 ed5b42d c151c44 ed5b42d c151c44 ed5b42d c151c44 ed5b42d c151c44 ed5b42d c151c44 ed5b42d c151c44 dcefa44 c151c44 ed5b42d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from openai import AsyncOpenAI, OpenAI
import asyncio
import json
from typing import AsyncIterator
from typing import Union, List, Dict, Literal
from dotenv import load_dotenv
import os
from pydantic import BaseModel
load_dotenv()
# Initialize the async client
client = AsyncOpenAI(
base_url=os.getenv("BASE_URL"),
api_key=os.getenv("API_KEY"),
)
class Message(BaseModel):
role: Literal["user", "assistant"]
content: str
# Helper function to flatten chat messages into a single string prompt
def flatten_messages(messages: List[Message]) -> str:
return "\n".join([f"{m.role}: {m.content}" for m in messages])
def process_input(data: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
"""
Processes input to either uppercase a string or modify the 'content' field
of a list of dictionaries.
"""
if isinstance(data, str):
return data.strip() # Ensures prompt is cleaned up (optional)
elif isinstance(data, list):
# Ensure each item in the list is a dictionary with a 'content' key
return [
{**item, "content": item["content"].strip()} # Trims whitespace in 'content'
for item in data if isinstance(item, dict) and "content" in item
]
else:
raise TypeError("Input must be a string or a list of dictionaries with a 'content' field")
# async def get_completions(
# prompt: Union[str, List[Dict[str, str]]],
# instructions: str
# ) -> str:
# processed_prompt = process_input(prompt) # Ensures the input format is correct
# if isinstance(processed_prompt, str):
# messages = [
# {"role": "system", "content": instructions},
# {"role": "user", "content": processed_prompt}
# ]
# elif isinstance(processed_prompt, list):
# messages = [{"role": "system", "content": instructions}] + processed_prompt
# else:
# raise TypeError("Unexpected processed input type.")
# response = await client.chat.completions.create(
# model=os.getenv("MODEL"),
# messages=messages,
# response_format={"type": "json_object"}
# )
# output: str = response.choices[0].message.content
# return output
async def get_completions(
prompt: Union[str, List[Dict[str, str]]],
instructions: str
) -> str:
if isinstance(prompt, list):
formatted_query = flatten_messages(prompt)
else:
formatted_query = prompt
processed_prompt = process_input(formatted_query)
messages = [{"role": "system", "content": instructions}]
if isinstance(processed_prompt, str):
messages.append({"role": "user", "content": processed_prompt})
elif isinstance(processed_prompt, list):
# Only keep the history for context and append the latest user query at the end
history = processed_prompt[:-1]
last_user_msg = processed_prompt[-1]
# Optional: Validate that the last message is from the user
if last_user_msg.get("role") != "user":
raise ValueError("Last message must be from the user.")
messages += history
messages.append(last_user_msg)
else:
raise TypeError("Unexpected processed input type.")
# print(os.getenv("MODEL"))
response = await client.chat.completions.create(
model=os.getenv("MODEL"),
messages=messages,
response_format={"type": "json_object"}
)
return response.choices[0].message.content # adjust based on your client
|