StarryXL-Demo / app.py
eienmojiki's picture
Update app.py
18ac332 verified
raw
history blame
6.43 kB
import os
import random
from typing import Callable, Dict, Optional, Tuple
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from transformers import CLIPTextModel
from diffusers import AutoencoderKL, StableDiffusionXLPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
MODEL = "eienmojiki/Starry-XL-v5.2"
HF_TOKEN = os.getenv("HF_TOKEN")
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
MAX_SEED = np.iinfo(np.int32).max
sampler_list = [
"DPM++ 2M Karras",
"DPM++ SDE Karras",
"DPM++ 2M SDE Karras",
"Euler",
"Euler a",
"DDIM",
]
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def seed_everything(seed: int) -> torch.Generator:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
generator = torch.Generator()
generator.manual_seed(seed)
return generator
def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
scheduler_factory_map = {
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
scheduler_config, use_karras_sigmas=True
),
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
scheduler_config, use_karras_sigmas=True
),
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
),
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(scheduler_config),
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
}
return scheduler_factory_map.get(name, lambda: None)()
def load_pipeline(model_name):
pipe = StableDiffusionXLPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
custom_pipeline="lpw_stable_diffusion_xl",
safety_checker = None,
use_safetensors=True,
add_watermarker=False,
use_auth_token=HF_TOKEN,
)
pipe.to(device)
return pipe
@spaces.GPU
def generate(
prompt: str,
negative_prompt: str = None,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 5.0,
num_inference_steps: int = 24,
sampler: str = "Euler a",
clip_skip: int = 1,
progress=gr.Progress(track_tqdm=True),
):
generator = seed_everything(seed)
pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler)
pipe.text_encoder = CLIPTextModel.from_pretrained(
MODEL,
subfolder = "text_encoder",
num_hidden_layers = 12 - (clip_skip - 1),
torch_dtype = torch.float16
)
try:
img = pipe(
prompt = prompt,
negative_prompt = negative_prompt,
width = width,
height = height,
guidance_scale = guidance_scale,
num_inference_steps = num_inference_steps,
generator = generator,
output_type="pil",
).images[0]
return img, seed
except Exception as e:
print(f"An error occurred: {e}")
if torch.cuda.is_available():
pipe = load_pipeline(MODEL)
print("Loaded on Device!")
else:
pipe = None
with gr.Blocks(
theme=gr.themes.Soft()
) as demo:
gr.Markdown("# Starry XL 5.2 Demo")
with gr.Group():
prompt = gr.Text(
label="Prompt",
placeholder="Enter your prompt here..."
)
negative_prompt = gr.Text(
label="Negative Prompt",
placeholder="(Optional) Enter your negative prompt here..."
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
sampler = gr.Dropdown(
label="Sampler",
choices=sampler_list,
interactive=True,
value="Euler a",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=20,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Steps",
minimum=10,
maximum=100,
step=1,
value=25,
)
clip_skip = gr.Slider(
label="Clip Skip",
minimum=1,
maximum=2,
step=1,
value=1
)
run_button = gr.Button("Run")
result = gr.Image(
label="Result",
show_label=False
)
with gr.Group():
used_seed = gr.Number(label="Used Seed", interactive=False)
gr.on(
triggers=[
prompt.submit,
negative_prompt.submit,
run_button.click,
],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=[
prompt,
negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
sampler,
clip_skip
],
outputs=[result, used_seed],
api_name="run"
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(show_error=True)