|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
package extension |
|
|
|
import ( |
|
"errors" |
|
"fmt" |
|
"io" |
|
"sync" |
|
"sync/atomic" |
|
"time" |
|
|
|
"ten_framework/ten" |
|
|
|
openai "github.com/sashabaranov/go-openai" |
|
) |
|
|
|
type openaiChatGPTExtension struct { |
|
ten.DefaultExtension |
|
openaiChatGPT *openaiChatGPT |
|
} |
|
|
|
const ( |
|
cmdInFlush = "flush" |
|
cmdOutFlush = "flush" |
|
dataInTextDataPropertyText = "text" |
|
dataInTextDataPropertyIsFinal = "is_final" |
|
dataOutTextDataPropertyText = "text" |
|
dataOutTextDataPropertyTextEndOfSegment = "end_of_segment" |
|
|
|
propertyBaseUrl = "base_url" |
|
propertyApiKey = "api_key" |
|
propertyModel = "model" |
|
propertyPrompt = "prompt" |
|
propertyFrequencyPenalty = "frequency_penalty" |
|
propertyPresencePenalty = "presence_penalty" |
|
propertyTemperature = "temperature" |
|
propertyTopP = "top_p" |
|
propertyMaxTokens = "max_tokens" |
|
propertyGreeting = "greeting" |
|
propertyProxyUrl = "proxy_url" |
|
propertyMaxMemoryLength = "max_memory_length" |
|
) |
|
|
|
var ( |
|
memory []openai.ChatCompletionMessage |
|
memoryChan chan openai.ChatCompletionMessage |
|
maxMemoryLength = 10 |
|
|
|
outdateTs atomic.Int64 |
|
wg sync.WaitGroup |
|
) |
|
|
|
func newChatGPTExtension(name string) ten.Extension { |
|
return &openaiChatGPTExtension{} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (p *openaiChatGPTExtension) OnStart(tenEnv ten.TenEnv) { |
|
tenEnv.LogInfo("OnStart") |
|
|
|
|
|
openaiChatGPTConfig := defaultOpenaiChatGPTConfig() |
|
|
|
if baseUrl, err := tenEnv.GetPropertyString(propertyBaseUrl); err != nil { |
|
tenEnv.LogError(fmt.Sprintf("GetProperty required %s failed, err: %v", propertyBaseUrl, err)) |
|
} else { |
|
if len(baseUrl) > 0 { |
|
openaiChatGPTConfig.BaseUrl = baseUrl |
|
} |
|
} |
|
|
|
if apiKey, err := tenEnv.GetPropertyString(propertyApiKey); err != nil { |
|
tenEnv.LogError(fmt.Sprintf("GetProperty required %s failed, err: %v", propertyApiKey, err)) |
|
return |
|
} else { |
|
openaiChatGPTConfig.ApiKey = apiKey |
|
} |
|
|
|
if model, err := tenEnv.GetPropertyString(propertyModel); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s error:%v", propertyModel, err)) |
|
} else { |
|
if len(model) > 0 { |
|
openaiChatGPTConfig.Model = model |
|
} |
|
} |
|
|
|
if prompt, err := tenEnv.GetPropertyString(propertyPrompt); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s error:%v", propertyPrompt, err)) |
|
} else { |
|
if len(prompt) > 0 { |
|
openaiChatGPTConfig.Prompt = prompt |
|
} |
|
} |
|
|
|
if frequencyPenalty, err := tenEnv.GetPropertyFloat64(propertyFrequencyPenalty); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyFrequencyPenalty, err)) |
|
} else { |
|
openaiChatGPTConfig.FrequencyPenalty = float32(frequencyPenalty) |
|
} |
|
|
|
if presencePenalty, err := tenEnv.GetPropertyFloat64(propertyPresencePenalty); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyPresencePenalty, err)) |
|
} else { |
|
openaiChatGPTConfig.PresencePenalty = float32(presencePenalty) |
|
} |
|
|
|
if temperature, err := tenEnv.GetPropertyFloat64(propertyTemperature); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyTemperature, err)) |
|
} else { |
|
openaiChatGPTConfig.Temperature = float32(temperature) |
|
} |
|
|
|
if topP, err := tenEnv.GetPropertyFloat64(propertyTopP); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyTopP, err)) |
|
} else { |
|
openaiChatGPTConfig.TopP = float32(topP) |
|
} |
|
|
|
if maxTokens, err := tenEnv.GetPropertyInt64(propertyMaxTokens); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyMaxTokens, err)) |
|
} else { |
|
if maxTokens > 0 { |
|
openaiChatGPTConfig.MaxTokens = int(maxTokens) |
|
} |
|
} |
|
|
|
if proxyUrl, err := tenEnv.GetPropertyString(propertyProxyUrl); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyProxyUrl, err)) |
|
} else { |
|
openaiChatGPTConfig.ProxyUrl = proxyUrl |
|
} |
|
|
|
greeting, err := tenEnv.GetPropertyString(propertyGreeting) |
|
if err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyGreeting, err)) |
|
} |
|
|
|
if propMaxMemoryLength, err := tenEnv.GetPropertyInt64(propertyMaxMemoryLength); err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("GetProperty optional %s failed, err: %v", propertyMaxMemoryLength, err)) |
|
} else { |
|
if propMaxMemoryLength > 0 { |
|
maxMemoryLength = int(propMaxMemoryLength) |
|
} |
|
} |
|
|
|
|
|
openaiChatgpt, err := newOpenaiChatGPT(openaiChatGPTConfig) |
|
if err != nil { |
|
tenEnv.LogError(fmt.Sprintf("newOpenaiChatGPT failed, err: %v", err)) |
|
return |
|
} |
|
tenEnv.LogInfo(fmt.Sprintf("newOpenaiChatGPT succeed with max_tokens: %d, model: %s", |
|
openaiChatGPTConfig.MaxTokens, openaiChatGPTConfig.Model)) |
|
|
|
p.openaiChatGPT = openaiChatgpt |
|
|
|
memoryChan = make(chan openai.ChatCompletionMessage, maxMemoryLength*2) |
|
|
|
|
|
if len(greeting) > 0 { |
|
outputData, _ := ten.NewData("text_data") |
|
outputData.SetProperty(dataOutTextDataPropertyText, greeting) |
|
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, true) |
|
if err := tenEnv.SendData(outputData, nil); err != nil { |
|
tenEnv.LogError(fmt.Sprintf("greeting [%s] send failed, err: %v", greeting, err)) |
|
} else { |
|
tenEnv.LogInfo(fmt.Sprintf("greeting [%s] sent", greeting)) |
|
} |
|
} |
|
|
|
tenEnv.OnStartDone() |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
func (p *openaiChatGPTExtension) OnCmd( |
|
tenEnv ten.TenEnv, |
|
cmd ten.Cmd, |
|
) { |
|
cmdName, err := cmd.GetName() |
|
if err != nil { |
|
tenEnv.LogError(fmt.Sprintf("OnCmd get name failed, err: %v", err)) |
|
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError) |
|
tenEnv.ReturnResult(cmdResult, cmd, nil) |
|
return |
|
} |
|
tenEnv.LogInfo(fmt.Sprintf("OnCmd %s", cmdInFlush)) |
|
|
|
switch cmdName { |
|
case cmdInFlush: |
|
outdateTs.Store(time.Now().UnixMicro()) |
|
|
|
wg.Wait() |
|
|
|
|
|
outCmd, err := ten.NewCmd(cmdOutFlush) |
|
if err != nil { |
|
tenEnv.LogError(fmt.Sprintf("new cmd %s failed, err: %v", cmdOutFlush, err)) |
|
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError) |
|
tenEnv.ReturnResult(cmdResult, cmd, nil) |
|
return |
|
} |
|
if err := tenEnv.SendCmd(outCmd, nil); err != nil { |
|
tenEnv.LogError(fmt.Sprintf("send cmd %s failed, err: %v", cmdOutFlush, err)) |
|
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError) |
|
tenEnv.ReturnResult(cmdResult, cmd, nil) |
|
return |
|
} else { |
|
tenEnv.LogInfo(fmt.Sprintf("cmd %s sent", cmdOutFlush)) |
|
} |
|
} |
|
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeOk) |
|
tenEnv.ReturnResult(cmdResult, cmd, nil) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
func (p *openaiChatGPTExtension) OnData( |
|
tenEnv ten.TenEnv, |
|
data ten.Data, |
|
) { |
|
|
|
isFinal, err := data.GetPropertyBool(dataInTextDataPropertyIsFinal) |
|
if err != nil { |
|
tenEnv.LogWarn(fmt.Sprintf("OnData GetProperty %s failed, err: %v", dataInTextDataPropertyIsFinal, err)) |
|
return |
|
} |
|
if !isFinal { |
|
tenEnv.LogDebug("ignore non-final input") |
|
return |
|
} |
|
|
|
|
|
inputText, err := data.GetPropertyString(dataInTextDataPropertyText) |
|
if err != nil { |
|
tenEnv.LogError(fmt.Sprintf("OnData GetProperty %s failed, err: %v", dataInTextDataPropertyText, err)) |
|
return |
|
} |
|
if len(inputText) == 0 { |
|
tenEnv.LogDebug("ignore empty text") |
|
return |
|
} |
|
tenEnv.LogInfo(fmt.Sprintf("OnData input text: [%s]", inputText)) |
|
|
|
|
|
for len(memoryChan) > 0 { |
|
m, ok := <-memoryChan |
|
if !ok { |
|
break |
|
} |
|
memory = append(memory, m) |
|
if len(memory) > maxMemoryLength { |
|
memory = memory[1:] |
|
} |
|
} |
|
memory = append(memory, openai.ChatCompletionMessage{ |
|
Role: openai.ChatMessageRoleUser, |
|
Content: inputText, |
|
}) |
|
if len(memory) > maxMemoryLength { |
|
memory = memory[1:] |
|
} |
|
|
|
|
|
wg.Add(1) |
|
go func(startTime time.Time, inputText string, memory []openai.ChatCompletionMessage) { |
|
defer wg.Done() |
|
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] memory: %v", inputText, memory)) |
|
|
|
|
|
resp, err := p.openaiChatGPT.getChatCompletionsStream(memory) |
|
if err != nil { |
|
tenEnv.LogError(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] failed, err: %v", inputText, err)) |
|
return |
|
} |
|
defer func() { |
|
if resp != nil { |
|
resp.Close() |
|
} |
|
}() |
|
tenEnv.LogDebug(fmt.Sprintf("GetChatCompletionsStream start to recv for input text: [%s]", inputText)) |
|
|
|
var sentence, fullContent string |
|
var firstSentenceSent bool |
|
for { |
|
if startTime.UnixMicro() < outdateTs.Load() { |
|
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream recv interrupt and flushing for input text: [%s], startTs: %d, outdateTs: %d", |
|
inputText, startTime.UnixMicro(), outdateTs.Load())) |
|
break |
|
} |
|
|
|
chatCompletions, err := resp.Recv() |
|
if errors.Is(err, io.EOF) { |
|
tenEnv.LogDebug(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s], io.EOF break", inputText)) |
|
break |
|
} |
|
|
|
var content string |
|
if len(chatCompletions.Choices) > 0 && chatCompletions.Choices[0].Delta.Content != "" { |
|
content = chatCompletions.Choices[0].Delta.Content |
|
} |
|
fullContent += content |
|
|
|
for { |
|
|
|
var sentenceIsFinal bool |
|
sentence, content, sentenceIsFinal = parseSentence(sentence, content) |
|
if len(sentence) == 0 || !sentenceIsFinal { |
|
tenEnv.LogDebug(fmt.Sprintf("sentence %s is empty or not final", sentence)) |
|
break |
|
} |
|
tenEnv.LogDebug(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] got sentence: [%s]", inputText, sentence)) |
|
|
|
|
|
outputData, err := ten.NewData("text_data") |
|
if err != nil { |
|
tenEnv.LogError(fmt.Sprintf("NewData failed, err: %v", err)) |
|
break |
|
} |
|
outputData.SetProperty(dataOutTextDataPropertyText, sentence) |
|
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, false) |
|
if err := tenEnv.SendData(outputData, nil); err != nil { |
|
tenEnv.LogError(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] send sentence [%s] failed, err: %v", inputText, sentence, err)) |
|
break |
|
} else { |
|
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] sent sentence [%s]", inputText, sentence)) |
|
} |
|
sentence = "" |
|
|
|
if !firstSentenceSent { |
|
firstSentenceSent = true |
|
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] first sentence sent, first_sentency_latency %dms", |
|
inputText, time.Since(startTime).Milliseconds())) |
|
} |
|
} |
|
} |
|
|
|
|
|
memoryChan <- openai.ChatCompletionMessage{ |
|
Role: openai.ChatMessageRoleAssistant, |
|
Content: fullContent, |
|
} |
|
|
|
|
|
outputData, _ := ten.NewData("text_data") |
|
outputData.SetProperty(dataOutTextDataPropertyText, sentence) |
|
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, true) |
|
if err := tenEnv.SendData(outputData, nil); err != nil { |
|
tenEnv.LogError(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] end of segment with sentence [%s] send failed, err: %v", inputText, sentence, err)) |
|
} else { |
|
tenEnv.LogInfo(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] end of segment with sentence [%s] sent", inputText, sentence)) |
|
} |
|
}(time.Now(), inputText, append([]openai.ChatCompletionMessage{}, memory...)) |
|
} |
|
|
|
func init() { |
|
|
|
ten.RegisterAddonAsExtension( |
|
"openai_chatgpt", |
|
ten.NewDefaultExtensionAddon(newChatGPTExtension), |
|
) |
|
} |
|
|