Spaces:
Sleeping
Sleeping
File size: 1,635 Bytes
564e576 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import type { ToolResult } from "$lib/types/Tool";
import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate";
import { AbortedGenerations } from "../abortedGenerations";
import type { TextGenerationContext } from "./types";
import type { EndpointMessage } from "../endpoints/endpoints";
type GenerateContext = Omit<TextGenerationContext, "messages"> & { messages: EndpointMessage[] };
export async function* generate(
{ model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext,
toolResults: ToolResult[],
preprompt?: string
): AsyncIterable<MessageUpdate> {
for await (const output of await endpoint({
messages,
preprompt,
continueMessage: isContinue,
generateSettings: assistant?.generateSettings,
toolResults,
})) {
// text generation completed
if (output.generated_text) {
let interrupted =
!output.token.special && !model.parameters.stop?.includes(output.token.text);
let text = output.generated_text.trimEnd();
for (const stopToken of model.parameters.stop ?? []) {
if (!text.endsWith(stopToken)) continue;
interrupted = false;
text = text.slice(0, text.length - stopToken.length);
}
yield { type: MessageUpdateType.FinalAnswer, text, interrupted };
continue;
}
// ignore special tokens
if (output.token.special) continue;
// pass down normal token
yield { type: MessageUpdateType.Stream, token: output.token.text };
// abort check
const date = AbortedGenerations.getInstance().getList().get(conv._id.toString());
if (date && date > promptedAt) break;
// no output check
if (!output) break;
}
}
|