File size: 2,612 Bytes
816c445
3fe3a3a
 
 
 
 
941af41
3fe3a3a
 
 
 
 
 
20bbc42
 
 
3fe3a3a
20bbc42
 
 
3fe3a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe1cda6
3fe3a3a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import gradio as gr
from diffusers import StableDiffusionXLPipeline

# --- Settings and paths ---
# Base SDXL model – change this to the base model you want to use.
BASE_MODEL = "ByteDance/Hyper-SD"
# Path to your LoRA weights (assumed to be in a format that Diffusers can use)
LORA_PATH = "fofr/sdxl-emoji"

# --- Load the base pipeline ---
pipe = StableDiffusionXLPipeline.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float32,  # Use FP32 for CPU
    variant="fp16",            # You may also need to adjust this if not using GPU
    safety_checker=None,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)

# --- Enable fast attention if available ---
try:
    pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
    print("xFormers not enabled:", e)

# --- Apply the LoRA weights ---
# Diffusers v0.18+ supports applying LoRA weights to parts of the pipeline.
# Here we assume the LoRA affects the UNet (and, if needed, the text encoder).
try:
    # For the UNet:
    pipe.unet.load_attn_procs(LORA_PATH)
    # If you also have LoRA weights for the text encoder, you might do:
    # pipe.text_encoder.load_attn_procs(LORA_PATH)
except Exception as e:
    print("Error applying LoRA weights:", e)

# --- Define the image generation function ---
def generate_image(prompt: str, steps: int = 30, guidance: float = 7.5):
    """
    Generate an image from a text prompt.
    
    Args:
        prompt (str): The text prompt.
        steps (int): Number of inference steps.
        guidance (float): Guidance scale (higher values encourage the image to follow the prompt).
        
    Returns:
        A generated PIL image.
    """
    # Use autocast for faster FP16 inference on CUDA
    with torch.cuda.amp.autocast():
        result = pipe(prompt, num_inference_steps=steps, guidance_scale=guidance)
    return result.images[0]

# --- Build the Gradio interface ---
demo = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
        gr.Slider(minimum=1, maximum=8, step=1, value=30, label="Inference Steps"),
        gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale")
    ],
    outputs=gr.Image(type="pil", label="Generated Image"),
    title="Super Fast SDXL-Emoji Generator",
    description=(
        "This demo uses a Stable Diffusion XL model enhanced with a custom LoRA "
        "to generate images quickly. Adjust the prompt and settings below, then hit 'Submit'!"
    ),
)

# --- Launch the demo ---
demo.launch()