educrpg's picture
Update app.py
3debdb5 verified
raw
history blame
8.98 kB
import os
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
from huggingface_hub import InferenceClient
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
#pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
sdxl = InferenceClient(model="stabilityai/stable-diffusion-xl-base-1.0", token=os.environ['HF_TOKEN'])
pipeline2Image = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtypes=torch.bfloat16).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# (duration=190)
#@spaces.GPU
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# generator = torch.Generator().manual_seed(seed)
# image = pipe(
# prompt=prompt,
# width=width,
# height=height,
# num_inference_steps=num_inference_steps,
# generator=generator,
# guidance_scale=guidance_scale
# ).images[0]
image = sdxl.text_to_image(
prompt,
guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, seed=seed,width=width, height=height
)
return image, seed
examples = [
"a tiny astronaut hatching from an egg on the moon",
"a cat holding a sign that says hello world",
"an anime illustration of a wiener schnitzel",
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX.1 [dev]
12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt],
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed]
)
# Adding image input options at the bottom
gr.Markdown("## Upload or select an additional image")
with gr.Row():
uploaded_image = gr.Image(label="Upload Image", type="pil")
image_url = gr.Textbox(label="Image URL", placeholder="Enter image URL")
use_generated_image = gr.Button("Use Generated Image")
with gr.Accordion("Advanced Settings", open=False):
seed2 = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed2 = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width2 = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height2 = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
strength2 = gr.Slider(
label="Strength",
minimum=.1,
maximum=1,
step=0.1,
value=.5,
)
guidance_scale2 = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps2 = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
prompt2 = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run2_button = gr.Button("Run", scale=0)
additional_image_output = gr.Image(label="Selected Image", show_label=False)
def select_image(uploaded_image, image_url, use_generated=False):
if use_generated:
return result.value
elif uploaded_image is not None:
return uploaded_image
elif image_url:
try:
img = gr.Image.load(image_url)
return img
except Exception as e:
return f"Failed to load image from URL: {e}"
return None
def image2image(uploaded_image, image_url, use_generated=False):
image = select_image(uploaded_image, image_url, use_generated=use_generated)
#prompt = "one awesome dude"
#generator = torch.Generator(device=device).manual_seed(1024)
#image = pipeline2Image(prompt=prompt, image=image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
return image
use_generated_image.click(fn=lambda: image2image(None, None, True), inputs=[], outputs=additional_image_output)
uploaded_image.change(fn=image2image, inputs=[uploaded_image, image_url, gr.State(False)], outputs=additional_image_output)
image_url.submit(fn=image2image, inputs=[uploaded_image, image_url, gr.State(False)], outputs=additional_image_output)
@spaces.GPU(duration=190)
def infer2(prompt, image, seed=42, randomize_seed=False, width=1024, height=1024, strength=.5, guidance_scale=5.0, num_inference_steps=28):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
image2 = pipeline2Image(prompt=prompt, image=image, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator).images[0]
# generator = torch.Generator().manual_seed(seed)
# image = pipe(
# prompt=prompt,
# width=width,
# height=height,
# num_inference_steps=num_inference_steps,
# generator=generator,
# guidance_scale=guidance_scale
# ).images[0]
return image, seed
final_image_output = gr.Image(label="Final Image", show_label=False)
gr.on(
triggers=[run2_button.click, prompt2.submit],
fn=infer2,
inputs=[prompt2, additional_image_output, seed2, randomize_seed2, width2, height2, strength2, guidance_scale2, num_inference_steps2],
outputs=[final_image_output, seed2]
)
demo.launch()