ten / agents /ten_packages /extension /openai_chatgpt /openai_chatgpt_extension.go
3v324v23's picture
Зафиксирована рабочая версия TEN-Agent для HuggingFace Space
87337b1
/**
*
* Agora Real Time Engagement
* Created by lixinhui in 2024.
* Copyright (c) 2024 Agora IO. All rights reserved.
*
*/
// Note that this is just an example extension written in the GO programming
// language, so the package name does not equal to the containing directory
// name. However, it is not common in Go.
package extension
import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"ten_framework/ten"
openai "github.com/sashabaranov/go-openai"
)
type openaiChatGPTExtension struct {
ten.DefaultExtension
openaiChatGPT *openaiChatGPT
}
const (
cmdInFlush = "flush"
cmdOutFlush = "flush"
dataInTextDataPropertyText = "text"
dataInTextDataPropertyIsFinal = "is_final"
dataOutTextDataPropertyText = "text"
dataOutTextDataPropertyTextEndOfSegment = "end_of_segment"
propertyBaseUrl = "base_url" // Optional
propertyApiKey = "api_key" // Required
propertyModel = "model" // Optional
propertyPrompt = "prompt" // Optional
propertyFrequencyPenalty = "frequency_penalty" // Optional
propertyPresencePenalty = "presence_penalty" // Optional
propertyTemperature = "temperature" // Optional
propertyTopP = "top_p" // Optional
propertyMaxTokens = "max_tokens" // Optional
propertyGreeting = "greeting" // Optional
propertyProxyUrl = "proxy_url" // Optional
propertyMaxMemoryLength = "max_memory_length" // Optional
)
var (
memory []openai.ChatCompletionMessage
memoryChan chan openai.ChatCompletionMessage
maxMemoryLength = 10
outdateTs atomic.Int64
wg sync.WaitGroup
)
func newChatGPTExtension(name string) ten.Extension {
return &openaiChatGPTExtension{}
}
// OnStart will be called when the extension is starting,
// properies can be read here to initialize and start the extension.
// current supported properties:
// - api_key (required)
// - model
// - prompt
// - frequency_penalty
// - presence_penalty
// - temperature
// - top_p
// - max_tokens
// - greeting
// - proxy_url
func (p *openaiChatGPTExtension) OnStart(tenEnv ten.TenEnv) {
tenEnv.LogInfo("OnStart")
// prepare configuration
openaiChatGPTConfig := defaultOpenaiChatGPTConfig()
if baseUrl, err := tenEnv.GetPropertyString(propertyBaseUrl); err != nil {
tenEnv.LogError(fmt.Sprintf("GetProperty required %s failed, err: %v", propertyBaseUrl, err))
} else {
if len(baseUrl) > 0 {
openaiChatGPTConfig.BaseUrl = baseUrl
}
}
if apiKey, err := tenEnv.GetPropertyString(propertyApiKey); err != nil {
tenEnv.LogError(fmt.Sprintf("GetProperty required %s failed, err: %v", propertyApiKey, err))
return
} else {
openaiChatGPTConfig.ApiKey = apiKey
}
if model, err := tenEnv.GetPropertyString(propertyModel); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s error:%v", propertyModel, err))
} else {
if len(model) > 0 {
openaiChatGPTConfig.Model = model
}
}
if prompt, err := tenEnv.GetPropertyString(propertyPrompt); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s error:%v", propertyPrompt, err))
} else {
if len(prompt) > 0 {
openaiChatGPTConfig.Prompt = prompt
}
}
if frequencyPenalty, err := tenEnv.GetPropertyFloat64(propertyFrequencyPenalty); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyFrequencyPenalty, err))
} else {
openaiChatGPTConfig.FrequencyPenalty = float32(frequencyPenalty)
}
if presencePenalty, err := tenEnv.GetPropertyFloat64(propertyPresencePenalty); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyPresencePenalty, err))
} else {
openaiChatGPTConfig.PresencePenalty = float32(presencePenalty)
}
if temperature, err := tenEnv.GetPropertyFloat64(propertyTemperature); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyTemperature, err))
} else {
openaiChatGPTConfig.Temperature = float32(temperature)
}
if topP, err := tenEnv.GetPropertyFloat64(propertyTopP); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyTopP, err))
} else {
openaiChatGPTConfig.TopP = float32(topP)
}
if maxTokens, err := tenEnv.GetPropertyInt64(propertyMaxTokens); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyMaxTokens, err))
} else {
if maxTokens > 0 {
openaiChatGPTConfig.MaxTokens = int(maxTokens)
}
}
if proxyUrl, err := tenEnv.GetPropertyString(propertyProxyUrl); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyProxyUrl, err))
} else {
openaiChatGPTConfig.ProxyUrl = proxyUrl
}
greeting, err := tenEnv.GetPropertyString(propertyGreeting)
if err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyGreeting, err))
}
if propMaxMemoryLength, err := tenEnv.GetPropertyInt64(propertyMaxMemoryLength); err != nil {
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyMaxMemoryLength, err))
} else {
if propMaxMemoryLength > 0 {
maxMemoryLength = int(propMaxMemoryLength)
}
}
// create openaiChatGPT instance
openaiChatgpt, err := newOpenaiChatGPT(openaiChatGPTConfig)
if err != nil {
tenEnv.LogError(fmt.Sprintf("newOpenaiChatGPT failed, err: %v", err))
return
}
tenEnv.LogInfo(fmt.Sprintf("newOpenaiChatGPT succeed with max_tokens: %d, model: %s",
openaiChatGPTConfig.MaxTokens, openaiChatGPTConfig.Model))
p.openaiChatGPT = openaiChatgpt
memoryChan = make(chan openai.ChatCompletionMessage, maxMemoryLength*2)
// send greeting if available
if len(greeting) > 0 {
outputData, _ := ten.NewData("text_data")
outputData.SetProperty(dataOutTextDataPropertyText, greeting)
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, true)
if err := tenEnv.SendData(outputData, nil); err != nil {
tenEnv.LogError(fmt.Sprintf("greeting [%s] send failed, err: %v", greeting, err))
} else {
tenEnv.LogInfo(fmt.Sprintf("greeting [%s] sent", greeting))
}
}
tenEnv.OnStartDone()
}
// OnCmd receives cmd from ten graph.
// current supported cmd:
// - name: flush
// example:
// {"name": "flush"}
func (p *openaiChatGPTExtension) OnCmd(
tenEnv ten.TenEnv,
cmd ten.Cmd,
) {
cmdName, err := cmd.GetName()
if err != nil {
tenEnv.LogError(fmt.Sprintf("OnCmd get name failed, err: %v", err))
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError)
tenEnv.ReturnResult(cmdResult, cmd, nil)
return
}
tenEnv.LogInfo(fmt.Sprintf("OnCmd %s", cmdInFlush))
switch cmdName {
case cmdInFlush:
outdateTs.Store(time.Now().UnixMicro())
wg.Wait() // wait for chat completion stream to finish
// send out
outCmd, err := ten.NewCmd(cmdOutFlush)
if err != nil {
tenEnv.LogError(fmt.Sprintf("new cmd %s failed, err: %v", cmdOutFlush, err))
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError)
tenEnv.ReturnResult(cmdResult, cmd, nil)
return
}
if err := tenEnv.SendCmd(outCmd, nil); err != nil {
tenEnv.LogError(fmt.Sprintf("send cmd %s failed, err: %v", cmdOutFlush, err))
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError)
tenEnv.ReturnResult(cmdResult, cmd, nil)
return
} else {
tenEnv.LogInfo(fmt.Sprintf("cmd %s sent", cmdOutFlush))
}
}
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeOk)
tenEnv.ReturnResult(cmdResult, cmd, nil)
}
// OnData receives data from ten graph.
// current supported data:
// - name: text_data
// example:
// {"name": "text_data", "properties": {"text": "hello", "is_final": true}
func (p *openaiChatGPTExtension) OnData(
tenEnv ten.TenEnv,
data ten.Data,
) {
// Get isFinal
isFinal, err := data.GetPropertyBool(dataInTextDataPropertyIsFinal)
if err != nil {
tenEnv.LogWarn(fmt.Sprintf("OnData GetProperty %s failed, err: %v", dataInTextDataPropertyIsFinal, err))
return
}
if !isFinal { // ignore non-final
tenEnv.LogDebug("ignore non-final input")
return
}
// Get input text
inputText, err := data.GetPropertyString(dataInTextDataPropertyText)
if err != nil {
tenEnv.LogError(fmt.Sprintf("OnData GetProperty %s failed, err: %v", dataInTextDataPropertyText, err))
return
}
if len(inputText) == 0 {
tenEnv.LogDebug("ignore empty text")
return
}
tenEnv.LogInfo(fmt.Sprintf("OnData input text: [%s]", inputText))
// prepare memory
for len(memoryChan) > 0 {
m, ok := <-memoryChan
if !ok {
break
}
memory = append(memory, m)
if len(memory) > maxMemoryLength {
memory = memory[1:]
}
}
memory = append(memory, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: inputText,
})
if len(memory) > maxMemoryLength {
memory = memory[1:]
}
// start goroutine to request and read responses from openai
wg.Add(1)
go func(startTime time.Time, inputText string, memory []openai.ChatCompletionMessage) {
defer wg.Done()
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] memory: %v", inputText, memory))
// Get result from ai
resp, err := p.openaiChatGPT.getChatCompletionsStream(memory)
if err != nil {
tenEnv.LogError(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] failed, err: %v", inputText, err))
return
}
defer func() {
if resp != nil { // Close stream object
resp.Close()
}
}()
tenEnv.LogDebug(fmt.Sprintf("GetChatCompletionsStream start to recv for input text: [%s]", inputText))
var sentence, fullContent string
var firstSentenceSent bool
for {
if startTime.UnixMicro() < outdateTs.Load() { // Check whether to interrupt
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream recv interrupt and flushing for input text: [%s], startTs: %d, outdateTs: %d",
inputText, startTime.UnixMicro(), outdateTs.Load()))
break
}
chatCompletions, err := resp.Recv()
if errors.Is(err, io.EOF) {
tenEnv.LogDebug(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s], io.EOF break", inputText))
break
}
var content string
if len(chatCompletions.Choices) > 0 && chatCompletions.Choices[0].Delta.Content != "" {
content = chatCompletions.Choices[0].Delta.Content
}
fullContent += content
for {
// feed content and check whether sentence is available
var sentenceIsFinal bool
sentence, content, sentenceIsFinal = parseSentence(sentence, content)
if len(sentence) == 0 || !sentenceIsFinal {
tenEnv.LogDebug(fmt.Sprintf("sentence %s is empty or not final", sentence))
break
}
tenEnv.LogDebug(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] got sentence: [%s]", inputText, sentence))
// send sentence
outputData, err := ten.NewData("text_data")
if err != nil {
tenEnv.LogError(fmt.Sprintf("NewData failed, err: %v", err))
break
}
outputData.SetProperty(dataOutTextDataPropertyText, sentence)
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, false)
if err := tenEnv.SendData(outputData, nil); err != nil {
tenEnv.LogError(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] send sentence [%s] failed, err: %v", inputText, sentence, err))
break
} else {
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] sent sentence [%s]", inputText, sentence))
}
sentence = ""
if !firstSentenceSent {
firstSentenceSent = true
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] first sentence sent, first_sentency_latency %dms",
inputText, time.Since(startTime).Milliseconds()))
}
}
}
// remember response as assistant content in memory
memoryChan <- openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: fullContent,
}
// send end of segment
outputData, _ := ten.NewData("text_data")
outputData.SetProperty(dataOutTextDataPropertyText, sentence)
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, true)
if err := tenEnv.SendData(outputData, nil); err != nil {
tenEnv.LogError(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] end of segment with sentence [%s] send failed, err: %v", inputText, sentence, err))
} else {
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] end of segment with sentence [%s] sent", inputText, sentence))
}
}(time.Now(), inputText, append([]openai.ChatCompletionMessage{}, memory...))
}
func init() {
// Register addon
ten.RegisterAddonAsExtension(
"openai_chatgpt",
ten.NewDefaultExtensionAddon(newChatGPTExtension),
)
}