Spaces:
Paused
Paused
import torch | |
from qwen_vl_utils import process_vision_info | |
from transformers import ( | |
AutoProcessor, | |
Qwen2VLForConditionalGeneration, | |
Qwen2_5_VLForConditionalGeneration, | |
) | |
from torchvision.transforms import ToPILImage | |
to_pil = ToPILImage() | |
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: | |
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. | |
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n | |
Here are examples of how to transform or refine prompts: | |
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. | |
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n | |
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: | |
User Prompt:''' | |
def split_string(s): | |
# 将中文引号替换为英文引号 | |
s = s.replace("“", '"').replace("”", '"') # use english quotes | |
result = [] | |
# 标记是否在引号内 | |
in_quotes = False | |
temp = "" | |
# 遍历字符串中的每个字符及其索引 | |
for idx, char in enumerate(s): | |
# 如果字符是引号且索引大于 155 | |
if char == '"' and idx > 155: | |
# 将引号添加到临时字符串 | |
temp += char | |
# 如果不在引号内 | |
if not in_quotes: | |
# 将临时字符串添加到结果列表 | |
result.append(temp) | |
# 清空临时字符串 | |
temp = "" | |
# 切换引号状态 | |
in_quotes = not in_quotes | |
continue | |
# 如果在引号内 | |
if in_quotes: | |
# 如果字符是空格 | |
if char.isspace(): | |
pass # have space token | |
# 将字符用中文引号包裹后添加到结果列表 | |
result.append("“" + char + "”") | |
else: | |
# 将字符添加到临时字符串 | |
temp += char | |
# 如果临时字符串不为空 | |
if temp: | |
# 将临时字符串添加到结果列表 | |
result.append(temp) | |
return result | |
class Qwen25VL_7b_Embedder(torch.nn.Module): | |
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"): | |
super(Qwen25VL_7b_Embedder, self).__init__() | |
self.max_length = max_length | |
self.dtype = dtype | |
self.device = device | |
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
model_path, | |
torch_dtype=dtype, | |
attn_implementation="eager", | |
).to(torch.cuda.current_device()) | |
self.model.requires_grad_(False) | |
self.processor = AutoProcessor.from_pretrained( | |
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28 | |
) | |
self.prefix = Qwen25VL_7b_PREFIX | |
def forward(self, caption, ref_images): | |
text_list = caption | |
embs = torch.zeros( | |
len(text_list), | |
self.max_length, | |
self.model.config.hidden_size, | |
dtype=torch.bfloat16, | |
device=torch.cuda.current_device(), | |
) | |
hidden_states = torch.zeros( | |
len(text_list), | |
self.max_length, | |
self.model.config.hidden_size, | |
dtype=torch.bfloat16, | |
device=torch.cuda.current_device(), | |
) | |
masks = torch.zeros( | |
len(text_list), | |
self.max_length, | |
dtype=torch.long, | |
device=torch.cuda.current_device(), | |
) | |
input_ids_list = [] | |
attention_mask_list = [] | |
emb_list = [] | |
def split_string(s): | |
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes | |
result = [] | |
in_quotes = False | |
temp = "" | |
for idx,char in enumerate(s): | |
if char == '"' and idx>155: | |
temp += char | |
if not in_quotes: | |
result.append(temp) | |
temp = "" | |
in_quotes = not in_quotes | |
continue | |
if in_quotes: | |
if char.isspace(): | |
pass # have space token | |
result.append("“" + char + "”") | |
else: | |
temp += char | |
if temp: | |
result.append(temp) | |
return result | |
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)): | |
messages = [{"role": "user", "content": []}] | |
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"}) | |
messages[0]["content"].append({"type": "image", "image": to_pil(imgs)}) | |
# 再添加 text | |
messages[0]["content"].append({"type": "text", "text": f"{txt}"}) | |
# Preparation for inference | |
text = self.processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = self.processor( | |
text=[text], | |
images=image_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
old_inputs_ids = inputs.input_ids | |
text_split_list = split_string(text) | |
token_list = [] | |
for text_each in text_split_list: | |
txt_inputs = self.processor( | |
text=text_each, | |
images=None, | |
videos=None, | |
padding=True, | |
return_tensors="pt", | |
) | |
token_each = txt_inputs.input_ids | |
if token_each[0][0] == 2073 and token_each[0][-1] == 854: | |
token_each = token_each[:, 1:-1] | |
token_list.append(token_each) | |
else: | |
token_list.append(token_each) | |
new_txt_ids = torch.cat(token_list, dim=1).to("cuda") | |
new_txt_ids = new_txt_ids.to(old_inputs_ids.device) | |
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0] | |
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0] | |
inputs.input_ids = ( | |
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0) | |
.unsqueeze(0) | |
.to("cuda") | |
) | |
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda") | |
outputs = self.model( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
pixel_values=inputs.pixel_values.to("cuda"), | |
image_grid_thw=inputs.image_grid_thw.to("cuda"), | |
output_hidden_states=True, | |
) | |
emb = outputs["hidden_states"][-1] | |
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][ | |
: self.max_length | |
] | |
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones( | |
(min(self.max_length, emb.shape[1] - 217)), | |
dtype=torch.long, | |
device=torch.cuda.current_device(), | |
) | |
return embs, masks |