prithivMLmods commited on
Commit
74ba6ce
·
verified ·
1 Parent(s): 54875b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -344
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import random
3
  import uuid
4
- import json
5
  import time
6
  import asyncio
7
  from threading import Thread
@@ -11,7 +10,6 @@ import spaces
11
  import torch
12
  import numpy as np
13
  from PIL import Image
14
- import edge_tts
15
  import cv2
16
 
17
  from transformers import (
@@ -24,31 +22,107 @@ from transformers import (
24
  from transformers.image_utils import load_image
25
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
- # --------- Global Config and Model Loading ---------
 
 
 
28
  MAX_MAX_NEW_TOKENS = 2048
29
  DEFAULT_MAX_NEW_TOKENS = 1024
30
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
31
- MAX_SEED = np.iinfo(np.int32).max
32
-
33
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
 
35
- # For text-only generation (chat)
36
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
37
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  model = AutoModelForCausalLM.from_pretrained(
39
- model_id,
40
  device_map="auto",
41
  torch_dtype=torch.bfloat16,
42
  )
43
  model.eval()
44
 
45
- # For TTS
46
- TTS_VOICES = [
47
- "en-US-JennyNeural", # @tts1
48
- "en-US-GuyNeural", # @tts2
49
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # For multimodal Qwen2VL (OCR / video/text)
52
  MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
53
  processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
54
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -57,8 +131,46 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
57
  torch_dtype=torch.float16
58
  ).to("cuda").eval()
59
 
60
- # For SDXL Image Generation
61
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # Set your SDXL model repository path via env variable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
63
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
64
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
@@ -77,7 +189,7 @@ if USE_TORCH_COMPILE:
77
  if ENABLE_CPU_OFFLOAD:
78
  sd_pipe.enable_model_cpu_offload()
79
 
80
- # For SDXL quality styles and LoRA options (used in the image-gen tab)
81
  LORA_OPTIONS = {
82
  "Realism (face/character)👦🏻": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"),
83
  "Pixar (art/toons)🙀": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"),
@@ -93,6 +205,8 @@ LORA_OPTIONS = {
93
  "Pencil Art (characteristic/creative)✏️": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"),
94
  "Art Minimalistic (paint/semireal)🎨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"),
95
  }
 
 
96
  style_list = [
97
  {
98
  "name": "3840 x 2160",
@@ -119,351 +233,104 @@ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
119
  DEFAULT_STYLE_NAME = "3840 x 2160"
120
  STYLE_NAMES = list(styles.keys())
121
 
122
- # --------- Utility Functions ---------
123
- def text_to_speech(text: str, voice: str, output_file="output.mp3"):
124
- """Convert text to speech using Edge TTS and save as MP3"""
125
- async def run_tts():
126
- communicate = edge_tts.Communicate(text, voice)
127
- await communicate.save(output_file)
128
- return output_file
129
- return asyncio.run(run_tts())
130
-
131
- def clean_chat_history(chat_history):
132
- """Remove non-string content from the chat history."""
133
- return [msg for msg in chat_history if isinstance(msg, dict) and isinstance(msg.get("content"), str)]
134
-
135
- def save_image(img: Image.Image) -> str:
136
- """Save a PIL image to a file with a unique filename."""
137
- unique_name = str(uuid.uuid4()) + ".png"
138
- img.save(unique_name)
139
- return unique_name
140
-
141
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
142
- return random.randint(0, MAX_SEED) if randomize_seed else seed
143
-
144
- def progress_bar_html(label: str) -> str:
145
- """Return an HTML snippet for a progress bar."""
146
- return f'''
147
- <div style="display: flex; align-items: center;">
148
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
149
- <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
150
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
151
- </div>
152
- </div>
153
- <style>
154
- @keyframes loading {{
155
- 0% {{ transform: translateX(-100%); }}
156
- 100% {{ transform: translateX(100%); }}
157
- }}
158
- </style>
159
- '''
160
-
161
- def downsample_video(video_path):
162
- """Extract 10 evenly spaced frames from a video."""
163
- vidcap = cv2.VideoCapture(video_path)
164
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
165
- fps = vidcap.get(cv2.CAP_PROP_FPS)
166
- frames = []
167
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
168
- for i in frame_indices:
169
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
170
- success, image = vidcap.read()
171
- if success:
172
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
173
- pil_image = Image.fromarray(image)
174
- timestamp = round(i / fps, 2)
175
- frames.append((pil_image, timestamp))
176
- vidcap.release()
177
- return frames
178
-
179
  def apply_style(style_name: str, positive: str, negative: str = ""):
180
- """Apply a chosen quality style to the prompt."""
181
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
182
- return p.replace("{prompt}", positive), n + negative
183
-
184
- # --------- Tab 1: Chat Interface (Multimodal) ---------
185
- def chat_generate(input_dict: dict, chat_history: list,
186
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
187
- temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
188
- text = input_dict["text"]
189
- files = input_dict.get("files", [])
190
- lower_text = text.strip().lower()
191
-
192
- # If image generation command
193
- if lower_text.startswith("@image"):
194
- prompt = text[len("@image"):].strip()
195
- yield progress_bar_html("Generating Image")
196
- image_paths, used_seed = generate_image_fn(
197
- prompt=prompt,
198
- negative_prompt="",
199
- use_negative_prompt=False,
200
- seed=1,
201
- width=1024,
202
- height=1024,
203
- guidance_scale=3,
204
- num_inference_steps=25,
205
- randomize_seed=True,
206
- use_resolution_binning=True,
207
- num_images=1,
208
- )
209
- yield gr.Image.update(value=image_paths[0])
210
- return
211
-
212
- # If video inference command
213
- if lower_text.startswith("@video-infer"):
214
- prompt = text[len("@video-infer"):].strip()
215
- if files:
216
- video_path = files[0]
217
- frames = downsample_video(video_path)
218
- messages = [
219
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
220
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
221
- ]
222
- for frame in frames:
223
- image, timestamp = frame
224
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
225
- image.save(image_path)
226
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
227
- messages[1]["content"].append({"type": "image", "url": image_path})
228
- else:
229
- messages = [
230
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
231
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
232
- ]
233
- inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda")
234
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
235
- generation_kwargs = {
236
- **inputs,
237
- "streamer": streamer,
238
- "max_new_tokens": max_new_tokens,
239
- "do_sample": True,
240
- "temperature": temperature,
241
- "top_p": top_p,
242
- "top_k": top_k,
243
- "repetition_penalty": repetition_penalty,
244
- }
245
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
246
- thread.start()
247
- buffer = ""
248
- yield progress_bar_html("Processing video with Qwen2VL")
249
- for new_text in streamer:
250
- buffer += new_text.replace("<|im_end|>", "")
251
- time.sleep(0.01)
252
- yield buffer
253
- return
254
-
255
- # Check for TTS command
256
- tts_prefix = "@tts"
257
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
258
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
259
-
260
- if is_tts and voice_index:
261
- voice = TTS_VOICES[voice_index - 1]
262
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
263
- conversation = [{"role": "user", "content": text}]
264
- else:
265
- voice = None
266
- text = text.replace(tts_prefix, "").strip()
267
- conversation = clean_chat_history(chat_history)
268
- conversation.append({"role": "user", "content": text})
269
-
270
- if files:
271
- # Handle multimodal chat with images
272
- images = [load_image(f) for f in files]
273
- messages = [{
274
- "role": "user",
275
- "content": [{"type": "image", "image": image} for image in images] + [{"type": "text", "text": text}]
276
- }]
277
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
278
- inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
279
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
280
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
281
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
282
- thread.start()
283
- buffer = ""
284
- yield progress_bar_html("Thinking...")
285
- for new_text in streamer:
286
- buffer += new_text.replace("<|im_end|>", "")
287
- time.sleep(0.01)
288
- yield buffer
289
  else:
290
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
291
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
292
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
293
- gr.Warning(f"Trimmed input as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
294
- input_ids = input_ids.to(model.device)
295
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
296
- generation_kwargs = {
297
- "input_ids": input_ids,
298
- "streamer": streamer,
299
- "max_new_tokens": max_new_tokens,
300
- "do_sample": True,
301
- "top_p": top_p,
302
- "top_k": top_k,
303
- "temperature": temperature,
304
- "num_beams": 1,
305
- "repetition_penalty": repetition_penalty,
306
- }
307
- t = Thread(target=model.generate, kwargs=generation_kwargs)
308
- t.start()
309
- outputs = []
310
- yield progress_bar_html("Processing...")
311
- for new_text in streamer:
312
- outputs.append(new_text)
313
- yield "".join(outputs)
314
- final_response = "".join(outputs)
315
- yield final_response
316
- if is_tts and voice:
317
- output_file = text_to_speech(final_response, voice)
318
- yield gr.Audio.update(value=output_file)
319
 
320
- # Helper function for image generation (used in chat @image branch)
321
- @spaces.GPU(duration=60, enable_queue=True)
322
- def generate_image_fn(prompt: str, negative_prompt: str = "", use_negative_prompt: bool = False,
323
- seed: int = 1, width: int = 1024, height: int = 1024,
324
- guidance_scale: float = 3, num_inference_steps: int = 25,
325
- randomize_seed: bool = False, use_resolution_binning: bool = True,
326
- num_images: int = 1, progress=None):
327
  seed = int(randomize_seed_fn(seed, randomize_seed))
328
- generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
329
  options = {
330
- "prompt": [prompt] * num_images,
331
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
332
  "width": width,
333
  "height": height,
334
  "guidance_scale": guidance_scale,
335
- "num_inference_steps": num_inference_steps,
336
- "generator": generator,
 
337
  "output_type": "pil",
338
  }
339
- if use_resolution_binning:
340
- options["use_resolution_binning"] = True
341
-
342
- images = []
343
- for i in range(0, num_images, BATCH_SIZE):
344
- batch_options = options.copy()
345
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
346
- if batch_options.get("negative_prompt") is not None:
347
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
348
- if device.type == "cuda":
349
- with torch.autocast("cuda", dtype=torch.float16):
350
- outputs = sd_pipe(**batch_options)
351
- else:
352
- outputs = sd_pipe(**batch_options)
353
- images.extend(outputs.images)
354
  image_paths = [save_image(img) for img in images]
355
  return image_paths, seed
356
 
357
- # --------- Tab 2: SDXL Image Generation ---------
358
- @spaces.GPU(duration=180, enable_queue=True)
359
- def sdxl_generate(prompt: str, negative_prompt: str = "", use_negative_prompt: bool = True,
360
- seed: int = 0, width: int = 1024, height: int = 1024, guidance_scale: float = 3,
361
- randomize_seed: bool = False, style_name: str = DEFAULT_STYLE_NAME,
362
- lora_model: str = "Realism (face/character)👦🏻", progress=None):
363
- seed = int(randomize_seed_fn(seed, randomize_seed))
364
- positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt)
365
- if not use_negative_prompt:
366
- effective_negative_prompt = ""
367
- model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
368
- # Set the adapter for the current generation
369
- sd_pipe.load_lora_weights(model_name, weight_name=weight_name, adapter_name=adapter_name)
370
- sd_pipe.set_adapters(adapter_name)
371
- images = sd_pipe(
372
- prompt=positive_prompt,
373
- negative_prompt=effective_negative_prompt,
374
- width=width,
375
- height=height,
376
- guidance_scale=guidance_scale,
377
- num_inference_steps=20,
378
- num_images_per_prompt=1,
379
- cross_attention_kwargs={"scale": 0.65},
380
- output_type="pil",
381
- ).images
382
- image_paths = [save_image(img) for img in images]
383
- return image_paths, seed
384
-
385
- # --------- Tab 3: Qwen2VL OCR & Text Generation ---------
386
- def qwen2vl_ocr_textgen(prompt: str, image_file):
387
- if image_file is None:
388
- return "Please upload an image."
389
- # Load the image
390
- image = load_image(image_file)
391
- messages = [
392
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
393
- {"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": image}]}
394
- ]
395
- inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True,
396
- return_dict=True, return_tensors="pt").to("cuda")
397
- outputs = model_m.generate(
398
- **inputs,
399
- max_new_tokens=1024,
400
- do_sample=True,
401
- temperature=0.6,
402
- top_p=0.9,
403
- top_k=50,
404
- repetition_penalty=1.2,
405
- )
406
- response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
407
- return response
408
 
409
- # --------- Building the Gradio Interface with Tabs ---------
410
- with gr.Blocks(title="Combined Demo") as demo:
411
- gr.Markdown("# Combined Demo: Chat, SDXL Image Gen & Qwen2VL OCR/TextGen")
412
  with gr.Tabs():
413
- # --- Tab 1: Chat Interface ---
414
  with gr.Tab("Chat Interface"):
415
- chat_interface = gr.ChatInterface(
416
- fn=chat_generate,
417
- additional_inputs=[
418
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
419
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
420
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
421
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
422
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
423
- ],
424
- examples=[
425
- ["Write the Python Program for Array Rotation"],
426
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
427
- [{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}],
428
- ["@image Chocolate dripping from a donut"],
429
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
430
- ],
431
- cache_examples=False,
432
- type="messages",
433
- description="Use commands like **@image**, **@video-infer**, **@tts1**, or plain text.",
434
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple",
435
- placeholder="Type your query (e.g., @tts1 for TTS, @image for image gen, etc.)"),
436
- stop_btn="Stop Generation",
437
- multimodal=True,
438
- )
439
- # --- Tab 2: SDXL Image Generation ---
440
- with gr.Tab("SDXL Gen Image"):
441
  with gr.Row():
442
- prompt_in = gr.Textbox(label="Prompt", placeholder="Enter prompt for image generation")
443
- negative_prompt_in = gr.Textbox(label="Negative prompt", placeholder="Enter negative prompt", lines=2)
444
  with gr.Row():
445
- seed_in = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
446
- randomize_in = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  with gr.Row():
448
- width_in = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
449
- height_in = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
450
- guidance_in = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=3.0)
451
- style_in = gr.Radio(choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, label="Quality Style")
452
- lora_in = gr.Dropdown(choices=list(LORA_OPTIONS.keys()), value="Realism (face/character)👦🏻", label="LoRA Selection")
453
- run_button_img = gr.Button("Generate Image")
454
- output_gallery = gr.Gallery(label="Generated Image", columns=1, preview=True)
455
- seed_output = gr.Number(label="Seed used")
456
- run_button_img.click(fn=sdxl_generate,
457
- inputs=[prompt_in, negative_prompt_in, randomize_in, seed_in, width_in, height_in, guidance_in, randomize_in, style_in, lora_in],
458
- outputs=[output_gallery, seed_output])
459
- # --- Tab 3: Qwen2VL OCR & Text Generation ---
460
- with gr.Tab("Qwen2VL OCR/TextGen"):
461
  with gr.Row():
