File size: 6,218 Bytes
179a06a
1
import gradio as gr import numpy as np import random from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline import torch import os from huggingface_hub import HfFolder import logging from typing import List, Tuple, Optional import gc from PIL import Image, ImageDraw, ImageFont  # Logging setup torch.backends.cudnn.benchmark = True logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__)  # Hugging Face authentication token = HfFolder.get_token() or os.getenv("HF_TOKEN") if token:     logger.info("Hugging Face token loaded.") else:     logger.warning("No HF_TOKEN found; public models only.")  # Device selection device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}")  # Models config model_configs = [     {"id": "runwayml/stable-diffusion-v1-5", "name": "Stable Diffusion v1.5"},     {"id": "stabilityai/stable-diffusion-2-1", "name": "Stable Diffusion 2.1"},     {"id": "Lykon/DreamShaper", "name": "DreamShaper"}, ] model_names = [m["name"] for m in model_configs]  # Constants MAX_SEED = np.iinfo(np.int32).max MAX_DIM = 1024  # Helper: create a simple error image  def create_error_image(msg: str, w=512, h=512) -> Image.Image:     img = Image.new("RGB", (w, h), (50, 0, 0))     draw = ImageDraw.Draw(img)     try:         font = ImageFont.truetype("arial.ttf", 18)     except IOError:         font = ImageFont.load_default()     draw.text((10, 10), msg, fill=(255,255,255), font=font)     return img  # Infer: returns images list for gallery, updated seed, and history  def infer(     prompt: str,     negative_prompt: str,     ref_image: Optional[Image.Image],     strength: float,     explicitness: int,     seed: int,     rand_seed: bool,     width: int,     height: int,     guidance: float,     steps: int,     selected: List[str],     history: List[Tuple[str, Image.Image]],     progress=gr.Progress() ) -> Tuple[List[Image.Image], int, List[Tuple[str, Image.Image]]]:     # Validate     if not prompt:         err = create_error_image("Enter a prompt.")         history.append(("Error: no prompt", err))         return [err], seed, history     if not selected:         err = create_error_image("Select a model.")         history.append(("Error: no model", err))         return [err], seed, history      # Seed     if rand_seed:         seed = random.randint(0, MAX_SEED)      # Prompt modifier     mod = ""     if explicitness >= 8:         mod = ", explicit"     elif explicitness >= 5:         mod = ", suggestive"     fp = prompt + mod      gen = torch.Generator(device=device).manual_seed(seed)     torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32      images = []     for cfg in model_configs:         if cfg["name"] not in selected:             continue         try:             # Select pipeline             if ref_image:                 pipe = StableDiffusionImg2ImgPipeline.from_pretrained(                     cfg["id"], torch_dtype=torch_dtype, use_auth_token=token                 ).to(device)                 img_in = ref_image.resize((width, height))                 out = pipe(                     prompt=fp,                     negative_prompt=negative_prompt,                     image=img_in,                     strength=strength,                     guidance_scale=guidance,                     num_inference_steps=steps,                     generator=gen                 ).images[0]             else:                 pipe = DiffusionPipeline.from_pretrained(                     cfg["id"], torch_dtype=torch_dtype, use_auth_token=token                 ).to(device)                 out = pipe(                     prompt=fp,                     negative_prompt=negative_prompt,                     width=width,                     height=height,                     guidance_scale=guidance,                     num_inference_steps=steps,                     generator=gen                 ).images[0]              # disable safety             if hasattr(pipe, 'safety_checker') and pipe.safety_checker:                 pipe.safety_checker = None              images.append(out)             history.append((f"{cfg['name']}: {prompt}", out))         except Exception as e:             err_img = create_error_image(str(e))             images.append(err_img)             history.append((f"{cfg['name']}: ERROR", err_img))         finally:             if 'pipe' in locals():                 pipe.to("cpu"); del pipe             torch.cuda.empty_cache(); gc.collect()      return images, seed, history  # Build UI def build_ui():     with gr.Blocks() as demo:         history_state = gr.State([])  # store (desc, img)          with gr.Sidebar():             ref = gr.Image(label="Reference (optional)", type="pil")             select = gr.CheckboxGroup(label="Models", choices=model_names, value=[model_names[0]])             history_dd = gr.Dropdown(choices=[], label="History")          with gr.Column():             prompt = gr.Textbox(label="Prompt")             neg = gr.Textbox(label="Negative Prompt")             strength = gr.Slider(0.1,1.0,0.6,label="Ref Strength")             explicit = gr.Slider(1,10,5,step=1,label="Explicitness")             seed = gr.Slider(0,MAX_SEED,0,step=1,label="Seed")             rand = gr.Checkbox(True, label="Randomize Seed")             w = gr.Slider(256,MAX_DIM,512,step=64,label="Width")             h = gr.Slider(256,MAX_DIM,512,step=64,label="Height")             guidance = gr.Slider(0.0,20.0,7.5,step=0.1,label="Guidance")             steps = gr.Slider(1,100,30,step=1,label="Steps")             run = gr.Button("Generate")             gallery = gr.Gallery(label="Results", columns=3)          run.click(             infer,             inputs=[prompt, neg, ref, strength, explicit, seed, rand, w, h, guidance, steps, select, history_state],             outputs=[gallery, seed, history_state]         )          # update history dropdown and gallery         history_state.change(lambda h: [d for d,_ in h], history_state, history_dd)         history_dd.change(lambda choice, h: next(img for d,img in h if d==choice), [history_dd, history_state], gallery)      return demo  if __name__ == "__main__":     build_ui().launch(debug=True)