ponix-generator / app.py
cwhuh's picture
chore : update color
d1ddaf9
raw
history blame
8.46 kB
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from llm_wrapper import run_gemini
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import subprocess
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
# PONIX mode load
pipe.load_lora_weights('cwhuh/ponix-generator-v0.1.0', weight_name='pytorch_lora_weights.safetensors')
embedding_path = hf_hub_download(repo_id='cwhuh/ponix-generator-v0.1.0', filename='./ponix-generator-v0.1.0_emb.safetensors', repo_type="model")
state_dict = load_file(embedding_path)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>", "<s2>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
torch.cuda.empty_cache()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
@spaces.GPU(duration=50)
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
refined_prompt = run_gemini(
target_prompt=prompt,
prompt_in_path="prompt.json",
)
print(f"Refined prompt: {refined_prompt}")
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=refined_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
output_type="pil",
good_vae=good_vae,
):
yield img, seed
examples = [
"기계곡학과(λ‘œμΌ“) ν¬λ‹‰μŠ€",
"λ°”μ΄μ˜¬λ¦°μ„ μ—°μ£Όν•˜λŠ” ν¬λ‹‰μŠ€",
"물리학을 μ—°κ΅¬ν•˜λŠ” ν¬λ‹‰μŠ€",
]
css="""
#col-container {
margin: 0 auto;
max-width: 580px;
}
.footer {
text-align: center;
margin-top: 20px;
font-size: 0.8em;
color: #666;
}
/* ν¬μŠ€ν… λ ˆλ“œ 색상 적용 */
button.primary {
background-color: rgb(200, 1, 80) !important;
border-color: rgb(200, 1, 80) !important;
}
button.primary:hover {
background-color: rgba(200, 1, 80, 0.8) !important;
border-color: rgb(200, 1, 80) !important;
}
.postech-red {
color: rgb(200, 1, 80);
}
.accordion .label-wrap {
color: rgb(200, 1, 80) !important;
}
.examples .icon {
color: rgb(200, 1, 80) !important;
}
input[type="checkbox"]:checked {
background-color: rgb(200, 1, 80) !important;
border-color: rgb(200, 1, 80) !important;
}
/* ν”„λ‘¬ν”„νŠΈ μž…λ ₯ λ°•μŠ€μ™€ 이미지 λ°•μŠ€ 배경색 λ³€κ²½ */
.gradio-container .prose input[type="text"],
.gradio-container .input-box,
.gradio-container .output-box {
background-color: rgb(200, 1, 80) !important;
border-color: rgb(200, 1, 80) !important;
color: white !important;
}
/* ν”Œλ ˆμ΄μŠ€ν™€λ” ν…μŠ€νŠΈ 색상 */
.gradio-container input::placeholder {
color: rgba(255, 255, 255, 0.7) !important;
}
/* λ ˆμ΄λΈ” 색상 */
label span {
color: rgb(200, 1, 80) !important;
}
/* Examples 헀더 색상 */
h3, .examples-header {
color: rgb(200, 1, 80) !important;
}
/* 링크 μŠ€νƒ€μΌ λ³€κ²½ */
a {
color: #666 !important;
text-decoration: underline !important;
transition: color 0.2s;
}
a:hover {
color: rgb(200, 1, 80) !important;
}
/* Examples λ²„νŠΌ ν…Œλ‘λ¦¬ 색상 */
.examples .gr-button {
border-color: rgb(200, 1, 80) !important;
}
.examples .gr-button:hover {
background-color: rgba(200, 1, 80, 0.1) !important;
}
/* μŠ¬λΌμ΄λ” 색상 */
input[type="range"]::-webkit-slider-thumb {
background: rgb(200, 1, 80) !important;
}
input[type="range"]::-moz-range-thumb {
background: rgb(200, 1, 80) !important;
}
/* Examples ν…μŠ€νŠΈ 색상 κ°•μ œ 적용 */
.examples-header h3 {
color: rgb(200, 1, 80) !important;
}
"""
with gr.Blocks(css=css, theme="soft") as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# πŸ” <span class="postech-red">[POSTECH]</span> PONIX Generator
[[based on FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]
""")
with gr.Group():
gr.Markdown("""
### πŸ” μ‚¬μš© κ°€μ΄λ“œ
- μƒμ„±ν•˜κ³  싢은 이미지λ₯Ό ν•œκΈ€λ‘œ κ°„λ‹¨ν•˜κ²Œ μž‘μ„±ν•΄μ£Όμ„Έμš”.
- μ΄λ―Έμ§€λŠ” λ…Έμ΄μ¦ˆμ—μ„œ 점차적으둜 μƒμ„±λ©λ‹ˆλ‹€. (40~50초 μ†Œμš”)
- λ¬Έμ˜λŠ” μ΄λ©”μΌλ‘œ λΆ€νƒλ“œλ¦½λ‹ˆλ‹€: [email protected]
""")
with gr.Group():
prompt = gr.Text(
label="ν”„λ‘¬ν”„νŠΈ μž…λ ₯",
max_lines=1,
placeholder="μ›ν•˜λŠ” ν¬λ‹‰μŠ€ 이미지λ₯Ό ν•œκΈ€λ‘œ μ„€λͺ…ν•΄μ£Όμ„Έμš”",
container=True,
)
run_button = gr.Button("πŸš€ μƒμ„±ν•˜κΈ°", variant="primary")
result = gr.Image(label="μƒμ„±λœ 이미지")
with gr.Accordion("πŸ› οΈ κ³ κΈ‰ μ„€μ •", open=False):
with gr.Group():
use_prompt_refinement = gr.Checkbox(
label="ν”„λ‘¬ν”„νŠΈ μžλ™ κ°œμ„ ",
value=True,
info="AIκ°€ μž…λ ₯ν•œ ν”„λ‘¬ν”„νŠΈλ₯Ό μžλ™μœΌλ‘œ κ°œμ„ ν•©λ‹ˆλ‹€."
)
with gr.Row():
seed = gr.Slider(
label="μ‹œλ“œ κ°’",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="랜덀 μ‹œλ“œ μ‚¬μš©", value=True)
with gr.Row():
width = gr.Slider(
label="λ„ˆλΉ„",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="높이",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="κ°€μ΄λ˜μŠ€ μŠ€μΌ€μΌ",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="μΆ”λ‘  단계 수",
minimum=1,
maximum=50,
step=1,
value=28,
)
# Examples 헀더에 클래슀 μΆ”κ°€
gr.Markdown("<div class='examples-header'><h3>μ˜ˆμ‹œ ν”„λ‘¬ν”„νŠΈ</h3></div>", elem_classes=["examples-header"])
gr.Examples(
examples = examples,
fn = infer,
inputs = [prompt],
outputs = [result, seed],
cache_examples="lazy"
)
gr.HTML("""
<div class="footer">
PONIX Generator by ν—ˆμ±„μ› | UG @ POSTECH
</div>
""")
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs = [result, seed]
)
demo.launch()