462
- qwen_prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt for OCR / text generation")
463
- qwen_image = gr.Image(label="Upload Image", type="filepath")
464
- run_button_qwen = gr.Button("Run Qwen2VL")
465
- qwen_output = gr.Textbox(label="Output")
466
- run_button_qwen.click(fn=qwen2vl_ocr_textgen, inputs=[qwen_prompt, qwen_image], outputs=qwen_output)
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
  if __name__ == "__main__":
469
- demo.queue(max_size=30).launch(share=True)
 
1
  import os
2
  import random
3
  import uuid
 
4
  import time
5
  import asyncio
6
  from threading import Thread
 
10
  import torch
11
  import numpy as np
12
  from PIL import Image
 
13
  import cv2
14
 
15
  from transformers import (
 
22
  from transformers.image_utils import load_image
23
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
24
 
25
+ # ---------------------------
26
+ # Global Settings & Utilities
27
+ # ---------------------------
28
+
29
  MAX_MAX_NEW_TOKENS = 2048
30
  DEFAULT_MAX_NEW_TOKENS = 1024
31
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
 
34
+ def save_image(img: Image.Image) -> str:
35
+ """Save a PIL image with a unique filename and return the path."""
36
+ unique_name = str(uuid.uuid4()) + ".png"
37
+ img.save(unique_name)
38
+ return unique_name
39
+
40
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
41
+ MAX_SEED = np.iinfo(np.int32).max
42
+ if randomize_seed:
43
+ seed = random.randint(0, MAX_SEED)
44
+ return seed
45
+
46
+ def progress_bar_html(label: str) -> str:
47
+ """Returns an HTML snippet for a thin progress bar with a label."""
48
+ return f'''
49
+ <div style="display: flex; align-items: center;">
50
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
51
+ <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
52
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
53
+ </div>
54
+ </div>
55
+ <style>
56
+ @keyframes loading {{
57
+ 0% {{ transform: translateX(-100%); }}
58
+ 100% {{ transform: translateX(100%); }}
59
+ }}
60
+ </style>
61
+ '''
62
+
63
+ # ---------------------------
64
+ # 1. Chat Interface Tab
65
+ # ---------------------------
66
+ # Uses a text-only model: FastThink-0.5B-Tiny
67
+
68
+ model_id_text = "prithivMLmods/FastThink-0.5B-Tiny"
69
+ tokenizer = AutoTokenizer.from_pretrained(model_id_text)
70
  model = AutoModelForCausalLM.from_pretrained(
71
+ model_id_text,
72
  device_map="auto",
73
  torch_dtype=torch.bfloat16,
74
  )
75
  model.eval()
76
 
77
+ def clean_chat_history(chat_history):
78
+ """
79
+ Filter out any chat entries whose "content" is not a string.
80
+ """
81
+ cleaned = []
82
+ for msg in chat_history:
83
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
84
+ cleaned.append(msg)
85
+ return cleaned
86
+
87
+ def chat_generate(input_text: str, chat_history: list, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
88
+ """
89
+ Chat generation using a text-only model.
90
+ """
91
+ # Prepare conversation by cleaning history and appending the new user message.
92
+ conversation = clean_chat_history(chat_history)
93
+ conversation.append({"role": "user", "content": input_text})
94
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
95
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
96
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
97
+ input_ids = input_ids.to(model.device)
98
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
99
+ generation_kwargs = {
100
+ "input_ids": input_ids,
101
+ "streamer": streamer,
102
+ "max_new_tokens": max_new_tokens,
103
+ "do_sample": True,
104
+ "top_p": top_p,
105
+ "top_k": top_k,
106
+ "temperature": temperature,
107
+ "num_beams": 1,
108
+ "repetition_penalty": repetition_penalty,
109
+ }
110
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
111
+ thread.start()
112
+ outputs = []
113
+ # Collect the generated text from the streamer.
114
+ for new_text in streamer:
115
+ outputs.append(new_text)
116
+ final_response = "".join(outputs)
117
+ # Append assistant reply to chat history.
118
+ updated_history = conversation + [{"role": "assistant", "content": final_response}]
119
+ return final_response, updated_history
120
+
121
+ # ---------------------------
122
+ # 2. Qwen 2 VL OCR Tab
123
+ # ---------------------------
124
+ # Uses Qwen2VL OCR model for multimodal input (text + image)
125
 
 
126
  MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
127
  processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
128
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
131
  torch_dtype=torch.float16
132
  ).to("cuda").eval()
133
 
134
+ def generate_qwen_ocr(input_text: str, image):
135
+ """
136
+ Uses the Qwen2VL OCR model to process an image along with text.
137
+ """
138
+ if image is None:
139
+ return "No image provided."
140
+ # Build message with system and user content.
141
+ messages = [
142
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
143
+ {"role": "user", "content": [{"type": "text", "text": input_text}, {"type": "image", "image": image}]}
144
+ ]
145
+ # Apply chat template.
146
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
147
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to("cuda")
148
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
149
+ generation_kwargs = {
150
+ **inputs,
151
+ "streamer": streamer,
152
+ "max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
153
+ "do_sample": True,
154
+ "temperature": 0.6,
155
+ "top_p": 0.9,
156
+ "top_k": 50,
157
+ "repetition_penalty": 1.2,
158
+ }
159
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
160
+ thread.start()
161
+ outputs = []
162
+ for new_text in streamer:
163
+ outputs.append(new_text.replace("<|im_end|>", ""))
164
+ final_response = "".join(outputs)
165
+ return final_response
166
+
167
+ # ---------------------------
168
+ # 3. Image Gen LoRA Tab
169
+ # ---------------------------
170
+ # Uses the SDXL pipeline with LoRA options.
171
+
172
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # set your SDXL model path via env variable
173
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
174
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
175
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
176
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
 
189
  if ENABLE_CPU_OFFLOAD:
190
  sd_pipe.enable_model_cpu_offload()
191
 
192
+ # LoRA options dictionary.
193
  LORA_OPTIONS = {
194
  "Realism (face/character)👦🏻": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"),
195
  "Pixar (art/toons)🙀": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"),
 
205
  "Pencil Art (characteristic/creative)✏️": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"),
206
  "Art Minimalistic (paint/semireal)🎨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"),
207
  }
