prithivMLmods's picture
Update app.py
3e5ce53 verified
raw
history blame
16 kB
import os
import random
import uuid
import time
import asyncio
from threading import Thread
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
import cv2
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
Qwen2VLForConditionalGeneration,
AutoProcessor,
)
from transformers.image_utils import load_image
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
# ---------------------------
# Global Settings & Utilities
# ---------------------------
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def save_image(img: Image.Image) -> str:
"""Save a PIL image with a unique filename and return the path."""
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
MAX_SEED = np.iinfo(np.int32).max
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def progress_bar_html(label: str) -> str:
"""Returns an HTML snippet for a thin progress bar with a label."""
return f'''
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
'''
# ---------------------------
# 1. Chat Interface Tab
# ---------------------------
# Uses a text-only model: FastThink-0.5B-Tiny
model_id_text = "prithivMLmods/FastThink-0.5B-Tiny"
tokenizer = AutoTokenizer.from_pretrained(model_id_text)
model = AutoModelForCausalLM.from_pretrained(
model_id_text,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
def clean_chat_history(chat_history):
"""
Filter out any chat entries whose "content" is not a string.
"""
cleaned = []
for msg in chat_history:
if isinstance(msg, dict) and isinstance(msg.get("content"), str):
cleaned.append(msg)
return cleaned
@spaces.GPU
def chat_generate(input_text: str, chat_history: list, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
"""
Chat generation using a text-only model.
"""
# Prepare conversation by cleaning history and appending the new user message.
conversation = clean_chat_history(chat_history)
conversation.append({"role": "user", "content": input_text})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
outputs = []
# Collect the generated text from the streamer.
for new_text in streamer:
outputs.append(new_text)
final_response = "".join(outputs)
# Append assistant reply to chat history.
updated_history = conversation + [{"role": "assistant", "content": final_response}]
return final_response, updated_history
# ---------------------------
# 2. Qwen 2 VL OCR Tab
# ---------------------------
# Uses Qwen2VL OCR model for multimodal input (text + image)
MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID_QWEN,
trust_remote_code=True,
torch_dtype=torch.float16
).to("cuda").eval()
@spaces.GPU
def generate_qwen_ocr(input_text: str, image):
"""
Uses the Qwen2VL OCR model to process an image along with text.
"""
if image is None:
return "No image provided."
# Build message with system and user content.
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "text": input_text}, {"type": "image", "image": image}]}
]
# Apply chat template.
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to("cuda")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
"do_sample": True,
"temperature": 0.6,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.2,
}
thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
thread.start()
outputs = []
for new_text in streamer:
outputs.append(new_text.replace("<|im_end|>", ""))
final_response = "".join(outputs)
return final_response
# ---------------------------
# 3. Image Gen LoRA Tab
# ---------------------------
# Uses the SDXL pipeline with LoRA options.
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # set your SDXL model path via env variable
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
sd_pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID_SD,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=True,
add_watermarker=False,
).to(device)
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
if torch.cuda.is_available():
sd_pipe.text_encoder = sd_pipe.text_encoder.half()
if USE_TORCH_COMPILE:
sd_pipe.compile()
if ENABLE_CPU_OFFLOAD:
sd_pipe.enable_model_cpu_offload()
# LoRA options dictionary.
LORA_OPTIONS = {
"Realism (face/character)👦🏻": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"),
"Pixar (art/toons)🙀": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"),
"Photoshoot (camera/film)📸": ("prithivMLmods/Canopus-Photo-Shoot-Mini-LoRA", "Canopus-Photo-Shoot-Mini-LoRA.safetensors", "photo"),
"Clothing (hoodies/pant/shirts)👔": ("prithivMLmods/Canopus-Clothing-Adp-LoRA", "Canopus-Dress-Clothing-LoRA.safetensors", "clth"),
"Interior Architecture (house/hotel)🏠": ("prithivMLmods/Canopus-Interior-Architecture-0.1", "Canopus-Interior-Architecture-0.1δ.safetensors", "arch"),
"Fashion Product (wearing/usable)👜": ("prithivMLmods/Canopus-Fashion-Product-Dilation", "Canopus-Fashion-Product-Dilation.safetensors", "fashion"),
"Minimalistic Image (minimal/detailed)🏞️": ("prithivMLmods/Pegasi-Minimalist-Image-Style", "Pegasi-Minimalist-Image-Style.safetensors", "minimalist"),
"Modern Clothing (trend/new)👕": ("prithivMLmods/Canopus-Modern-Clothing-Design", "Canopus-Modern-Clothing-Design.safetensors", "mdrnclth"),
"Animaliea (farm/wild)🫎": ("prithivMLmods/Canopus-Animaliea-Artism", "Canopus-Animaliea-Artism.safetensors", "Animaliea"),
"Liquid Wallpaper (minimal/illustration)🖼️": ("prithivMLmods/Canopus-Liquid-Wallpaper-Art", "Canopus-Liquid-Wallpaper-Minimalize-LoRA.safetensors", "liquid"),
"Canes Cars (realistic/futurecars)🚘": ("prithivMLmods/Canes-Cars-Model-LoRA", "Canes-Cars-Model-LoRA.safetensors", "car"),
"Pencil Art (characteristic/creative)✏️": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"),
"Art Minimalistic (paint/semireal)🎨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"),
}
# Style options.
style_list = [
{
"name": "3840 x 2160",
"prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
},
{
"name": "2560 x 1440",
"prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
},
{
"name": "HD+",
"prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
},
{
"name": "Style Zero",
"prompt": "{prompt}",
"negative_prompt": "",
},
]
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
DEFAULT_STYLE_NAME = "3840 x 2160"
STYLE_NAMES = list(styles.keys())
def apply_style(style_name: str, positive: str, negative: str = ""):
if style_name in styles:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
else:
p, n = styles[DEFAULT_STYLE_NAME]
return p.replace("{prompt}", positive), n + (negative if negative else "")
@spaces.GPU
def generate_image_lora(prompt: str, negative_prompt: str, use_negative_prompt: bool, seed: int, width: int, height: int, guidance_scale: float, randomize_seed: bool, style_name: str, lora_model: str):
seed = int(randomize_seed_fn(seed, randomize_seed))
positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt)
if not use_negative_prompt:
effective_negative_prompt = ""
# Set the desired LoRA adapter.
model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
sd_pipe.set_adapters(adapter_name)
# Generate image(s)
options = {
"prompt": [positive_prompt],
"negative_prompt": [effective_negative_prompt],
"width": width,
"height": height,
"guidance_scale": guidance_scale,
"num_inference_steps": 20,
"num_images_per_prompt": 1,
"cross_attention_kwargs": {"scale": 0.65},
"output_type": "pil",
}
outputs = sd_pipe(**options)
images = outputs.images
image_paths = [save_image(img) for img in images]
return image_paths, seed
# ---------------------------
# Build Gradio Interface with Three Tabs
# ---------------------------
with gr.Blocks(css=".gradio-container {max-width: 900px; margin: auto;}") as demo:
gr.Markdown("## Multi-Functional Demo: Chat Interface | Qwen 2 VL OCR | Image Gen LoRA")
with gr.Tabs():
# Tab 1: Chat Interface
with gr.Tab("Chat Interface"):
chat_output = gr.Chatbot(label="Chat Conversation")
with gr.Row():
chat_inp = gr.Textbox(label="Enter your message", placeholder="Type your message here...", lines=2)
send_btn = gr.Button("Send")
with gr.Row():
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p_slider = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k_slider = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
rep_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
state = gr.State([])
def chat_step(user_message, history, max_tokens, temp, top_p, top_k, rep_penalty):
response, updated_history = chat_generate(user_message, history, max_tokens, temp, top_p, top_k, rep_penalty)
return updated_history, updated_history
send_btn.click(chat_step,
inputs=[chat_inp, state, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, rep_penalty_slider],
outputs=[chat_output, state])
chat_inp.submit(chat_step,
inputs=[chat_inp, state, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, rep_penalty_slider],
outputs=[chat_output, state])
# Tab 2: Qwen 2 VL OCR
with gr.Tab("Qwen 2 VL OCR"):
gr.Markdown("Upload an image and enter a prompt. The model will return OCR/extraction or descriptive text from the image.")
ocr_inp = gr.Textbox(label="Enter prompt", placeholder="Describe what you want to extract...", lines=2)
image_inp = gr.Image(label="Upload Image", type="pil")
ocr_output = gr.Textbox(label="Output", placeholder="Model output will appear here...", lines=5)
ocr_btn = gr.Button("Run Qwen 2 VL OCR")
ocr_btn.click(generate_qwen_ocr, inputs=[ocr_inp, image_inp], outputs=ocr_output)
# Tab 3: Image Gen LoRA
with gr.Tab("Image Gen LoRA"):
gr.Markdown("Generate images with SDXL using various LoRA models and quality styles.")
with gr.Row():
prompt_img = gr.Textbox(label="Prompt", placeholder="Enter prompt for image generation...", lines=2)
negative_prompt_img = gr.Textbox(label="Negative Prompt", placeholder="(optional) negative prompt", lines=2)
use_neg_checkbox = gr.Checkbox(label="Use Negative Prompt", value=True)
with gr.Row():
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=0)
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
width_slider = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
height_slider = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=3.0)
style_radio = gr.Radio(label="Quality Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
lora_dropdown = gr.Dropdown(label="LoRA Selection", choices=list(LORA_OPTIONS.keys()), value="Realism (face/character)👦🏻")
img_output = gr.Gallery(label="Generated Images", columns=1, preview=True)
seed_output = gr.Number(label="Used Seed")
run_img_btn = gr.Button("Generate Image")
run_img_btn.click(generate_image_lora,
inputs=[prompt_img, negative_prompt_img, use_neg_checkbox, seed_slider, width_slider, height_slider, guidance_slider, randomize_seed_checkbox, style_radio, lora_dropdown],
outputs=[img_output, seed_output])
gr.Markdown("### Adjustments")
gr.Markdown("Each tab has been implemented separately. Feel free to adjust parameters and layout as needed in each tab.")
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)