nsarrazin HF Staff commited on
Commit
791e118
·
unverified ·
1 Parent(s): 2c61d4a

feat(endpoints): Add conv ID to headers passed to TGI (#1511)

Browse files
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -29,6 +29,7 @@ import endpointLangserve, {
29
  } from "./langserve/endpointLangserve";
30
 
31
  import type { Tool, ToolCall, ToolResult } from "$lib/types/Tool";
 
32
 
33
  export type EndpointMessage = Omit<Message, "id">;
34
 
@@ -41,6 +42,7 @@ export interface EndpointParameters {
41
  tools?: Tool[];
42
  toolResults?: ToolResult[];
43
  isMultimodal?: boolean;
 
44
  }
45
 
46
  interface CommonEndpoint {
 
29
  } from "./langserve/endpointLangserve";
30
 
31
  import type { Tool, ToolCall, ToolResult } from "$lib/types/Tool";
32
+ import type { ObjectId } from "mongodb";
33
 
34
  export type EndpointMessage = Omit<Message, "id">;
35
 
 
42
  tools?: Tool[];
43
  toolResults?: ToolResult[];
44
  isMultimodal?: boolean;
45
+ conversationId?: ObjectId;
46
  }
47
 
48
  interface CommonEndpoint {
src/lib/server/endpoints/openai/endpointOai.ts CHANGED
@@ -149,7 +149,7 @@ export async function endpointOai(
149
  "Tools are not supported for 'completions' mode, switch to 'chat_completions' instead"
150
  );
151
  }
152
- return async ({ messages, preprompt, continueMessage, generateSettings }) => {
153
  const prompt = await buildPrompt({
154
  messages,
155
  continueMessage,
@@ -171,12 +171,22 @@ export async function endpointOai(
171
 
172
  const openAICompletion = await openai.completions.create(body, {
173
  body: { ...body, ...extraBody },
 
 
 
174
  });
175
 
176
  return openAICompletionToTextGenerationStream(openAICompletion);
177
  };
178
  } else if (completion === "chat_completions") {
179
- return async ({ messages, preprompt, generateSettings, tools, toolResults }) => {
 
 
 
 
 
 
 
180
  let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
181
  await prepareMessages(messages, imageProcessor, !model.tools && model.multimodal);
182
 
@@ -240,6 +250,9 @@ export async function endpointOai(
240
 
241
  const openChatAICompletion = await openai.chat.completions.create(body, {
242
  body: { ...body, ...extraBody },
 
 
 
243
  });
244
 
245
  return openAIChatToTextGenerationStream(openChatAICompletion);
 
149
  "Tools are not supported for 'completions' mode, switch to 'chat_completions' instead"
150
  );
151
  }
152
+ return async ({ messages, preprompt, continueMessage, generateSettings, conversationId }) => {
153
  const prompt = await buildPrompt({
154
  messages,
155
  continueMessage,
 
171
 
172
  const openAICompletion = await openai.completions.create(body, {
173
  body: { ...body, ...extraBody },
174
+ headers: {
175
+ "ChatUI-Conversation-ID": conversationId?.toString() ?? "",
176
+ },
177
  });
178
 
179
  return openAICompletionToTextGenerationStream(openAICompletion);
180
  };
181
  } else if (completion === "chat_completions") {
182
+ return async ({
183
+ messages,
184
+ preprompt,
185
+ generateSettings,
186
+ tools,
187
+ toolResults,
188
+ conversationId,
189
+ }) => {
190
  let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
191
  await prepareMessages(messages, imageProcessor, !model.tools && model.multimodal);
192
 
 
250
 
251
  const openChatAICompletion = await openai.chat.completions.create(body, {
252
  body: { ...body, ...extraBody },
253
+ headers: {
254
+ "ChatUI-Conversation-ID": conversationId?.toString() ?? "",
255
+ },
256
  });
257
 
258
  return openAIChatToTextGenerationStream(openChatAICompletion);
src/lib/server/endpoints/tgi/endpointTgi.ts CHANGED
@@ -43,6 +43,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
43
  tools,
44
  toolResults,
45
  isMultimodal,
 
46
  }) => {
47
  const messagesWithResizedFiles = await Promise.all(
48
  messages.map((message) => prepareMessage(Boolean(isMultimodal), message, imageProcessor))
@@ -72,6 +73,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
72
  info.headers = {
73
  ...info.headers,
74
  Authorization: authorization,
 
75
  };
76
  }
77
  return fetch(endpointUrl, info);
 
43
  tools,
44
  toolResults,
45
  isMultimodal,
46
+ conversationId,
47
  }) => {
48
  const messagesWithResizedFiles = await Promise.all(
49
  messages.map((message) => prepareMessage(Boolean(isMultimodal), message, imageProcessor))
 
73
  info.headers = {
74
  ...info.headers,
75
  Authorization: authorization,
76
+ "ChatUI-Conversation-ID": conversationId?.toString() ?? "",
77
  };
78
  }
79
  return fetch(endpointUrl, info);
src/lib/server/textGeneration/generate.ts CHANGED
@@ -18,6 +18,7 @@ export async function* generate(
18
  generateSettings: assistant?.generateSettings,
19
  toolResults,
20
  isMultimodal: model.multimodal,
 
21
  })) {
22
  // text generation completed
23
  if (output.generated_text) {
 
18
  generateSettings: assistant?.generateSettings,
19
  toolResults,
20
  isMultimodal: model.multimodal,
21
+ conversationId: conv._id,
22
  })) {
23
  // text generation completed
24
  if (output.generated_text) {
src/lib/server/textGeneration/tools.ts CHANGED
@@ -196,6 +196,7 @@ export async function* runTools(
196
  type: input.type === "file" ? "str" : input.type,
197
  })),
198
  })),
 
199
  })) {
200
  // model natively supports tool calls
201
  if (output.token.toolCalls) {
 
196
  type: input.type === "file" ? "str" : input.type,
197
  })),
198
  })),
199
+ conversationId: conv._id,
200
  })) {
201
  // model natively supports tool calls
202
  if (output.token.toolCalls) {