208
+
209
+ # Style options.
210
  style_list = [
211
  {
212
  "name": "3840 x 2160",
 
233
  DEFAULT_STYLE_NAME = "3840 x 2160"
234
  STYLE_NAMES = list(styles.keys())
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def apply_style(style_name: str, positive: str, negative: str = ""):
237
+ if style_name in styles:
238
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  else:
240
+ p, n = styles[DEFAULT_STYLE_NAME]
241
+ return p.replace("{prompt}", positive), n + (negative if negative else "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ def generate_image_lora(prompt: str, negative_prompt: str, use_negative_prompt: bool, seed: int, width: int, height: int, guidance_scale: float, randomize_seed: bool, style_name: str, lora_model: str):
 
 
 
 
 
 
244
  seed = int(randomize_seed_fn(seed, randomize_seed))
245
+ positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt)
246
+ if not use_negative_prompt:
247
+ effective_negative_prompt = ""
248
+ # Set the desired LoRA adapter.
249
+ model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
250
+ sd_pipe.set_adapters(adapter_name)
251
+ # Generate image(s)
252
  options = {
253
+ "prompt": [positive_prompt],
254
+ "negative_prompt": [effective_negative_prompt],
255
  "width": width,
256
  "height": height,
257
  "guidance_scale": guidance_scale,
258
+ "num_inference_steps": 20,
259
+ "num_images_per_prompt": 1,
260
+ "cross_attention_kwargs": {"scale": 0.65},
261
  "output_type": "pil",
262
  }
263
+ outputs = sd_pipe(**options)
264
+ images = outputs.images
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  image_paths = [save_image(img) for img in images]
266
  return image_paths, seed
267
 
268
+ # ---------------------------
269
+ # Build Gradio Interface with Three Tabs
270
+ # ---------------------------
271
+ with gr.Blocks(css=".gradio-container {max-width: 900px; margin: auto;}") as demo:
272
+ gr.Markdown("## Multi-Functional Demo: Chat Interface | Qwen 2 VL OCR | Image Gen LoRA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
 
 
 
274
  with gr.Tabs():
275
+ # Tab 1: Chat Interface
276
  with gr.Tab("Chat Interface"):
277
+ chat_output = gr.Chatbot(label="Chat Conversation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  with gr.Row():
279
+ chat_inp = gr.Textbox(label="Enter your message", placeholder="Type your message here...", lines=2)
280
+ send_btn = gr.Button("Send")
281
  with gr.Row():
282
+ max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
283
+ temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
284
+ top_p_slider = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
285
+ top_k_slider = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
286
+ rep_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
287
+ state = gr.State([])
288
+
289
+ def chat_step(user_message, history, max_tokens, temp, top_p, top_k, rep_penalty):
290
+ response, updated_history = chat_generate(user_message, history, max_tokens, temp, top_p, top_k, rep_penalty)
291
+ return updated_history, updated_history
292
+
293
+ send_btn.click(chat_step,
294
+ inputs=[chat_inp, state, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, rep_penalty_slider],
295
+ outputs=[chat_output, state])
296
+ chat_inp.submit(chat_step,
297
+ inputs=[chat_inp, state, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, rep_penalty_slider],
298
+ outputs=[chat_output, state])
299
+
300
+ # Tab 2: Qwen 2 VL OCR
301
+ with gr.Tab("Qwen 2 VL OCR"):
302
+ gr.Markdown("Upload an image and enter a prompt. The model will return OCR/extraction or descriptive text from the image.")
303
+ ocr_inp = gr.Textbox(label="Enter prompt", placeholder="Describe what you want to extract...", lines=2)
304
+ image_inp = gr.Image(label="Upload Image", type="pil")
305
+ ocr_output = gr.Textbox(label="Output", placeholder="Model output will appear here...", lines=5)
306
+ ocr_btn = gr.Button("Run Qwen 2 VL OCR")
307
+ ocr_btn.click(generate_qwen_ocr, inputs=[ocr_inp, image_inp], outputs=ocr_output)
308
+
309
+ # Tab 3: Image Gen LoRA
310
+ with gr.Tab("Image Gen LoRA"):
311
+ gr.Markdown("Generate images with SDXL using various LoRA models and quality styles.")
312
  with gr.Row():
313
+ prompt_img = gr.Textbox(label="Prompt", placeholder="Enter prompt for image generation...", lines=2)
314
+ negative_prompt_img = gr.Textbox(label="Negative Prompt", placeholder="(optional) negative prompt", lines=2)
315
+ use_neg_checkbox = gr.Checkbox(label="Use Negative Prompt", value=True)
 
 
 
 
 
 
 
 
 
 
316
  with gr.Row():
317
+ seed_slider = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=0)
318
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
319
+ with gr.Row():
320
+ width_slider = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
321
+ height_slider = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
322
+ guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=3.0)
323
+ style_radio = gr.Radio(label="Quality Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
324
+ lora_dropdown = gr.Dropdown(label="LoRA Selection", choices=list(LORA_OPTIONS.keys()), value="Realism (face/character)👦🏻")
325
+ img_output = gr.Gallery(label="Generated Images", columns=1, preview=True)
326
+ seed_output = gr.Number(label="Used Seed")
327
+ run_img_btn = gr.Button("Generate Image")
328
+ run_img_btn.click(generate_image_lora,
329
+ inputs=[prompt_img, negative_prompt_img, use_neg_checkbox, seed_slider, width_slider, height_slider, guidance_slider, randomize_seed_checkbox, style_radio, lora_dropdown],
330
+ outputs=[img_output, seed_output])
331
+
332
+ gr.Markdown("### Adjustments")
333
+ gr.Markdown("Each tab has been implemented separately. Feel free to adjust parameters and layout as needed in each tab.")
334
 
335
  if __name__ == "__main__":
336
+ demo.queue(max_size=20).launch(share=True)