prithivMLmods commited on
Commit
098da17
·
verified ·
1 Parent(s): 7d378b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +525 -319
app.py CHANGED
@@ -1,329 +1,535 @@
1
  import os
2
- import random
3
- import uuid
4
  import json
 
5
  import time
6
- import asyncio
7
- import re
8
- from threading import Thread
 
9
 
10
- import gradio as gr
11
- import spaces
12
  import torch
13
- import numpy as np
14
  from PIL import Image
15
- import cv2
16
-
17
- from transformers import (
18
- AutoProcessor,
19
- Gemma3ForConditionalGeneration,
20
- Qwen2VLForConditionalGeneration,
21
- TextIteratorStreamer,
22
- )
23
- from transformers.image_utils import load_image
24
-
25
- # Constants
26
- MAX_MAX_NEW_TOKENS = 2048
27
- DEFAULT_MAX_NEW_TOKENS = 1024
28
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
- MAX_SEED = np.iinfo(np.int32).max
30
-
31
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
-
33
- # Helper function to return a progress bar HTML snippet.
34
- def progress_bar_html(label: str) -> str:
35
- return f'''
36
- <div style="display: flex; align-items: center;">
37
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
38
- <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
39
- <div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div>
40
- </div>
41
- </div>
42
- <style>
43
- @keyframes loading {{
44
- 0% {{ transform: translateX(-100%); }}
45
- 100% {{ transform: translateX(100%); }}
46
- }}
47
- </style>
48
- '''
49
-
50
- # Qwen2-VL (for optional image inference)
51
-
52
- MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
53
- processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
54
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
55
- MODEL_ID_VL,
56
- trust_remote_code=True,
57
- torch_dtype=torch.float16
58
- ).to("cuda").eval()
59
-
60
- def clean_chat_history(chat_history):
61
- cleaned = []
62
- for msg in chat_history:
63
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
64
- cleaned.append(msg)
65
- return cleaned
66
-
67
- bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
68
- bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
69
- default_negative = os.getenv("default_negative", "")
70
-
71
- def check_text(prompt, negative=""):
72
- for i in bad_words:
73
- if i in prompt:
74
- return True
75
- for i in bad_words_negative:
76
- if i in negative:
77
- return True
78
- return False
79
-
80
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
81
- if randomize_seed:
82
- seed = random.randint(0, MAX_SEED)
83
- return seed
84
-
85
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
86
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
87
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
88
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
89
-
90
- dtype = torch.float16 if device.type == "cuda" else torch.float32
91
-
92
-
93
- # Gemma3 Model (default for text, image, & video inference)
94
-
95
- gemma3_model_id = "google/gemma-3-4b-it" #[or] Duplicate the space to use 12b
96
- gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
97
- gemma3_model_id, device_map="auto"
98
- ).eval()
99
- gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
100
-
101
- # VIDEO PROCESSING HELPER
102
-
103
- def downsample_video(video_path):
104
- vidcap = cv2.VideoCapture(video_path)
105
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
106
- fps = vidcap.get(cv2.CAP_PROP_FPS)
107
- frames = []
108
- # Sample 10 evenly spaced frames.
109
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
110
- for i in frame_indices:
111
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
112
- success, image = vidcap.read()
113
- if success:
114
- # Convert from BGR to RGB and then to PIL Image.
115
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
116
- pil_image = Image.fromarray(image)
117
- timestamp = round(i / fps, 2)
118
- frames.append((pil_image, timestamp))
119
- vidcap.release()
120
- return frames
121
-
122
- # MAIN GENERATION FUNCTION
123
-
124
- @spaces.GPU
125
- def generate(
126
- input_dict: dict,
127
- chat_history: list[dict],
128
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
129
- temperature: float = 0.6,
130
- top_p: float = 0.9,
131
- top_k: int = 50,
132
- repetition_penalty: float = 1.2,
133
  ):
