File size: 7,862 Bytes
c41b22c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e74c01b
c41b22c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import torch
from qwen_vl_utils import process_vision_info
from transformers import (
    AutoProcessor,
    Qwen2VLForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration,
)
from torchvision.transforms import ToPILImage

to_pil = ToPILImage()

Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
Here are examples of how to transform or refine prompts:
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
User Prompt:'''


def split_string(s):
    # 将中文引号替换为英文引号
    s = s.replace("“", '"').replace("”", '"')  # use english quotes
    result = []
    # 标记是否在引号内
    in_quotes = False
    temp = ""

    # 遍历字符串中的每个字符及其索引
    for idx, char in enumerate(s):
        # 如果字符是引号且索引大于 155
        if char == '"' and idx > 155:
            # 将引号添加到临时字符串
            temp += char
            # 如果不在引号内
            if not in_quotes:
                # 将临时字符串添加到结果列表
                result.append(temp)
                # 清空临时字符串
                temp = ""

            # 切换引号状态
            in_quotes = not in_quotes
            continue
        # 如果在引号内
        if in_quotes:
            # 如果字符是空格
            if char.isspace():
                pass  # have space token

            # 将字符用中文引号包裹后添加到结果列表
            result.append("“" + char + "”")
        else:
            # 将字符添加到临时字符串
            temp += char

    # 如果临时字符串不为空
    if temp:
        # 将临时字符串添加到结果列表
        result.append(temp)

    return result


class Qwen25VL_7b_Embedder(torch.nn.Module):
    def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
        super(Qwen25VL_7b_Embedder, self).__init__()
        self.max_length = max_length
        self.dtype = dtype
        self.device = device

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype=dtype,
            attn_implementation="eager",
        ).to(torch.cuda.current_device())

        self.model.requires_grad_(False)
        self.processor = AutoProcessor.from_pretrained(
            model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
        )

        self.prefix = Qwen25VL_7b_PREFIX

    def forward(self, caption, ref_images):
        text_list = caption
        embs = torch.zeros(
            len(text_list),
            self.max_length,
            self.model.config.hidden_size,
            dtype=torch.bfloat16,
            device=torch.cuda.current_device(),
        )
        hidden_states = torch.zeros(
            len(text_list),
            self.max_length,
            self.model.config.hidden_size,
            dtype=torch.bfloat16,
            device=torch.cuda.current_device(),
        )
        masks = torch.zeros(
            len(text_list),
            self.max_length,
            dtype=torch.long,
            device=torch.cuda.current_device(),
        )
        input_ids_list = []
        attention_mask_list = []
        emb_list = []

        def split_string(s):
            s = s.replace("“", '"').replace("”", '"').replace("'", '''"''')  # use english quotes
            result = []
            in_quotes = False
            temp = ""

            for idx,char in enumerate(s):
                if char == '"' and idx>155:
                    temp += char
                    if not in_quotes:
                        result.append(temp)
                        temp = ""

                    in_quotes = not in_quotes
                    continue
                if in_quotes:
                    if char.isspace():
                        pass  # have space token

                    result.append("“" + char + "”")
                else:
                    temp += char

            if temp:
                result.append(temp)

            return result

        for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):

            messages = [{"role": "user", "content": []}]

            messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})

            messages[0]["content"].append({"type": "image", "image": to_pil(imgs)})

            # 再添加 text
            messages[0]["content"].append({"type": "text", "text": f"{txt}"})

            # Preparation for inference
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
            )

            image_inputs, video_inputs = process_vision_info(messages)

            inputs = self.processor(
                text=[text],
                images=image_inputs,
                padding=True,
                return_tensors="pt",
            )

            old_inputs_ids = inputs.input_ids
            text_split_list = split_string(text)

            token_list = []
            for text_each in text_split_list:
                txt_inputs = self.processor(
                    text=text_each,
                    images=None,
                    videos=None,
                    padding=True,
                    return_tensors="pt",
                )
                token_each = txt_inputs.input_ids
                if token_each[0][0] == 2073 and token_each[0][-1] == 854:
                    token_each = token_each[:, 1:-1]
                    token_list.append(token_each)
                else:
                    token_list.append(token_each)

            new_txt_ids = torch.cat(token_list, dim=1).to("cuda")

            new_txt_ids = new_txt_ids.to(old_inputs_ids.device)

            idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
            idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
            inputs.input_ids = (
                torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
                .unsqueeze(0)
                .to("cuda")
            )
            inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
            outputs = self.model(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                pixel_values=inputs.pixel_values.to("cuda"),
                image_grid_thw=inputs.image_grid_thw.to("cuda"),
                output_hidden_states=True,
            )

            emb = outputs["hidden_states"][-1]

            embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
                : self.max_length
            ]

            masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
                (min(self.max_length, emb.shape[1] - 217)),
                dtype=torch.long,
                device=torch.cuda.current_device(),
            )

        return embs, masks