File size: 9,253 Bytes
87337b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-08.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
import random
import requests
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion

from ten.async_ten_env import AsyncTenEnv
from ten_ai_base.config import BaseConfig


@dataclass
class OpenAIChatGPTConfig(BaseConfig):
    api_key: str = ""
    base_url: str = "https://api.openai.com/v1"
    model: str = (
        "gpt-4o"  # Adjust this to match the equivalent of `openai.GPT4o` in the Python library
    )
    prompt: str = (
        "You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points."
    )
    frequency_penalty: float = 0.9
    presence_penalty: float = 0.9
    top_p: float = 1.0
    temperature: float = 0.1
    max_tokens: int = 512
    seed: int = random.randint(0, 10000)
    proxy_url: str = ""
    greeting: str = "Hello, how can I help you today?"
    max_memory_length: int = 10
    vendor: str = "openai"
    azure_endpoint: str = ""
    azure_api_version: str = ""


class ReasoningMode(str, Enum):
    ModeV1= "v1"

class ThinkParser:
    def __init__(self):
        self.state = 'NORMAL'  # States: 'NORMAL', 'THINK'
        self.think_content = ""
        self.content = ""
    
    def process(self, new_chars):
        if new_chars == "<think>":
            self.state = 'THINK'
            return True
        elif new_chars == "</think>":
            self.state = 'NORMAL'
            return True
        else:
            if self.state == "THINK":
                self.think_content += new_chars
        return False
    
    def process_by_reasoning_content(self, reasoning_content):
        state_changed = False
        if reasoning_content:
            if self.state == 'NORMAL':
                self.state = 'THINK'
                state_changed = True
            self.think_content += reasoning_content
        elif self.state == 'THINK':
            self.state = 'NORMAL'
            state_changed = True
        return state_changed
        

class OpenAIChatGPT:
    client = None

    def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig):
        self.config = config
        self.ten_env = ten_env
        ten_env.log_info(f"OpenAIChatGPT initialized with config: {config.api_key}")
        if self.config.vendor == "azure":
            self.client = AsyncAzureOpenAI(
                api_key=config.api_key,
                api_version=self.config.azure_api_version,
                azure_endpoint=config.azure_endpoint,
            )
            ten_env.log_info(
                f"Using Azure OpenAI with endpoint: {config.azure_endpoint}, api_version: {config.azure_api_version}"
            )
        else:
            self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url, default_headers={
                "api-key": config.api_key,
                "Authorization": f"Bearer {config.api_key}"
            })
        self.session = requests.Session()
        if config.proxy_url:
            proxies = {
                "http": config.proxy_url,
                "https": config.proxy_url,
            }
            ten_env.log_info(f"Setting proxies: {proxies}")
            self.session.proxies.update(proxies)
        self.client.session = self.session

    async def get_chat_completions(self, messages, tools=None) -> ChatCompletion:
        req = {
            "model": self.config.model,
            "messages": [
                {
                    "role": "system",
                    "content": self.config.prompt,
                },
                *messages,
            ],
            "tools": tools,
            "temperature": self.config.temperature,
            "top_p": self.config.top_p,
            "presence_penalty": self.config.presence_penalty,
            "frequency_penalty": self.config.frequency_penalty,
            "max_tokens": self.config.max_tokens,
            "seed": self.config.seed,
        }

        try:
            response = await self.client.chat.completions.create(**req)
        except Exception as e:
            raise RuntimeError(f"CreateChatCompletion failed, err: {e}") from e

        return response

    async def get_chat_completions_stream(self, messages, tools=None, listener=None):
        req = {
            "model": self.config.model,
            "messages": [
                {
                    "role": "system",
                    "content": self.config.prompt,
                },
                *messages,
            ],
            "tools": tools,
            "temperature": self.config.temperature,
            "top_p": self.config.top_p,
            "presence_penalty": self.config.presence_penalty,
            "frequency_penalty": self.config.frequency_penalty,
            "max_tokens": self.config.max_tokens,
            "seed": self.config.seed,
            "stream": True,
        }

        try:
            response = await self.client.chat.completions.create(**req)
        except Exception as e:
            raise RuntimeError(f"CreateChatCompletionStream failed, err: {e}") from e

        full_content = ""
        # Check for tool calls
        tool_calls_dict = defaultdict(
            lambda: {
                "id": None,
                "function": {"arguments": "", "name": None},
                "type": None,
            }
        )

        # Example usage
        parser = ThinkParser()
        reasoning_mode = None

        async for chat_completion in response:
            # self.ten_env.log_info(f"Chat completion: {chat_completion}")
            if len(chat_completion.choices) == 0:
                continue
            choice = chat_completion.choices[0]
            delta = choice.delta

            content = delta.content if delta and delta.content else ""
            reasoning_content = delta.reasoning_content if delta and hasattr(delta, "reasoning_content") and delta.reasoning_content else ""

            if reasoning_mode is None and reasoning_content is not None:
                reasoning_mode = ReasoningMode.ModeV1

            # Emit content update event (fire-and-forget)
            if listener and (content or reasoning_mode == ReasoningMode.ModeV1):
                prev_state = parser.state

                if reasoning_mode == ReasoningMode.ModeV1:
                    self.ten_env.log_info("process_by_reasoning_content")
                    think_state_changed = parser.process_by_reasoning_content(reasoning_content)
                else:
                    think_state_changed = parser.process(content)

                if not think_state_changed:
                    # self.ten_env.log_info(f"state: {parser.state}, content: {content}, think: {parser.think_content}")
                    if parser.state == "THINK":
                        listener.emit("reasoning_update", parser.think_content)
                    elif parser.state == "NORMAL":
                        listener.emit("content_update", content)

                if prev_state == "THINK" and parser.state == "NORMAL":
                    listener.emit("reasoning_update_finish", parser.think_content)
                    parser.think_content = ""

            full_content += content

            if delta.tool_calls:
                for tool_call in delta.tool_calls:
                    if tool_call.id is not None:
                        tool_calls_dict[tool_call.index]["id"] = tool_call.id

                    # If the function name is not None, set it
                    if tool_call.function.name is not None:
                        tool_calls_dict[tool_call.index]["function"][
                            "name"
                        ] = tool_call.function.name

                    # Append the arguments
                    tool_calls_dict[tool_call.index]["function"][
                        "arguments"
                    ] += tool_call.function.arguments

                    # If the type is not None, set it
                    if tool_call.type is not None:
                        tool_calls_dict[tool_call.index]["type"] = tool_call.type

        # Convert the dictionary to a list
        tool_calls_list = list(tool_calls_dict.values())

        # Emit tool calls event (fire-and-forget)
        if listener and tool_calls_list:
            for tool_call in tool_calls_list:
                listener.emit("tool_call", tool_call)

        # Emit content finished event after the loop completes
        if listener:
            listener.emit("content_finished", full_content)