Spaces:
Running
Running
from typing import List | |
from openai.types.chat import ChatCompletionMessageParam | |
from transformers import PreTrainedTokenizer | |
from api.generation.utils import parse_messages | |
from api.utils.protocol import Role | |
def build_baichuan_chat_input( | |
tokenizer: PreTrainedTokenizer, | |
messages: List[ChatCompletionMessageParam], | |
context_len: int = 4096, | |
max_new_tokens: int = 256 | |
) -> List[int]: | |
""" | |
Builds the input tokens for the Baichuan chat model based on the given messages. | |
Refs: | |
https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_utils.py | |
Args: | |
tokenizer: The PreTrainedTokenizer object. | |
messages: A list of ChatCompletionMessageParam objects representing the chat messages. | |
context_len: The maximum length of the context (default=4096). | |
max_new_tokens: The maximum number of new tokens to be added (default=256). | |
Returns: | |
List[int]: The input tokens for the Baichuan chat model. | |
""" | |
max_input_tokens = context_len - max_new_tokens | |
system, rounds = parse_messages(messages) | |
system_tokens = tokenizer.encode(system) | |
max_history_tokens = max_input_tokens - len(system_tokens) | |
history_tokens = [] | |
for r in rounds[::-1]: | |
round_tokens = [] | |
for message in r: | |
if message["role"] == Role.USER: | |
round_tokens.append(195) | |
else: | |
round_tokens.append(196) | |
round_tokens.extend(tokenizer.encode(message["content"])) | |
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: | |
history_tokens = round_tokens + history_tokens # concat left | |
if len(history_tokens) < max_history_tokens: | |
continue | |
break | |
input_tokens = system_tokens + history_tokens | |
if messages[-1]["role"] != Role.ASSISTANT: | |
input_tokens.append(196) | |
return input_tokens[-max_input_tokens:] # truncate left | |
def check_is_baichuan(model) -> bool: | |
""" | |
Checks if the given model is a Baichuan model. | |
Args: | |
model: The model to be checked. | |
Returns: | |
bool: True if the model is a Baichuan model, False otherwise. | |
""" | |
return "BaichuanLayer" in getattr(model, "_no_split_modules", []) | |