Spaces:
Running
Running
from openai import AsyncOpenAI, OpenAI | |
import asyncio | |
import json | |
from typing import AsyncIterator | |
from typing import Union, List, Dict, Literal | |
from dotenv import load_dotenv | |
import os | |
from pydantic import BaseModel | |
load_dotenv() | |
# Initialize the async client | |
client = AsyncOpenAI( | |
base_url=os.getenv("BASE_URL"), | |
api_key=os.getenv("API_KEY"), | |
) | |
class Message(BaseModel): | |
role: Literal["user", "assistant"] | |
content: str | |
# Helper function to flatten chat messages into a single string prompt | |
def flatten_messages(messages: List[Message]) -> str: | |
return "\n".join([f"{m.role}: {m.content}" for m in messages]) | |
def process_input(data: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: | |
""" | |
Processes input to either uppercase a string or modify the 'content' field | |
of a list of dictionaries. | |
""" | |
if isinstance(data, str): | |
return data.strip() # Ensures prompt is cleaned up (optional) | |
elif isinstance(data, list): | |
# Ensure each item in the list is a dictionary with a 'content' key | |
return [ | |
{**item, "content": item["content"].strip()} # Trims whitespace in 'content' | |
for item in data if isinstance(item, dict) and "content" in item | |
] | |
else: | |
raise TypeError("Input must be a string or a list of dictionaries with a 'content' field") | |
# async def get_completions( | |
# prompt: Union[str, List[Dict[str, str]]], | |
# instructions: str | |
# ) -> str: | |
# processed_prompt = process_input(prompt) # Ensures the input format is correct | |
# if isinstance(processed_prompt, str): | |
# messages = [ | |
# {"role": "system", "content": instructions}, | |
# {"role": "user", "content": processed_prompt} | |
# ] | |
# elif isinstance(processed_prompt, list): | |
# messages = [{"role": "system", "content": instructions}] + processed_prompt | |
# else: | |
# raise TypeError("Unexpected processed input type.") | |
# response = await client.chat.completions.create( | |
# model=os.getenv("MODEL"), | |
# messages=messages, | |
# response_format={"type": "json_object"} | |
# ) | |
# output: str = response.choices[0].message.content | |
# return output | |
async def get_completions( | |
prompt: Union[str, List[Dict[str, str]]], | |
instructions: str | |
) -> str: | |
if isinstance(prompt, list): | |
formatted_query = flatten_messages(prompt) | |
else: | |
formatted_query = prompt | |
processed_prompt = process_input(formatted_query) | |
messages = [{"role": "system", "content": instructions}] | |
if isinstance(processed_prompt, str): | |
messages.append({"role": "user", "content": processed_prompt}) | |
elif isinstance(processed_prompt, list): | |
# Only keep the history for context and append the latest user query at the end | |
history = processed_prompt[:-1] | |
last_user_msg = processed_prompt[-1] | |
# Optional: Validate that the last message is from the user | |
if last_user_msg.get("role") != "user": | |
raise ValueError("Last message must be from the user.") | |
messages += history | |
messages.append(last_user_msg) | |
else: | |
raise TypeError("Unexpected processed input type.") | |
# print(os.getenv("MODEL")) | |
response = await client.chat.completions.create( | |
model=os.getenv("MODEL"), | |
messages=messages, | |
response_format={"type": "json_object"} | |
) | |
return response.choices[0].message.content # adjust based on your client | |