LPX55 commited on
Commit
803363f
·
verified ·
1 Parent(s): f7ef47a

Create mini.py

Browse files
Files changed (1) hide show
  1. mini.py +436 -0
mini.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from PIL import Image
5
+ import os
6
+ from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast
7
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
8
+ from flux.transformer_flux import FluxTransformer2DModel
9
+ from flux.pipeline_flux_chameleon import FluxPipeline
10
+ import torch.nn as nn
11
+ import math
12
+ import logging
13
+ import sys
14
+ from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
+ from huggingface_hub import snapshot_download
16
+
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+
19
+ # Set up logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(levelname)s - %(message)s',
23
+ handlers=[logging.StreamHandler(sys.stdout)]
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ MODEL_ID = "Djrango/Qwen2vl-Flux"
28
+ MODEL_CACHE_DIR = "model_cache"
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+ DTYPE = torch.bfloat16
31
+
32
+ # Aspect ratio options
33
+ ASPECT_RATIOS = {
34
+ "1:1": (1024, 1024),
35
+ "16:9": (1344, 768),
36
+ "9:16": (768, 1344),
37
+ "2.4:1": (1536, 640),
38
+ "3:4": (896, 1152),
39
+ "4:3": (1152, 896),
40
+ }
41
+
42
+ class Qwen2Connector(nn.Module):
43
+ def __init__(self, input_dim=3584, output_dim=4096):
44
+ super().__init__()
45
+ self.linear = nn.Linear(input_dim, output_dim)
46
+
47
+ def forward(self, x):
48
+ return self.linear(x)
49
+
50
+ # Download models if not present
51
+ if not os.path.exists(MODEL_CACHE_DIR):
52
+ logger.info("Starting model download...")
53
+ try:
54
+ snapshot_download(
55
+ repo_id=MODEL_ID,
56
+ local_dir=MODEL_CACHE_DIR,
57
+ local_dir_use_symlinks=False,
58
+ token=HF_TOKEN
59
+ )
60
+ logger.info("Model download completed successfully")
61
+ except Exception as e:
62
+ logger.error(f"Error downloading models: {str(e)}")
63
+ raise
64
+
65
+ # Initialize models in global context
66
+ logger.info("Starting model loading...")
67
+
68
+ # Load smaller models to GPU
69
+ tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
70
+ text_encoder = CLIPTextModel.from_pretrained(
71
+ os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
72
+ ).to(DTYPE).to(DEVICE)
73
+
74
+ text_encoder_two = T5EncoderModel.from_pretrained(
75
+ os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
76
+ ).to(DTYPE).to(DEVICE)
77
+
78
+ tokenizer_two = T5TokenizerFast.from_pretrained(
79
+ os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2")
80
+ )
81
+
82
+ # Load larger models to CPU
83
+ vae = AutoencoderKL.from_pretrained(
84
+ os.path.join(MODEL_CACHE_DIR, "flux/vae")
85
+ ).to(DTYPE).cpu()
86
+
87
+ transformer = FluxTransformer2DModel.from_pretrained(
88
+ os.path.join(MODEL_CACHE_DIR, "flux/transformer")
89
+ ).to(DTYPE).cpu()
90
+
91
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
92
+ os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
93
+ shift=1
94
+ )
95
+
96
+ # Load Qwen2VL to CPU
97
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
98
+ os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
99
+ ).to(DTYPE).cpu()
100
+
101
+ # Load connector and embedder
102
+ connector = Qwen2Connector().to(DTYPE).cpu()
103
+ connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
104
+ connector_state = torch.load(connector_path, map_location='cpu')
105
+ connector_state = {k.replace('module.', ''): v.to(DTYPE) for k, v in connector_state.items()}
106
+ connector.load_state_dict(connector_state)
107
+
108
+ t5_context_embedder = nn.Linear(4096, 3072).to(DTYPE).cpu()
109
+ t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
110
+ t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
111
+ t5_embedder_state = {k: v.to(DTYPE) for k, v in t5_embedder_state.items()}
112
+ t5_context_embedder.load_state_dict(t5_embedder_state)
113
+
114
+ # Set all models to eval mode
115
+ for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, t5_context_embedder]:
116
+ model.requires_grad_(False)
117
+ model.eval()
118
+
119
+ logger.info("All models loaded successfully")
120
+
121
+ # Initialize processors and pipeline
122
+ qwen2vl_processor = AutoProcessor.from_pretrained(
123
+ MODEL_ID,
124
+ subfolder="qwen2-vl",
125
+ min_pixels=256*28*28,
126
+ max_pixels=256*28*28
127
+ )
128
+
129
+ pipeline = FluxPipeline(
130
+ transformer=transformer,
131
+ scheduler=scheduler,
132
+ vae=vae,
133
+ text_encoder=text_encoder,
134
+ tokenizer=tokenizer,
135
+ )
136
+
137
+ def process_image(image):
138
+ """Process image with Qwen2VL model"""
139
+ try:
140
+ # Move Qwen2VL models to GPU
141
+ logger.info("Moving Qwen2VL models to GPU...")
142
+ qwen2vl.to(DEVICE)
143
+ connector.to(DEVICE)
144
+
145
+ message = [
146
+ {
147
+ "role": "user",
148
+ "content": [
149
+ {"type": "image", "image": image},
150
+ {"type": "text", "text": "Describe this image."},
151
+ ]
152
+ }
153
+ ]
154
+ text = qwen2vl_processor.apply_chat_template(
155
+ message,
156
+ tokenize=False,
157
+ add_generation_prompt=True
158
+ )
159
+
160
+ with torch.no_grad():
161
+ inputs = qwen2vl_processor(
162
+ text=[text],
163
+ images=[image],
164
+ padding=True,
165
+ return_tensors="pt"
166
+ ).to(DEVICE)
167
+
168
+ output_hidden_state, image_token_mask, image_grid_thw = qwen2vl(**inputs)
169
+ image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
170
+ image_hidden_state = connector(image_hidden_state)
171
+
172
+ result = (image_hidden_state.cpu(), image_grid_thw)
173
+
174
+ # Move models back to CPU
175
+ qwen2vl.cpu()
176
+ connector.cpu()
177
+ torch.cuda.empty_cache()
178
+
179
+ return result
180
+
181
+ except Exception as e:
182
+ logger.error(f"Error in process_image: {str(e)}")
183
+ raise
184
+
185
+ def resize_image(img, max_pixels=1050000):
186
+ if not isinstance(img, Image.Image):
187
+ img = Image.fromarray(img)
188
+
189
+ width, height = img.size
190
+ num_pixels = width * height
191
+
192
+ if num_pixels > max_pixels:
193
+ scale = math.sqrt(max_pixels / num_pixels)
194
+ new_width = int(width * scale)
195
+ new_height = int(height * scale)
196
+ new_width = new_width - (new_width % 8)
197
+ new_height = new_height - (new_height % 8)
198
+ img = img.resize((new_width, new_height), Image.LANCZOS)
199
+
200
+ return img
201
+
202
+ def compute_t5_text_embeddings(prompt):
203
+ """Compute T5 embeddings for text prompt"""
204
+ if prompt == "":
205
+ return None
206
+
207
+ text_inputs = tokenizer_two(
208
+ prompt,
209
+ padding="max_length",
210
+ max_length=256,
211
+ truncation=True,
212
+ return_tensors="pt"
213
+ ).to(DEVICE)
214
+
215
+ prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
216
+ prompt_embeds = t5_context_embedder.to(DEVICE)(prompt_embeds)
217
+ t5_context_embedder.cpu()
218
+
219
+ return prompt_embeds
220
+
221
+ def compute_text_embeddings(prompt=""):
222
+ with torch.no_grad():
223
+ text_inputs = tokenizer(
224
+ prompt,
225
+ padding="max_length",
226
+ max_length=77,
227
+ truncation=True,
228
+ return_tensors="pt"
229
+ ).to(DEVICE)
230
+
231
+ prompt_embeds = text_encoder(
232
+ text_inputs.input_ids,
233
+ output_hidden_states=False
234
+ )
235
+ pooled_prompt_embeds = prompt_embeds.pooler_output
236
+ return pooled_prompt_embeds
237
+
238
+ @spaces.GPU(duration=75)
239
+ def generate(input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1", progress=gr.Progress(track_tqdm=True)):
240
+ try:
241
+ logger.info(f"Starting generation with prompt: {prompt}")
242
+
243
+ if input_image is None:
244
+ raise ValueError("No input image provided")
245
+
246
+ if seed is not None:
247
+ torch.manual_seed(seed)
248
+ logger.info(f"Set random seed to: {seed}")
249
+
250
+ # Process image with Qwen2VL
251
+ logger.info("Processing input image with Qwen2VL...")
252
+ qwen2_hidden_state, image_grid_thw = process_image(input_image)
253
+ logger.info("Image processing completed")
254
+
255
+ # Compute text embeddings
256
+ logger.info("Computing text embeddings...")
257
+ pooled_prompt_embeds = compute_text_embeddings(prompt)
258
+ t5_prompt_embeds = compute_t5_text_embeddings(prompt)
259
+ logger.info("Text embeddings computed")
260
+
261
+ # Move Transformer and VAE to GPU
262
+ logger.info("Moving Transformer and VAE to GPU...")
263
+ transformer.to(DEVICE)
264
+ vae.to(DEVICE)
265
+
266
+ # Update pipeline models
267
+ pipeline.transformer = transformer
268
+ pipeline.vae = vae
269
+ logger.info("Models moved to GPU")
270
+
271
+ # Get dimensions
272
+ width, height = ASPECT_RATIOS[aspect_ratio]
273
+ logger.info(f"Using dimensions: {width}x{height}")
274
+
275
+ try:
276
+ logger.info("Starting image generation...")
277
+ output_images = pipeline(
278
+ prompt_embeds=qwen2_hidden_state.to(DEVICE).repeat(num_images, 1, 1),
279
+ pooled_prompt_embeds=pooled_prompt_embeds,
280
+ t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
281
+ num_inference_steps=num_inference_steps,
282
+ guidance_scale=guidance_scale,
283
+ height=height,
284
+ width=width,
285
+ ).images
286
+ logger.info("Image generation completed")
287
+
288
+ return output_images
289
+
290
+ except Exception as e:
291
+ raise RuntimeError(f"Error generating images: {str(e)}")
292
+
293
+ except Exception as e:
294
+ logger.error(f"Error during generation: {str(e)}")
295
+ raise gr.Error(f"Generation failed: {str(e)}")
296
+
297
+ # Create Gradio interface
298
+ with gr.Blocks(
299
+ theme=gr.themes.Soft(),
300
+ css="""
301
+ .container {
302
+ max-width: 1200px;
303
+ margin: auto;
304
+ }
305
+ .header {
306
+ text-align: center;
307
+ margin: 20px 0 40px 0;
308
+ padding: 20px;
309
+ background: #f7f7f7;
310
+ border-radius: 12px;
311
+ }
312
+ .param-row {
313
+ padding: 10px 0;
314
+ }
315
+ footer {
316
+ margin-top: 40px;
317
+ padding: 20px;
318
+ border-top: 1px solid #eee;
319
+ }
320
+ """
321
+ ) as demo:
322
+ with gr.Column(elem_classes="container"):
323
+ gr.Markdown(
324
+ """# 🎨 Qwen2vl-Flux Image Variation Demo
325
+ Generate creative variations of your images with optional text guidance"""
326
+ )
327
+
328
+ with gr.Row(equal_height=True):
329
+ with gr.Column(scale=1):
330
+ input_image = gr.Image(
331
+ label="Upload Your Image",
332
+ type="pil",
333
+ height=384,
334
+ sources=["upload", "clipboard"]
335
+ )
336
+ prompt = gr.Textbox(
337
+ label="Text Prompt (Optional)",
338
+ placeholder="As Long As Possible...",
339
+ lines=3
340
+ )
341
+ with gr.Accordion("Advanced Settings", open=False):
342
+ with gr.Group():
343
+
344
+ with gr.Row(elem_classes="param-row"):
345
+ guidance = gr.Slider(
346
+ minimum=1,
347
+ maximum=10,
348
+ value=3.5,
349
+ step=0.5,
350
+ label="Guidance Scale",
351
+ info="Higher values follow prompt more closely"
352
+ )
353
+ steps = gr.Slider(
354
+ minimum=1,
355
+ maximum=50,
356
+ value=28,
357
+ step=1,
358
+ label="Sampling Steps",
359
+ info="More steps = better quality but slower"
360
+ )
361
+
362
+ with gr.Row(elem_classes="param-row"):
363
+ num_images = gr.Slider(
364
+ minimum=1,
365
+ maximum=4,
366
+ value=1,
367
+ step=1,
368
+ label="Number of Images",
369
+ info="Generate multiple variations at once"
370
+ )
371
+ seed = gr.Number(
372
+ label="Random Seed",
373
+ value=None,
374
+ precision=0,
375
+ info="Set for reproducible results"
376
+ )
377
+ aspect_ratio = gr.Radio(
378
+ label="Aspect Ratio",
379
+ choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"],
380
+ value="1:1",
381
+ info="Choose aspect ratio for generated images"
382
+ )
383
+
384
+ submit_btn = gr.Button(
385
+ "🎨 Generate Variations",
386
+ variant="primary",
387
+ size="lg"
388
+ )
389
+
390
+ with gr.Column(scale=1):
391
+ # Output Section
392
+ output_gallery = gr.Gallery(
393
+ label="Generated Variations",
394
+ columns=2,
395
+ rows=2,
396
+ height=700,
397
+ object_fit="contain",
398
+ show_label=True,
399
+ allow_preview=True,
400
+ preview=True
401
+ )
402
+ error_message = gr.Textbox(visible=False)
403
+
404
+ with gr.Row(elem_classes="footer"):
405
+ gr.Markdown("""
406
+ ### Tips:
407
+ - 📸 Upload any image to get started
408
+ - 💡 Add an optional text prompt to guide the generation
409
+ - 🎯 Adjust guidance scale to control prompt influence
410
+ - ⚙️ Increase steps for higher quality
411
+ - 🎲 Use seeds for reproducible results
412
+ """)
413
+
414
+ submit_btn.click(
415
+ fn=generate,
416
+ inputs=[
417
+ input_image,
418
+ prompt,
419
+ guidance,
420
+ steps,
421
+ num_images,
422
+ seed,
423
+ aspect_ratio
424
+ ],
425
+ outputs=[output_gallery],
426
+ show_progress=True
427
+ )
428
+
429
+ # Launch the app
430
+ if __name__ == "__main__":
431
+ demo.launch(
432
+ server_name="0.0.0.0", # Listen on all network interfaces
433
+ server_port=7860, # Use a specific port
434
+ share=False, # Disable public URL sharing
435
+ ssr_mode=False # Fixes bug for some users
436
+ )