Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import uuid | |
import time | |
import asyncio | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
Qwen2VLForConditionalGeneration, | |
AutoProcessor, | |
) | |
from transformers.image_utils import load_image | |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
# --------------------------- | |
# Global Settings & Utilities | |
# --------------------------- | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
def save_image(img: Image.Image) -> str: | |
"""Save a PIL image with a unique filename and return the path.""" | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
MAX_SEED = np.iinfo(np.int32).max | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def progress_bar_html(label: str) -> str: | |
"""Returns an HTML snippet for a thin progress bar with a label.""" | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
# --------------------------- | |
# 1. Chat Interface Tab | |
# --------------------------- | |
# Uses a text-only model: FastThink-0.5B-Tiny | |
model_id_text = "prithivMLmods/FastThink-0.5B-Tiny" | |
tokenizer = AutoTokenizer.from_pretrained(model_id_text) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id_text, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
model.eval() | |
def clean_chat_history(chat_history): | |
""" | |
Filter out any chat entries whose "content" is not a string. | |
""" | |
cleaned = [] | |
for msg in chat_history: | |
if isinstance(msg, dict) and isinstance(msg.get("content"), str): | |
cleaned.append(msg) | |
return cleaned | |
def chat_generate(input_text: str, chat_history: list, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float): | |
""" | |
Chat generation using a text-only model. | |
""" | |
# Prepare conversation by cleaning history and appending the new user message. | |
conversation = clean_chat_history(chat_history) | |
conversation.append({"role": "user", "content": input_text}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"max_new_tokens": max_new_tokens, | |
"do_sample": True, | |
"top_p": top_p, | |
"top_k": top_k, | |
"temperature": temperature, | |
"num_beams": 1, | |
"repetition_penalty": repetition_penalty, | |
} | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
outputs = [] | |
# Collect the generated text from the streamer. | |
for new_text in streamer: | |
outputs.append(new_text) | |
final_response = "".join(outputs) | |
# Append assistant reply to chat history. | |
updated_history = conversation + [{"role": "assistant", "content": final_response}] | |
return final_response, updated_history | |
# --------------------------- | |
# 2. Qwen 2 VL OCR Tab | |
# --------------------------- | |
# Uses Qwen2VL OCR model for multimodal input (text + image) | |
MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" | |
processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True) | |
model_m = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_QWEN, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to("cuda").eval() | |
def generate_qwen_ocr(input_text: str, image): | |
""" | |
Uses the Qwen2VL OCR model to process an image along with text. | |
""" | |
if image is None: | |
return "No image provided." | |
# Build message with system and user content. | |
messages = [ | |
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, | |
{"role": "user", "content": [{"type": "text", "text": input_text}, {"type": "image", "image": image}]} | |
] | |
# Apply chat template. | |
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = { | |
**inputs, | |
"streamer": streamer, | |
"max_new_tokens": DEFAULT_MAX_NEW_TOKENS, | |
"do_sample": True, | |
"temperature": 0.6, | |
"top_p": 0.9, | |
"top_k": 50, | |
"repetition_penalty": 1.2, | |
} | |
thread = Thread(target=model_m.generate, kwargs=generation_kwargs) | |
thread.start() | |
outputs = [] | |
for new_text in streamer: | |
outputs.append(new_text.replace("<|im_end|>", "")) | |
final_response = "".join(outputs) | |
return final_response | |
# --------------------------- | |
# 3. Image Gen LoRA Tab | |
# --------------------------- | |
# Uses the SDXL pipeline with LoRA options. | |
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # set your SDXL model path via env variable | |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) | |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" | |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" | |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) | |
sd_pipe = StableDiffusionXLPipeline.from_pretrained( | |
MODEL_ID_SD, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_safetensors=True, | |
add_watermarker=False, | |
).to(device) | |
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config) | |
if torch.cuda.is_available(): | |
sd_pipe.text_encoder = sd_pipe.text_encoder.half() | |
if USE_TORCH_COMPILE: | |
sd_pipe.compile() | |
if ENABLE_CPU_OFFLOAD: | |
sd_pipe.enable_model_cpu_offload() | |
# LoRA options dictionary. | |
LORA_OPTIONS = { | |
"Realism (face/character)👦🏻": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"), | |
"Pixar (art/toons)🙀": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"), | |
"Photoshoot (camera/film)📸": ("prithivMLmods/Canopus-Photo-Shoot-Mini-LoRA", "Canopus-Photo-Shoot-Mini-LoRA.safetensors", "photo"), | |
"Clothing (hoodies/pant/shirts)👔": ("prithivMLmods/Canopus-Clothing-Adp-LoRA", "Canopus-Dress-Clothing-LoRA.safetensors", "clth"), | |
"Interior Architecture (house/hotel)🏠": ("prithivMLmods/Canopus-Interior-Architecture-0.1", "Canopus-Interior-Architecture-0.1δ.safetensors", "arch"), | |
"Fashion Product (wearing/usable)👜": ("prithivMLmods/Canopus-Fashion-Product-Dilation", "Canopus-Fashion-Product-Dilation.safetensors", "fashion"), | |
"Minimalistic Image (minimal/detailed)🏞️": ("prithivMLmods/Pegasi-Minimalist-Image-Style", "Pegasi-Minimalist-Image-Style.safetensors", "minimalist"), | |
"Modern Clothing (trend/new)👕": ("prithivMLmods/Canopus-Modern-Clothing-Design", "Canopus-Modern-Clothing-Design.safetensors", "mdrnclth"), | |
"Animaliea (farm/wild)🫎": ("prithivMLmods/Canopus-Animaliea-Artism", "Canopus-Animaliea-Artism.safetensors", "Animaliea"), | |
"Liquid Wallpaper (minimal/illustration)🖼️": ("prithivMLmods/Canopus-Liquid-Wallpaper-Art", "Canopus-Liquid-Wallpaper-Minimalize-LoRA.safetensors", "liquid"), | |
"Canes Cars (realistic/futurecars)🚘": ("prithivMLmods/Canes-Cars-Model-LoRA", "Canes-Cars-Model-LoRA.safetensors", "car"), | |
"Pencil Art (characteristic/creative)✏️": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"), | |
"Art Minimalistic (paint/semireal)🎨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"), | |
} | |
# Style options. | |
style_list = [ | |
{ | |
"name": "3840 x 2160", | |
"prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly", | |
}, | |
{ | |
"name": "2560 x 1440", | |
"prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly", | |
}, | |
{ | |
"name": "HD+", | |
"prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly", | |
}, | |
{ | |
"name": "Style Zero", | |
"prompt": "{prompt}", | |
"negative_prompt": "", | |
}, | |
] | |
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} | |
DEFAULT_STYLE_NAME = "3840 x 2160" | |
STYLE_NAMES = list(styles.keys()) | |
def apply_style(style_name: str, positive: str, negative: str = ""): | |
if style_name in styles: | |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) | |
else: | |
p, n = styles[DEFAULT_STYLE_NAME] | |
return p.replace("{prompt}", positive), n + (negative if negative else "") | |
def generate_image_lora(prompt: str, negative_prompt: str, use_negative_prompt: bool, seed: int, width: int, height: int, guidance_scale: float, randomize_seed: bool, style_name: str, lora_model: str): | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt) | |
if not use_negative_prompt: | |
effective_negative_prompt = "" | |
# Set the desired LoRA adapter. | |
model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model] | |
sd_pipe.set_adapters(adapter_name) | |
# Generate image(s) | |
options = { | |
"prompt": [positive_prompt], | |
"negative_prompt": [effective_negative_prompt], | |
"width": width, | |
"height": height, | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": 20, | |
"num_images_per_prompt": 1, | |
"cross_attention_kwargs": {"scale": 0.65}, | |
"output_type": "pil", | |
} | |
outputs = sd_pipe(**options) | |
images = outputs.images | |
image_paths = [save_image(img) for img in images] | |
return image_paths, seed | |
# --------------------------- | |
# Build Gradio Interface with Three Tabs | |
# --------------------------- | |
with gr.Blocks(css=".gradio-container {max-width: 900px; margin: auto;}") as demo: | |
gr.Markdown("## Multi-Functional Demo: Chat Interface | Qwen 2 VL OCR | Image Gen LoRA") | |
with gr.Tabs(): | |
# Tab 1: Chat Interface | |
with gr.Tab("Chat Interface"): | |
chat_output = gr.Chatbot(label="Chat Conversation") | |
with gr.Row(): | |
chat_inp = gr.Textbox(label="Enter your message", placeholder="Type your message here...", lines=2) | |
send_btn = gr.Button("Send") | |
with gr.Row(): | |
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) | |
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) | |
top_p_slider = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9) | |
top_k_slider = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) | |
rep_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) | |
state = gr.State([]) | |
def chat_step(user_message, history, max_tokens, temp, top_p, top_k, rep_penalty): | |
response, updated_history = chat_generate(user_message, history, max_tokens, temp, top_p, top_k, rep_penalty) | |
return updated_history, updated_history | |
send_btn.click(chat_step, | |
inputs=[chat_inp, state, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, rep_penalty_slider], | |
outputs=[chat_output, state]) | |
chat_inp.submit(chat_step, | |
inputs=[chat_inp, state, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, rep_penalty_slider], | |
outputs=[chat_output, state]) | |
# Tab 2: Qwen 2 VL OCR | |
with gr.Tab("Qwen 2 VL OCR"): | |
gr.Markdown("Upload an image and enter a prompt. The model will return OCR/extraction or descriptive text from the image.") | |
ocr_inp = gr.Textbox(label="Enter prompt", placeholder="Describe what you want to extract...", lines=2) | |
image_inp = gr.Image(label="Upload Image", type="pil") | |
ocr_output = gr.Textbox(label="Output", placeholder="Model output will appear here...", lines=5) | |
ocr_btn = gr.Button("Run Qwen 2 VL OCR") | |
ocr_btn.click(generate_qwen_ocr, inputs=[ocr_inp, image_inp], outputs=ocr_output) | |
# Tab 3: Image Gen LoRA | |
with gr.Tab("Image Gen LoRA"): | |
gr.Markdown("Generate images with SDXL using various LoRA models and quality styles.") | |
with gr.Row(): | |
prompt_img = gr.Textbox(label="Prompt", placeholder="Enter prompt for image generation...", lines=2) | |
negative_prompt_img = gr.Textbox(label="Negative Prompt", placeholder="(optional) negative prompt", lines=2) | |
use_neg_checkbox = gr.Checkbox(label="Use Negative Prompt", value=True) | |
with gr.Row(): | |
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=0) | |
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True) | |
with gr.Row(): | |
width_slider = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024) | |
height_slider = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024) | |
guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=3.0) | |
style_radio = gr.Radio(label="Quality Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME) | |
lora_dropdown = gr.Dropdown(label="LoRA Selection", choices=list(LORA_OPTIONS.keys()), value="Realism (face/character)👦🏻") | |
img_output = gr.Gallery(label="Generated Images", columns=1, preview=True) | |
seed_output = gr.Number(label="Used Seed") | |
run_img_btn = gr.Button("Generate Image") | |
run_img_btn.click(generate_image_lora, | |
inputs=[prompt_img, negative_prompt_img, use_neg_checkbox, seed_slider, width_slider, height_slider, guidance_slider, randomize_seed_checkbox, style_radio, lora_dropdown], | |
outputs=[img_output, seed_output]) | |
gr.Markdown("### Adjustments") | |
gr.Markdown("Each tab has been implemented separately. Feel free to adjust parameters and layout as needed in each tab.") | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(share=True) |