ginipick's picture
Update app.py
6bf8825 verified
raw
history blame
9.54 kB
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
import os
import spaces
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
device = "cuda"
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ckpt_dir = f'{root_dir}/weights/Kolors'
snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir)
snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus")
# Load models
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder',
ignore_mismatched_sizes=True
).to(dtype=torch.float16, device=device)
ip_img_size = 336
clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False
).to(device)
if hasattr(pipe.unet, 'encoder_hid_proj'):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# ----------------------------------------------
# infer ํ•จ์ˆ˜ (๊ธฐ์กด ๋กœ์ง ๊ทธ๋Œ€๋กœ ์œ ์ง€)
# ----------------------------------------------
@spaces.GPU(duration=80)
def infer(
user_prompt,
ip_adapter_image,
ip_adapter_scale=0.5,
negative_prompt="",
seed=100,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=5.0,
num_inference_steps=50,
progress=gr.Progress(track_tqdm=True)
):
# ์ˆจ๊ฒจ์ง„(๊ธฐ๋ณธ/ํ•„์ˆ˜) ํ”„๋กฌํ”„ํŠธ
hidden_prompt = (
"Studio Ghibli animation style, featuring whimsical characters with expressive eyes "
"and fluid movements. Lush, detailed natural environments with ethereal lighting "
"and soft color palettes of blues, greens, and warm earth tones."
)
# ์‹ค์ œ๋กœ ํŒŒ์ดํ”„๋ผ์ธ์— ์ „๋‹ฌํ•  ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ
prompt = f"{hidden_prompt}, {user_prompt}"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe.to("cuda")
image_encoder.to("cuda")
pipe.image_encoder = image_encoder
pipe.set_ip_adapter_scale([ip_adapter_scale])
image = pipe(
prompt=prompt,
ip_adapter_image=[ip_adapter_image],
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return image, seed
examples = [
[
"background alps",
"gh0.webp",
0.5
],
[
"dancing",
"gh5.jpg",
0.5
],
[
"smile",
"gh2.jpg",
0.5
],
[
"3d style",
"gh3.webp",
0.6
],
[
"with Pikachu",
"gh4.jpg",
0.5
],
[
" ",
"gh7.jpg",
0.6
],
[
"sunglass",
"gh1.jpg",
0.95
],
]
# --------------------------
# ๊ฐœ์„ ๋œ UI๋ฅผ ์œ„ํ•œ CSS
# --------------------------
css = """
body {
background: linear-gradient(135deg, #f5f7fa, #c3cfe2);
font-family: 'Helvetica Neue', Arial, sans-serif;
color: #333;
margin: 0;
padding: 0;
}
#col-container {
margin: 0 auto !important;
max-width: 720px;
background: rgba(255,255,255,0.85);
border-radius: 16px;
padding: 2rem;
box-shadow: 0 8px 24px rgba(0,0,0,0.1);
}
#header-title {
text-align: center;
font-size: 2rem;
font-weight: bold;
margin-bottom: 1rem;
}
#prompt-row {
display: flex;
gap: 0.5rem;
align-items: center;
margin-bottom: 1rem;
}
#prompt-text {
flex: 1;
}
#result img {
object-position: top;
border-radius: 8px;
}
#result .image-container {
height: 100%;
}
.gr-button {
background-color: #2E8BFB !important;
color: white !important;
border: none !important;
transition: background-color 0.2s ease;
}
.gr-button:hover {
background-color: #186EDB !important;
}
.gr-slider input[type=range] {
accent-color: #2E8BFB !important;
}
.gr-box {
background-color: #fafafa !important;
border: 1px solid #ddd !important;
border-radius: 8px !important;
padding: 1rem !important;
}
#advanced-settings {
margin-top: 1rem;
border-radius: 8px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("<div id='header-title'>Beyond Ghibli Reimagined</div>")
# ์ƒ๋‹จ: ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ + ์‹คํ–‰ ๋ฒ„ํŠผ
with gr.Row(elem_id="prompt-row"):
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
elem_id="prompt-text",
)
run_button = gr.Button("Run", elem_id="run-button")
# ๊ฐ€์šด๋ฐ: ์ด๋ฏธ์ง€ ์ž…๋ ฅ๊ณผ ์Šฌ๋ผ์ด๋”, ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€
with gr.Row():
with gr.Column():
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
ip_adapter_scale = gr.Slider(
label="Image influence scale",
info="Use 1 for creating variations",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
result = gr.Image(label="Result", elem_id="result")
# ํ•˜๋‹จ: ๊ณ ๊ธ‰ ์„ค์ •(Accordion)
with gr.Accordion("Advanced Settings", open=False, elem_id="advanced-settings"):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=2,
placeholder=(
"Copy(worst quality, low quality:1.4), bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, "
"normal quality, jpeg artifacts, signature, watermark, username, blurry, "
"artist name, (deformed iris, deformed pupils:1.2), (semi-realistic, cgi, "
"3d, render:1.1), amateur, (poorly drawn hands, poorly drawn face:1.2)"
),
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=50,
)
# ์˜ˆ์‹œ๋“ค
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, ip_adapter_image, ip_adapter_scale],
outputs=[result, seed],
cache_examples="lazy"
)
# ๋ฒ„ํŠผ ํด๋ฆญ/ํ”„๋กฌํ”„ํŠธ ์—”ํ„ฐ ์‹œ ์‹คํ–‰
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
ip_adapter_image,
ip_adapter_scale,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps
],
outputs=[result, seed]
)
demo.queue().launch()