File size: 4,077 Bytes
b5ae065
 
 
7dbac68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5ae065
 
 
 
 
 
 
 
 
7dbac68
185c2ff
 
b5ae065
 
 
1f8ab3d
185c2ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5ae065
 
 
 
 
 
 
 
0d1a7ca
b5ae065
 
 
 
 
7dbac68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5ae065
 
df0f84c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type OpenAI from "openai";
import type { Stream } from "openai/streaming";
import type { ToolCall } from "$lib/types/Tool";

type ToolCallWithParameters = {
	toolCall: ToolCall;
	parameterJsonString: string;
};

function prepareToolCalls(toolCallsWithParameters: ToolCallWithParameters[], tokenId: number) {
	const toolCalls: ToolCall[] = [];

	for (const toolCallWithParameters of toolCallsWithParameters) {
		// HACK: sometimes gpt4 via azure returns the JSON with literal newlines in it
		// like {\n "foo": "bar" }
		const s = toolCallWithParameters.parameterJsonString.replace("\n", "");
		const params = JSON.parse(s);

		const toolCall = toolCallWithParameters.toolCall;
		for (const name in params) {
			toolCall.parameters[name] = params[name];
		}

		toolCalls.push(toolCall);
	}

	const output = {
		token: {
			id: tokenId,
			text: "",
			logprob: 0,
			special: false,
			toolCalls,
		},
		generated_text: null,
		details: null,
	};

	return output;
}

/**
 * Transform a stream of OpenAI.Chat.ChatCompletion into a stream of TextGenerationStreamOutput
 */
export async function* openAIChatToTextGenerationStream(
	completionStream: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
) {
	let generatedText = "";
	let tokenId = 0;
	const toolCalls: ToolCallWithParameters[] = [];
	let toolBuffer = ""; // XXX: hack because tools seem broken on tgi openai endpoints?

	for await (const completion of completionStream) {
		const { choices } = completion;
		const content = choices[0]?.delta?.content ?? "";
		const last = choices[0]?.finish_reason === "stop" || choices[0]?.finish_reason === "length";

		// if the last token is a stop and the tool buffer is not empty, yield it as a generated_text
		if (choices[0]?.finish_reason === "stop" && toolBuffer.length > 0) {
			yield {
				token: {
					id: tokenId++,
					special: true,
					logprob: 0,
					text: "",
				},
				generated_text: toolBuffer,
				details: null,
			} as TextGenerationStreamOutput;
			break;
		}

		// weird bug where the parameters are streamed in like this
		if (choices[0]?.delta?.tool_calls) {
			const calls = Array.isArray(choices[0].delta.tool_calls)
				? choices[0].delta.tool_calls
				: [choices[0].delta.tool_calls];

			if (
				calls.length === 1 &&
				calls[0].index === 0 &&
				calls[0].id === "" &&
				calls[0].type === "function" &&
				!!calls[0].function &&
				calls[0].function.name === null
			) {
				toolBuffer += calls[0].function.arguments;
				continue;
			}
		}

		if (content) {
			generatedText = generatedText + content;
		}
		const output: TextGenerationStreamOutput = {
			token: {
				id: tokenId++,
				text: content ?? "",
				logprob: 0,
				special: last,
			},
			generated_text: last ? generatedText : null,
			details: null,
		};
		yield output;

		const tools = completion.choices[0]?.delta?.tool_calls || [];
		for (const tool of tools) {
			if (tool.id) {
				if (!tool.function?.name) {
					throw new Error("Tool call without function name");
				}
				const toolCallWithParameters: ToolCallWithParameters = {
					toolCall: {
						name: tool.function.name,
						parameters: {},
					},
					parameterJsonString: "",
				};
				toolCalls.push(toolCallWithParameters);
			}

			if (toolCalls.length > 0 && tool.function?.arguments) {
				toolCalls[toolCalls.length - 1].parameterJsonString += tool.function.arguments;
			}
		}

		if (choices[0]?.finish_reason === "tool_calls") {
			yield prepareToolCalls(toolCalls, tokenId++);
		}
	}
}

/**
 * Transform a non-streaming OpenAI chat completion into a stream of TextGenerationStreamOutput
 */
export async function* openAIChatToTextGenerationSingle(
	completion: OpenAI.Chat.Completions.ChatCompletion
) {
	const content = completion.choices[0]?.message?.content || "";
	const tokenId = 0;

	// Yield the content as a single token
	yield {
		token: {
			id: tokenId,
			text: content,
			logprob: 0,
			special: false,
		},
		generated_text: content,
		details: null,
	} as TextGenerationStreamOutput;
}