|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
import random |
|
import requests |
|
from openai import AsyncOpenAI, AsyncAzureOpenAI |
|
from openai.types.chat.chat_completion import ChatCompletion |
|
|
|
from ten.async_ten_env import AsyncTenEnv |
|
from ten_ai_base.config import BaseConfig |
|
|
|
|
|
@dataclass |
|
class OpenAIChatGPTConfig(BaseConfig): |
|
api_key: str = "" |
|
base_url: str = "https://api.openai.com/v1" |
|
model: str = ( |
|
"gpt-4o" |
|
) |
|
prompt: str = ( |
|
"You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points." |
|
) |
|
frequency_penalty: float = 0.9 |
|
presence_penalty: float = 0.9 |
|
top_p: float = 1.0 |
|
temperature: float = 0.1 |
|
max_tokens: int = 512 |
|
seed: int = random.randint(0, 10000) |
|
proxy_url: str = "" |
|
greeting: str = "Hello, how can I help you today?" |
|
max_memory_length: int = 10 |
|
vendor: str = "openai" |
|
azure_endpoint: str = "" |
|
azure_api_version: str = "" |
|
|
|
|
|
class ReasoningMode(str, Enum): |
|
ModeV1= "v1" |
|
|
|
class ThinkParser: |
|
def __init__(self): |
|
self.state = 'NORMAL' |
|
self.think_content = "" |
|
self.content = "" |
|
|
|
def process(self, new_chars): |
|
if new_chars == "<think>": |
|
self.state = 'THINK' |
|
return True |
|
elif new_chars == "</think>": |
|
self.state = 'NORMAL' |
|
return True |
|
else: |
|
if self.state == "THINK": |
|
self.think_content += new_chars |
|
return False |
|
|
|
def process_by_reasoning_content(self, reasoning_content): |
|
state_changed = False |
|
if reasoning_content: |
|
if self.state == 'NORMAL': |
|
self.state = 'THINK' |
|
state_changed = True |
|
self.think_content += reasoning_content |
|
elif self.state == 'THINK': |
|
self.state = 'NORMAL' |
|
state_changed = True |
|
return state_changed |
|
|
|
|
|
class OpenAIChatGPT: |
|
client = None |
|
|
|
def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig): |
|
self.config = config |
|
self.ten_env = ten_env |
|
ten_env.log_info(f"OpenAIChatGPT initialized with config: {config.api_key}") |
|
if self.config.vendor == "azure": |
|
self.client = AsyncAzureOpenAI( |
|
api_key=config.api_key, |
|
api_version=self.config.azure_api_version, |
|
azure_endpoint=config.azure_endpoint, |
|
) |
|
ten_env.log_info( |
|
f"Using Azure OpenAI with endpoint: {config.azure_endpoint}, api_version: {config.azure_api_version}" |
|
) |
|
else: |
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url, default_headers={ |
|
"api-key": config.api_key, |
|
"Authorization": f"Bearer {config.api_key}" |
|
}) |
|
self.session = requests.Session() |
|
if config.proxy_url: |
|
proxies = { |
|
"http": config.proxy_url, |
|
"https": config.proxy_url, |
|
} |
|
ten_env.log_info(f"Setting proxies: {proxies}") |
|
self.session.proxies.update(proxies) |
|
self.client.session = self.session |
|
|
|
async def get_chat_completions(self, messages, tools=None) -> ChatCompletion: |
|
req = { |
|
"model": self.config.model, |
|
"messages": [ |
|
{ |
|
"role": "system", |
|
"content": self.config.prompt, |
|
}, |
|
*messages, |
|
], |
|
"tools": tools, |
|
"temperature": self.config.temperature, |
|
"top_p": self.config.top_p, |
|
"presence_penalty": self.config.presence_penalty, |
|
"frequency_penalty": self.config.frequency_penalty, |
|
"max_tokens": self.config.max_tokens, |
|
"seed": self.config.seed, |
|
} |
|
|
|
try: |
|
response = await self.client.chat.completions.create(**req) |
|
except Exception as e: |
|
raise RuntimeError(f"CreateChatCompletion failed, err: {e}") from e |
|
|
|
return response |
|
|
|
async def get_chat_completions_stream(self, messages, tools=None, listener=None): |
|
req = { |
|
"model": self.config.model, |
|
"messages": [ |
|
{ |
|
"role": "system", |
|
"content": self.config.prompt, |
|
}, |
|
*messages, |
|
], |
|
"tools": tools, |
|
"temperature": self.config.temperature, |
|
"top_p": self.config.top_p, |
|
"presence_penalty": self.config.presence_penalty, |
|
"frequency_penalty": self.config.frequency_penalty, |
|
"max_tokens": self.config.max_tokens, |
|
"seed": self.config.seed, |
|
"stream": True, |
|
} |
|
|
|
try: |
|
response = await self.client.chat.completions.create(**req) |
|
except Exception as e: |
|
raise RuntimeError(f"CreateChatCompletionStream failed, err: {e}") from e |
|
|
|
full_content = "" |
|
|
|
tool_calls_dict = defaultdict( |
|
lambda: { |
|
"id": None, |
|
"function": {"arguments": "", "name": None}, |
|
"type": None, |
|
} |
|
) |
|
|
|
|
|
parser = ThinkParser() |
|
reasoning_mode = None |
|
|
|
async for chat_completion in response: |
|
|
|
if len(chat_completion.choices) == 0: |
|
continue |
|
choice = chat_completion.choices[0] |
|
delta = choice.delta |
|
|
|
content = delta.content if delta and delta.content else "" |
|
reasoning_content = delta.reasoning_content if delta and hasattr(delta, "reasoning_content") and delta.reasoning_content else "" |
|
|
|
if reasoning_mode is None and reasoning_content is not None: |
|
reasoning_mode = ReasoningMode.ModeV1 |
|
|
|
|
|
if listener and (content or reasoning_mode == ReasoningMode.ModeV1): |
|
prev_state = parser.state |
|
|
|
if reasoning_mode == ReasoningMode.ModeV1: |
|
self.ten_env.log_info("process_by_reasoning_content") |
|
think_state_changed = parser.process_by_reasoning_content(reasoning_content) |
|
else: |
|
think_state_changed = parser.process(content) |
|
|
|
if not think_state_changed: |
|
|
|
if parser.state == "THINK": |
|
listener.emit("reasoning_update", parser.think_content) |
|
elif parser.state == "NORMAL": |
|
listener.emit("content_update", content) |
|
|
|
if prev_state == "THINK" and parser.state == "NORMAL": |
|
listener.emit("reasoning_update_finish", parser.think_content) |
|
parser.think_content = "" |
|
|
|
full_content += content |
|
|
|
if delta.tool_calls: |
|
for tool_call in delta.tool_calls: |
|
if tool_call.id is not None: |
|
tool_calls_dict[tool_call.index]["id"] = tool_call.id |
|
|
|
|
|
if tool_call.function.name is not None: |
|
tool_calls_dict[tool_call.index]["function"][ |
|
"name" |
|
] = tool_call.function.name |
|
|
|
|
|
tool_calls_dict[tool_call.index]["function"][ |
|
"arguments" |
|
] += tool_call.function.arguments |
|
|
|
|
|
if tool_call.type is not None: |
|
tool_calls_dict[tool_call.index]["type"] = tool_call.type |
|
|
|
|
|
tool_calls_list = list(tool_calls_dict.values()) |
|
|
|
|
|
if listener and tool_calls_list: |
|
for tool_call in tool_calls_list: |
|
listener.emit("tool_call", tool_call) |
|
|
|
|
|
if listener: |
|
listener.emit("content_finished", full_content) |
|
|