Eagle2-Demo / eagle_vl /serve /inference.py
liuguilin's picture
update
4b798bc
import logging
import re
from threading import Thread
from typing import List, Optional
import os
import torch
from transformers import (
AutoModel,
AutoProcessor,
AutoConfig,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
from PIL import Image
from .chat_utils import Conversation, get_conv_template
logger = logging.getLogger(__name__)
def load_model_from_nv(model_path: str = "nvidia/Eagle-2-8B"):
token = os.environ.get("HF_TOKEN")
# hotfix the model to use flash attention 2
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, token=token)
config._attn_implementation = "flash_attention_2"
config.vision_config._attn_implementation = "flash_attention_2"
config.text_config._attn_implementation = "flash_attention_2"
print("Successfully set the attn_implementation to flash_attention_2")
logger.info(f"token = {token[:4]}***{token[-2:]}")
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
token=token
)
model.to("cuda")
processor = AutoProcessor.from_pretrained(model_path, config=config, trust_remote_code=True, use_fast=True, token=token)
return model, processor
def load_model_from_eagle(model_path: str = "NVEagle/Eagle2-8B"):
token = os.environ.get("HF_TOKEN")
logger.info(f"token = {token[:4]}***{token[-2:]}")
# hotfix the model to use flash attention 2
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, token=token)
config._attn_implementation = "flash_attention_2"
config.vision_config._attn_implementation = "flash_attention_2"
config.text_config._attn_implementation = "flash_attention_2"
print("Successfully set the attn_implementation to flash_attention_2")
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
token=token
)
model.to("cuda")
processor = AutoProcessor.from_pretrained(model_path, config=config, trust_remote_code=True, use_fast=True, token=token)
return model, processor
def load_model(model_path: str = "nvidia/Eagle2-8B"):
try:
model, processor = load_model_from_nv(model_path)
except Exception as e:
logger.error(f"Failed to load model from HF, trying to load from eagle: {e}")
model, processor = load_model_from_eagle()
return model, processor
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
for stop in self.stops:
if input_ids.shape[-1] < len(stop):
continue
if torch.all((stop == input_ids[0][-len(stop) :])).item():
return True
return False
def preprocess(
messages: list[dict],
processor,
video_nframes: int = 16,
):
"""
Build messages from the conversations and images.
"""
# get images from conversations
results = [
{
"role": "system",
"content": """You are Eagle 2, a cutting-edge large language model developed by NVIDIA. You are highly capable, efficient, and aligned, specialized in understanding complex multimodal inputs and providing expert-level responses across domains.
Always be concise, accurate, and helpful. You respond like a reliable co-pilot to researchers, developers, and engineers, offering deep technical insight, step-by-step reasoning, and practical suggestions.
You can interpret long contexts, follow nuanced instructions, and dynamically adjust your tone to match the user's intent. If the user does not specify a tone, default to a professional, technical, yet friendly style.
You understand you are Eagle 2, and may refer to yourself as such when asked."""}
]
# get texts from conversations
# converstion = get_conv_template(sft_format)
# only use the last 3 round of messages
# latest_messages = messages[-3:]
all_images_num = 0
for mid, message in enumerate(messages):
if message["role"] == "user":
record = {
"role": message["role"],
"content": [],
}
if "images" in message:
per_round_images = message["images"]
for image in per_round_images:
if isinstance(image, Image.Image) and all_images_num < 128:
record["content"].append(
{
"type": "image",
"image": image,
}
)
all_images_num+=1
elif isinstance(image, str) and image.endswith((".jpeg", ".jpg", ".png", ".gif")) and all_images_num < 128:
record["content"].append(
{
"type": "image",
"image": image,
}
)
all_images_num+=1
elif isinstance(image, str) and image.endswith((".mp4", ".mov", ".avi", ".webm")) and all_images_num < 128-video_nframes:
record["content"].append(
{
"type": "video",
"video": image,
"nframes": video_nframes,
}
)
all_images_num+=video_nframes
if 'content' in message:
record["content"].append(
{
"type": "text",
"text": str(message["content"]).strip(),
}
)
results.append(record)
elif message["role"] == "assistant":
formatted_answer = message["content"].strip()
# ◁think▷用户说了“你好”,这是一个非常简单的问候,通常用于开启对话。我需要判断用户的意图。可能性一:用户只是礼貌性地打招呼,想要开启一段对话;可能性二:用户可能有更具体的需求,比如询问我的功能、功能或者需要帮助。由于用户没有提供更多信息,我需要保持开放,同时引导用户进一步说明他们的需求。
# 我的回复需要既友好又开放,不能显得过于正式或冷漠。同时,我需要避免假设用户的具体需求,而是提供一个轻松的、鼓励继续对话的回应。◁/think▷你好!很高兴见到你。有什么我可以帮助你的吗
# delete all the texts between ◁think▷ and ◁/think▷
# FIXME: this is a hack to remove the thinking texts
# formatted_answer = re.sub(r"◁think▷.*◁/think▷", "", formatted_answer)
think_end_token = '◁/think▷'
formatted_answer = formatted_answer.split(think_end_token)[-1]
results.append(
{
"role": message["role"],
"content": [
{
"type": "text",
"text": formatted_answer,
}
],
}
)
assert (
formatted_answer.count(processor.image_token) == 0
), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
# print(f"messages = {results}")
text = processor.apply_chat_template(results, add_generation_prompt=False)
# print(f"raw text = {text}")
image_inputs, video_inputs, video_kwargs = processor.process_vision_info(results, return_video_kwargs=True)
inputs = processor(
images=image_inputs,
videos=video_inputs,
text=[text],
return_tensors="pt",
padding=True,
truncation=True,
videos_kwargs=video_kwargs,
)
return inputs
@torch.no_grad()
@torch.inference_mode()
def eagle_vl_generate(
model: torch.nn.Module,
processor: AutoProcessor,
conversations: list[Conversation],
stop_words: list,
max_length: int = 256,
temperature: float = 1.0,
top_p: float = 1.0,
chunk_size: int = -1,
video_nframes: int = 16,
):
# convert conversation to inputs
print(f"conversations = {conversations}")
inputs = preprocess(conversations, processor=processor, video_nframes=video_nframes)
inputs = inputs.to(model.device)
return generate(
model,
processor,
inputs,
max_gen_len=max_length,
temperature=temperature,
top_p=top_p,
stop_words=stop_words,
chunk_size=chunk_size,
)
def generate(
model,
processor,
inputs,
max_gen_len: int = 256,
temperature: float = 0,
top_p: float = 0.95,
stop_words: List[str] = [],
chunk_size: int = -1
):
"""Stream the text output from the multimodality model with prompt and image inputs."""
tokenizer = processor.tokenizer
stop_words_ids = [torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
kwargs = dict(
**inputs,
max_new_tokens=max_gen_len,
do_sample=True,
streamer=streamer,
stopping_criteria=stopping_criteria,
)
if temperature > 0:
kwargs.update(
{
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
}
)
else:
kwargs["do_sample"] = False
thread = Thread(target=model.generate, kwargs=kwargs)
thread.start()
yield from streamer