Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from PIL import Image | |
import os | |
import spaces | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline | |
from kolors.models.modeling_chatglm import ChatGLMModel | |
from kolors.models.tokenization_chatglm import ChatGLMTokenizer | |
from kolors.models.unet_2d_condition import UNet2DConditionModel | |
from diffusers import AutoencoderKL, EulerDiscreteScheduler | |
from huggingface_hub import snapshot_download | |
device = "cuda" | |
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
ckpt_dir = f'{root_dir}/weights/Kolors' | |
snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir) | |
snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus") | |
# Load models | |
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device) | |
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') | |
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device) | |
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") | |
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device) | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder', | |
ignore_mismatched_sizes=True | |
).to(dtype=torch.float16, device=device) | |
ip_img_size = 336 | |
clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size) | |
pipe = StableDiffusionXLPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
image_encoder=image_encoder, | |
feature_extractor=clip_image_processor, | |
force_zeros_for_empty_prompt=False | |
).to(device) | |
#pipe = pipe.to(device) | |
#pipe.enable_model_cpu_offload() | |
if hasattr(pipe.unet, 'encoder_hid_proj'): | |
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj | |
pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"]) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
def infer(prompt, ip_adapter_image, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
pipe.to("cuda") | |
image_encoder.to("cuda") | |
pipe.image_encoder = image_encoder | |
pipe.set_ip_adapter_scale([ip_adapter_scale]) | |
image = pipe( | |
prompt=prompt, | |
ip_adapter_image=[ip_adapter_image], | |
negative_prompt=negative_prompt, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
generator=generator, | |
).images[0] | |
return image, seed | |
examples = [ | |
["Studio Ghibli animation style, featuring whimsical characters with expressive eyes and fluid movements. Lush, detailed natural environments with ethereal lighting and soft color palettes of blues, greens, and warm earth tones. "aged baby"", "gh1.jpg", 0.5], | |
["Studio Ghibli animation style, featuring whimsical characters with expressive eyes and fluid movements. Lush, detailed natural environments with ethereal lighting and soft color palettes of blues, greens, and warm earth tones. "aged baby"", "gh2.jpg", 0.5], | |
["Studio Ghibli animation style, featuring whimsical characters with expressive eyes and fluid movements. Lush, detailed natural environments with ethereal lighting and soft color palettes of blues, greens, and warm earth tones. "aged baby"", "gh3.webp", 0.5], | |
["Studio Ghibli animation style, featuring whimsical characters with expressive eyes and fluid movements. Lush, detailed natural environments with ethereal lighting and soft color palettes of blues, greens, and warm earth tones. "aged baby"", "gh4.webp", 0.5], | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 720px; | |
} | |
#result img{ | |
object-position: top; | |
} | |
#result .image-container{ | |
height: 100% | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f""" | |
# Beyond Ghibli Reimagined | |
""") | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0) | |
with gr.Row(): | |
with gr.Column(): | |
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil") | |
ip_adapter_scale = gr.Slider( | |
label="Image influence scale", | |
info="Use 1 for creating variations", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.5, | |
) | |
result = gr.Image(label="Result", elem_id="result") | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Copy(worst quality, low quality:1.4), bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, (deformed iris, deformed pupils:1.2), (semi-realistic, cgi, 3d, render:1.1), amateur, (poorly drawn hands, poorly drawn face:1.2)", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=5.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=50, | |
) | |
gr.Examples( | |
examples=examples, | |
fn=infer, | |
inputs=[prompt, ip_adapter_image, ip_adapter_scale], | |
outputs=[result, seed], | |
cache_examples="lazy" | |
) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn=infer, | |
inputs=[prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], | |
outputs=[result, seed] | |
) | |
demo.queue().launch() |