134
- text = input_dict["text"]
135
- files = input_dict.get("files", [])
136
- lower_text = text.lower().strip()
137
-
138
- # ----- Qwen2-VL branch (triggered with @qwen2-vl) -----
139
- if lower_text.startswith("@qwen2-vl"):
140
- prompt_clean = re.sub(r"@qwen2-vl", "", text, flags=re.IGNORECASE).strip().strip('"')
141
- if files:
142
- images = [load_image(f) for f in files]
143
- messages = [{
144
- "role": "user",
145
- "content": [
146
- *[{"type": "image", "image": image} for image in images],
147
- {"type": "text", "text": prompt_clean},
148
- ]
149
- }]
150
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
151
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  else:
153
- messages = [
154
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
155
- {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
156
- ]
157
- inputs = processor.apply_chat_template(
158
- messages, add_generation_prompt=True, tokenize=True,
159
- return_dict=True, return_tensors="pt"
160
- ).to("cuda", dtype=torch.float16)
161
- streamer = TextIteratorStreamer(processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
162
- generation_kwargs = {
163
- **inputs,
164
- "streamer": streamer,
165
- "max_new_tokens": max_new_tokens,
166
- "do_sample": True,
167
- "temperature": temperature,
168
- "top_p": top_p,
169
- "top_k": top_k,
170
- "repetition_penalty": repetition_penalty,
171
- }
172
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
173
- thread.start()
174
- buffer = ""
175
- yield progress_bar_html("Processing with Qwen2VL")
176
- for new_text in streamer:
177
- buffer += new_text
178
- buffer = buffer.replace("<|im_end|>", "")
179
- time.sleep(0.01)
180
- yield buffer
181
- return
182
-
183
- # ----- Default branch: Gemma3 (for text, image, & video inference) -----
184
- if files:
185
- # Check if any provided file is a video based on extension.
186
- video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
187
- if any(str(f).lower().endswith(video_extensions) for f in files):
188
- # Video inference branch.
189
- prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
190
- video_path = files[0]
191
- frames = downsample_video(video_path)
192
- messages = [
193
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
194
- {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
195
- ]
196
- # Append each frame (with its timestamp) to the conversation.
197
- for frame in frames:
198
- image, timestamp = frame
199
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
200
- image.save(image_path)
201
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
202
- messages[1]["content"].append({"type": "image", "url": image_path})
203
- inputs = gemma3_processor.apply_chat_template(
204
- messages, add_generation_prompt=True, tokenize=True,
205
- return_dict=True, return_tensors="pt"
206
- ).to(gemma3_model.device, dtype=torch.bfloat16)
207
- streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
208
- generation_kwargs = {
209
- **inputs,
210
- "streamer": streamer,
211
- "max_new_tokens": max_new_tokens,
212
- "do_sample": True,
213
- "temperature": temperature,
214
- "top_p": top_p,
215
- "top_k": top_k,
216
- "repetition_penalty": repetition_penalty,
217
- }
218
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
219
- thread.start()
220
- buffer = ""
221
- yield progress_bar_html("Processing video with Gemma3")
222
- for new_text in streamer:
223
- buffer += new_text
224
- time.sleep(0.01)
225
- yield buffer
226
- return
227
  else:
228
- # Image inference branch.
229
- prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
230
- images = [load_image(f) for f in files]
231
- messages = [{
232
- "role": "user",
233
- "content": [
234
- *[{"type": "image", "image": image} for image in images],
235
- {"type": "text", "text": prompt_clean},
236
- ]
237
- }]
238
- inputs = gemma3_processor.apply_chat_template(
239
- messages, tokenize=True, add_generation_prompt=True,
240
- return_dict=True, return_tensors="pt"
241
- ).to(gemma3_model.device, dtype=torch.bfloat16)
242
- streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
243
- generation_kwargs = {
244
- **inputs,
245
- "streamer": streamer,
246
- "max_new_tokens": max_new_tokens,
247
- "do_sample": True,
248
- "temperature": temperature,
249
- "top_p": top_p,
250
- "top_k": top_k,
251
- "repetition_penalty": repetition_penalty,
252
- }
253
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
254
- thread.start()
255
- buffer = ""
256
- yield progress_bar_html("Processing with Gemma3")
257
- for new_text in streamer:
258
- buffer += new_text
259
- time.sleep(0.01)
260
- yield buffer
261
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  else:
263
- # Text-only inference branch.
264
- messages = [
265
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
266
- {"role": "user", "content": [{"type": "text", "text": text}]}
267
- ]
268
- inputs = gemma3_processor.apply_chat_template(
269
- messages, add_generation_prompt=True, tokenize=True,
270
- return_dict=True, return_tensors="pt"
271
- ).to(gemma3_model.device, dtype=torch.bfloat16)
272
- streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
273
- generation_kwargs = {
274
- **inputs,
275
- "streamer": streamer,
276
- "max_new_tokens": max_new_tokens,
277
- "do_sample": True,
278
- "temperature": temperature,
279
- "top_p": top_p,
280
- "top_k": top_k,
281
- "repetition_penalty": repetition_penalty,
282
- }
283
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
284
- thread.start()
285
- outputs = []
286
- for new_text in streamer:
287
- outputs.append(new_text)
288
- yield "".join(outputs)
289
- final_response = "".join(outputs)
290
- yield final_response
291
-
292
-
293
- # Gradio Interface
294
-
295
- demo = gr.ChatInterface(
296
- fn=generate,
297
- additional_inputs=[
298
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
299
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
300
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
301
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
302
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
303
- ],
304
- examples=[
305
- [{"text": "Create a short story based on the image.","files": ["examples/1111.jpg"]}],
306
- [{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
307
- [{"text": "Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
308
- [{"text": "Which movie character is this?", "files": ["examples/9999.jpg"]}],
309
- ["Explain Critical Temperature of Substance"],
310
- [{"text": "@qwen2-vl Transcription of the letter", "files": ["examples/222.png"]}],
311
- [{"text": "Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
312
- [{"text": "Describe the video", "files": ["examples/Missing.mp4"]}],
313
- [{"text": "Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
314
- [{"text": "Summarize the events in this video", "files": ["examples/sky.mp4"]}],
315
- [{"text": "What is in the video ?", "files": ["examples/redlight.mp4"]}],
316
- ["Python Program for Array Rotation"],
317
- ["Explain Critical Temperature of Substance"]
318
- ],
319
- cache_examples=False,
320
- type="messages",
321
- description="# **Gemma 3 Multimodal** \n`Use @qwen2-vl to switch to Qwen2-VL OCR for image inference and @video-infer for video input`",
322
- fill_height=True,
323
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Tag with @qwen2-vl for Qwen2-VL inference if needed."),
324
- stop_btn="Stop Generation",
325
- multimodal=True,
326
- )
327
-
328
- if __name__ == "__main__":
329
- demo.queue(max_size=20).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
2
  import json
3
+ import copy
4
  import time
5
+ import random
6
+ import logging
7
+ import numpy as np
8
+ from typing import Any, Dict, List, Optional, Union
9
 
 
 
10
  import torch
 
11
  from PIL import Image
12
+ import gradio as gr
13
+
14
+ from diffusers import (
15
+ DiffusionPipeline,
16
+ AutoencoderTiny,
17
+ AutoencoderKL,
18
+ AutoPipelineForImage2Image,
19
+ FluxPipeline,
20
+ FlowMatchEulerDiscreteScheduler)
21
+
22
+ from huggingface_hub import (
23
+ hf_hub_download,
24
+ HfFileSystem,
25
+ ModelCard,
26
+ snapshot_download)
27
+
28
+ from diffusers.utils import load_image
29
+
30
+ import spaces
31
+
32
+ #---if workspace = local or colab---
33
+
34
+ # Authenticate with Hugging Face
35
+ # from huggingface_hub import login
36
+
37
+ # Log in to Hugging Face using the provided token
38
+ # hf_token = 'hf-token-authentication'
39
+ # login(hf_token)
40
+
41
+ def calculate_shift(
42
+ image_seq_len,
43
+ base_seq_len: int = 256,
44
+ max_seq_len: int = 4096,
45
+ base_shift: float = 0.5,
46
+ max_shift: float = 1.16,
47
+ ):
48
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
49
+ b = base_shift - m * base_seq_len
50
+ mu = image_seq_len * m + b
51
+ return mu
52
+
53
+ def retrieve_timesteps(
54
+ scheduler,
55
+ num_inference_steps: Optional[int] = None,
56
+ device: Optional[Union[str, torch.device]] = None,
57
+ timesteps: Optional[List[int]] = None,
58
+ sigmas: Optional[List[float]] = None,
59
+ **kwargs,
60
+ ):
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
63
+ if timesteps is not None:
64
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
65
+ timesteps = scheduler.timesteps
66
+ num_inference_steps = len(timesteps)
67
+ elif sigmas is not None:
68
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
69
+ timesteps = scheduler.timesteps
70
+ num_inference_steps = len(timesteps)
71
+ else:
72
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
73
+ timesteps = scheduler.timesteps
74
+ return timesteps, num_inference_steps
75
+
76
+ # FLUX pipeline
77
+ @torch.inference_mode()
78
+ def flux_pipe_call_that_returns_an_iterable_of_images(
79
+ self,
80
+ prompt: Union[str, List[str]] = None,
81
+ prompt_2: Optional[Union[str, List[str]]] = None,
82
+ height: Optional[int] = None,
83
+ width: Optional[int] = None,
84
+ num_inference_steps: int = 28,
85
+ timesteps: List[int] = None,
86
+ guidance_scale: float = 3.5,
87
+ num_images_per_prompt: Optional[int] = 1,
88
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
89
+ latents: Optional[torch.FloatTensor] = None,
90
+ prompt_embeds: Optional[torch.FloatTensor] = None,
91
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
92
+ output_type: Optional[str] = "pil",
93
+ return_dict: bool = True,
94
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
95
+ max_sequence_length: int = 512,
96
+ good_vae: Optional[Any] = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  ):
98
+ height = height or self.default_sample_size * self.vae_scale_factor
99
+ width = width or self.default_sample_size * self.vae_scale_factor
100
+
101
+ self.check_inputs(
102
+ prompt,
103
+ prompt_2,
104
+ height,
105
+ width,
106
+ prompt_embeds=prompt_embeds,
107
+ pooled_prompt_embeds=pooled_prompt_embeds,
108
+ max_sequence_length=max_sequence_length,
109
+ )
110
+
111
+ self._guidance_scale = guidance_scale
112
+ self._joint_attention_kwargs = joint_attention_kwargs
113
+ self._interrupt = False
114
+
115
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
116
+ device = self._execution_device
117
+
118
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
119
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
120
+ prompt=prompt,
121
+ prompt_2=prompt_2,
122
+ prompt_embeds=prompt_embeds,
123
+ pooled_prompt_embeds=pooled_prompt_embeds,
124
+ device=device,
125
+ num_images_per_prompt=num_images_per_prompt,
126
+ max_sequence_length=max_sequence_length,
127
+ lora_scale=lora_scale,
128
+ )
129
+
130
+ num_channels_latents = self.transformer.config.in_channels // 4
131
+ latents, latent_image_ids = self.prepare_latents(
132
+ batch_size * num_images_per_prompt,
133
+ num_channels_latents,
134
+ height,
135
+ width,
136
+ prompt_embeds.dtype,
137
+ device,
138
+ generator,
139
+ latents,
140
+ )
141
+
142
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
143
+ image_seq_len = latents.shape[1]
144
+ mu = calculate_shift(
145
+ image_seq_len,
146
+ self.scheduler.config.base_image_seq_len,
147
+ self.scheduler.config.max_image_seq_len,
148
+ self.scheduler.config.base_shift,
149
+ self.scheduler.config.max_shift,
150
+ )
151
+ timesteps, num_inference_steps = retrieve_timesteps(
152
+ self.scheduler,
153
+ num_inference_steps,
154
+ device,
155
+ timesteps,
156
+ sigmas,
157
+ mu=mu,
158
+ )
159
+ self._num_timesteps = len(timesteps)
160
+
161
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
162
+
163
+ for i, t in enumerate(timesteps):
164
+ if self.interrupt:
165
+ continue
166
+
167
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
168
+
169
+ noise_pred = self.transformer(
170
+ hidden_states=latents,
171
+ timestep=timestep / 1000,
172
+ guidance=guidance,
173
+ pooled_projections=pooled_prompt_embeds,
174
+ encoder_hidden_states=prompt_embeds,
175
+ txt_ids=text_ids,
176
+ img_ids=latent_image_ids,
177
+ joint_attention_kwargs=self.joint_attention_kwargs,
178
+ return_dict=False,
179
+ )[0]
180
+
181
+ latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
182
+ latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
183
+ image = self.vae.decode(latents_for_image, return_dict=False)[0]
184
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
185
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
186
+ torch.cuda.empty_cache()
187
+
188
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
189
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
190
+ image = good_vae.decode(latents, return_dict=False)[0]
191
+ self.maybe_free_model_hooks()
192
+ torch.cuda.empty_cache()
193
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
194
+
195
+ #------------------------------------------------------------------------------------------------------------------------------------------------------------#
196
+ loras = [
197
+ #1
198
+ {
199
+ "image": "https://huggingface.co/strangerzonehf/CMS-3D-Art/resolve/main/images/33.png",
200
+ "title": "CMS 3D Art",
201
+ "repo": "strangerzonehf/CMS-3D-Art",
202
+ "weights": "CMS-3D-Art.safetensors",
203
+ "trigger_word": "CMS 3D Art"
204
+ },
205
+ #2
206
+ ]
207
+
208
+ #--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
209
+
210
+ dtype = torch.bfloat16
211
+ device = "cuda" if torch.cuda.is_available() else "cpu"
212
+ base_model = "black-forest-labs/FLUX.1-dev"
213
+
214
+ #TAEF1 is very tiny autoencoder which uses the same "latent API" as FLUX.1's VAE. FLUX.1 is useful for real-time previewing of the FLUX.1 generation process.#
215
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
216
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
217
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
218
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model,
219
+ vae=good_vae,
220
+ transformer=pipe.transformer,
221
+ text_encoder=pipe.text_encoder,
222
+ tokenizer=pipe.tokenizer,
223
+ text_encoder_2=pipe.text_encoder_2,
224
+ tokenizer_2=pipe.tokenizer_2,
225
+ torch_dtype=dtype
226
+ )
227
+
228
+ MAX_SEED = 2**32-1
229
+
230
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
231
+
232
+ class calculateDuration:
233
+ def __init__(self, activity_name=""):
234
+ self.activity_name = activity_name
235
+
236
+ def __enter__(self):
237
+ self.start_time = time.time()
238
+ return self
239
+
240
+ def __exit__(self, exc_type, exc_value, traceback):
241
+ self.end_time = time.time()
242
+ self.elapsed_time = self.end_time - self.start_time
243
+ if self.activity_name:
244
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
245
  else:
246
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
247
+
248
+ def update_selection(evt: gr.SelectData, width, height):
249
+ selected_lora = loras[evt.index]
250
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
251
+ lora_repo = selected_lora["repo"]
252
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
253
+ if "aspect" in selected_lora:
254
+ if selected_lora["aspect"] == "portrait":
255
+ width = 768
256
+ height = 1024
257
+ elif selected_lora["aspect"] == "landscape":
258
+ width = 1024
259
+ height = 768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  else:
261
+ width = 1024
262
+ height = 1024
263
+ return (
264
+ gr.update(placeholder=new_placeholder),
265
+ updated_text,
266
+ evt.index,
267
+ width,
268
+ height,
269
+ )
270
+
271
+ @spaces.GPU(duration=100)
272
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
273
+ pipe.to("cuda")
274
+ generator = torch.Generator(device="cuda").manual_seed(seed)
275
+ with calculateDuration("Generating image"):
276
+ # Generate image
277
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
278
+ prompt=prompt_mash,
279
+ num_inference_steps=steps,
280
+ guidance_scale=cfg_scale,
281
+ width=width,
282
+ height=height,
283
+ generator=generator,
284
+ joint_attention_kwargs={"scale": lora_scale},
285
+ output_type="pil",
286
+ good_vae=good_vae,
287
+ ):
288
+ yield img
289
+
290
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
291
+ generator = torch.Generator(device="cuda").manual_seed(seed)
292
+ pipe_i2i.to("cuda")
293
+ image_input = load_image(image_input_path)
294
+ final_image = pipe_i2i(
295
+ prompt=prompt_mash,
296
+ image=image_input,
297
+ strength=image_strength,
298
+ num_inference_steps=steps,
299
+ guidance_scale=cfg_scale,
300
+ width=width,
301
+ height=height,
302
+ generator=generator,
303
+ joint_attention_kwargs={"scale": lora_scale},
304
+ output_type="pil",
305
+ ).images[0]
306
+ return final_image
307
+
308
+ @spaces.GPU(duration=100)
309
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
310
+ if selected_index is None:
311
+ raise gr.Error("You must select a LoRA before proceeding.🧨")
312
+ selected_lora = loras[selected_index]
313
+ lora_path = selected_lora["repo"]
314
+ trigger_word = selected_lora["trigger_word"]
315
+ if(trigger_word):
316
+ if "trigger_position" in selected_lora:
317
+ if selected_lora["trigger_position"] == "prepend":
318
+ prompt_mash = f"{trigger_word} {prompt}"
319
+ else:
320
+ prompt_mash = f"{prompt} {trigger_word}"
321
+ else:
322
+ prompt_mash = f"{trigger_word} {prompt}"
323
+ else:
324
+ prompt_mash = prompt
325
+
326
+ with calculateDuration("Unloading LoRA"):
327
+ pipe.unload_lora_weights()
328
+ pipe_i2i.unload_lora_weights()
329
+
330
+ #LoRA weights flow
331
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
332
+ pipe_to_use = pipe_i2i if image_input is not None else pipe
333
+ weight_name = selected_lora.get("weights", None)
334
+
335
+ pipe_to_use.load_lora_weights(
336
+ lora_path,
337
+ weight_name=weight_name,
338
+ low_cpu_mem_usage=True
339
+ )
340
+
341
+ with calculateDuration("Randomizing seed"):
342
+ if randomize_seed:
343
+ seed = random.randint(0, MAX_SEED)
344
+
345
+ if(image_input is not None):
346
+
347
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
348
+ yield final_image, seed, gr.update(visible=False)
349
  else:
350
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
351
+
352
+ final_image = None
353
+ step_counter = 0
354
+ for image in image_generator:
355
+ step_counter+=1
356
+ final_image = image
357
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
358
+ yield image, seed, gr.update(value=progress_bar, visible=True)
359
+
360
+ yield final_image, seed, gr.update(value=progress_bar, visible=False)
361
+
362
+ def get_huggingface_safetensors(link):
363
+ split_link = link.split("/")
364
+ if(len(split_link) == 2):
365
+ model_card = ModelCard.load(link)
366
+ base_model = model_card.data.get("base_model")
367
+ print(base_model)
368
+
369
+ #Allows Both
370
+ if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
371
+ raise Exception("Flux LoRA Not Found!")
372
+
373
+ # Only allow "black-forest-labs/FLUX.1-dev"
374
+ #if base_model != "black-forest-labs/FLUX.1-dev":
375
+ #raise Exception("Only FLUX.1-dev is supported, other LoRA models are not allowed!")
376
+
377
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
378
+ trigger_word = model_card.data.get("instance_prompt", "")
379
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
380
+ fs = HfFileSystem()
381
+ try:
382
+ list_of_files = fs.ls(link, detail=False)
383
+ for file in list_of_files:
384
+ if(file.endswith(".safetensors")):
385
+ safetensors_name = file.split("/")[-1]
386
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
387
+ image_elements = file.split("/")
388
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
389
+ except Exception as e:
390
+ print(e)
391
+ gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
392
+ raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
393
+ return split_link[1], link, safetensors_name, trigger_word, image_url
394
+
395
+ def check_custom_model(link):
396
+ if(link.startswith("https://")):
397
+ if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
398
+ link_split = link.split("huggingface.co/")
399
+ return get_huggingface_safetensors(link_split[1])
400
+ else:
401
+ return get_huggingface_safetensors(link)
402
+
403
+ def add_custom_lora(custom_lora):
404
+ global loras
405
+ if(custom_lora):
406
+ try:
407
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
408
+ print(f"Loaded custom LoRA: {repo}")
409
+ card = f'''
410
+ <div class="custom_lora_card">
411
+ <span>Loaded custom LoRA:</span>
412
+ <div class="card_internal">
413
+ <img src="{image}" />
414
+ <div>
415
+ <h3>{title}</h3>
416
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
417
+ </div>
418
+ </div>
419
+ </div>
420
+ '''
421
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
422
+ if(not existing_item_index):
423
+ new_item = {
424
+ "image": image,
425
+ "title": title,
426
+ "repo": repo,
427
+ "weights": path,
428
+ "trigger_word": trigger_word
429
+ }
430
+ print(new_item)
431
+ existing_item_index = len(loras)
432
+ loras.append(new_item)
433
+
434
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
435
+ except Exception as e:
436
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
437
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=False), gr.update(), "", None, ""
438
+ else:
439
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
440
+
441
+ def remove_custom_lora():
442
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
443
+
444
+ run_lora.zerogpu = True
445
+
446
+ css = '''
447
+ #gen_btn{height: 100%}
448
+ #gen_column{align-self: stretch}
449
+ #title{text-align: center}
450
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
451
+ #title img{width: 100px; margin-right: 0.5em}
452
+ #gallery .grid-wrap{height: 10vh}
453
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
454
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
455
+ .card_internal img{margin-right: 1em}
456
+ .styler{--form-gap-width: 0px !important}
457
+ #progress{height:30px}
458
+ #progress .generating{display:none}
459
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
460
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
461
+ '''
462
+
463
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
464
+ title = gr.HTML(
465
+ """<h1>FLUX LoRA DLC🥳</h1>""",
466
+ elem_id="title",
467
+ )
468
+ selected_index = gr.State(None)
469
+ with gr.Row():
470
+ with gr.Column(scale=3):
471
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder=":/ choose the LoRA and type the prompt ")
472
+ with gr.Column(scale=1, elem_id="gen_column"):
473
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
474
+ with gr.Row():
475
+ with gr.Column():
476
+ selected_info = gr.Markdown("")
477
+ gallery = gr.Gallery(
478
+ [(item["image"], item["title"]) for item in loras],
479
+ label="250+ LoRA DLC's",
480
+ allow_preview=False,
481
+ columns=3,
482
+ elem_id="gallery",
483
+ show_share_button=False
484
+ )
485
+ with gr.Group():
486
+ custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime")
487
+ gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
488
+ custom_lora_info = gr.HTML(visible=False)
489
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
490
+ with gr.Column():
491
+ progress_bar = gr.Markdown(elem_id="progress",visible=False)
492
+ result = gr.Image(label="Generated Image", format="png")
493
+
494
+ with gr.Row():
495
+ with gr.Accordion("Advanced Settings", open=False):
496
+ with gr.Row():
497
+ input_image = gr.Image(label="Input image", type="filepath")
498
+ image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
499
+ with gr.Column():
500
+ with gr.Row():
501
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
502
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
503
+
504
+ with gr.Row():
505
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
506
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
507
+
508
+ with gr.Row():
509
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
510
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
511
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
512
+
513
+ gallery.select(
514
+ update_selection,
515
+ inputs=[width, height],
516
+ outputs=[prompt, selected_info, selected_index, width, height]
517
+ )
518
+ custom_lora.input(
519
+ add_custom_lora,
520
+ inputs=[custom_lora],
521
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
522
+ )
523
+ custom_lora_button.click(
524
+ remove_custom_lora,
525
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
526
+ )
527
+ gr.on(
528
+ triggers=[generate_button.click, prompt.submit],
529
+ fn=run_lora,
530
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
531
+ outputs=[result, seed, progress_bar]
532
+ )
533
+
534
+ app.queue()
535
+ app.launch(ssr_mode=False)