|
import marimo |
|
|
|
__generated_with = "0.9.14" |
|
app = marimo.App(width="medium") |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(): |
|
import random |
|
|
|
import marimo as mo |
|
import numpy as np |
|
import tqdm |
|
import transformers |
|
|
|
|
|
tqdm.auto.tqdm = tqdm.notebook.tqdm |
|
|
|
import torch |
|
from diffusers import DiffusionPipeline |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
MAX_IMAGE_SIZE = 1024 |
|
|
|
model_repo_id = ( |
|
"stabilityai/sdxl-turbo" |
|
) |
|
return ( |
|
DiffusionPipeline, |
|
MAX_IMAGE_SIZE, |
|
MAX_SEED, |
|
mo, |
|
model_repo_id, |
|
np, |
|
random, |
|
torch, |
|
tqdm, |
|
transformers, |
|
) |
|
|
|
|
|
@app.cell |
|
def __(mo, model_repo_id): |
|
mo.md(f"""# HuggingFace Text-to-Image: **{model_repo_id}**""") |
|
return |
|
|
|
|
|
@app.cell(hide_code=True) |
|
def __(): |
|
examples = [ |
|
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", |
|
"An astronaut riding a green horse", |
|
"A delicious ceviche cheesecake slice", |
|
] |
|
return (examples,) |
|
|
|
|
|
@app.cell |
|
def __( |
|
DiffusionPipeline, |
|
MAX_SEED, |
|
mo, |
|
model_repo_id, |
|
random, |
|
torch, |
|
): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if torch.cuda.is_available(): |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) |
|
pipe = pipe.to(device) |
|
|
|
def infer( |
|
prompt, |
|
negative_prompt, |
|
seed, |
|
randomize_seed, |
|
width, |
|
height, |
|
guidance_scale, |
|
num_inference_steps, |
|
): |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
image = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
width=width, |
|
height=height, |
|
generator=generator, |
|
).images[0] |
|
|
|
return image, seed |
|
|
|
mo.output.clear() |
|
return device, infer, pipe, torch_dtype |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
get_prompt, set_prompt = mo.state("") |
|
return get_prompt, set_prompt |
|
|
|
|
|
@app.cell |
|
def __(get_prompt, mo, set_prompt): |
|
prompt = mo.ui.text_area( |
|
placeholder="Enter your prompt", |
|
label="Prompt", |
|
full_width=True, |
|
value=get_prompt(), |
|
on_change=set_prompt, |
|
) |
|
return (prompt,) |
|
|
|
|
|
@app.cell |
|
def __(examples, mo, set_prompt): |
|
def _on_click(example): |
|
def handle(v): |
|
set_prompt(example) |
|
|
|
return handle |
|
|
|
buttons = mo.ui.array( |
|
[ |
|
mo.ui.button(label=example, on_click=_on_click(example)) |
|
for example in examples |
|
] |
|
) |
|
|
|
example_options = mo.vstack(buttons) |
|
return buttons, example_options |
|
|
|
|
|
@app.cell |
|
def __(MAX_IMAGE_SIZE, MAX_SEED, example_options, mo, prompt): |
|
run_button = mo.ui.run_button(label="Run", kind="success", full_width=True) |
|
|
|
negative_prompt = mo.ui.text_area( |
|
placeholder="Enter a negative prompt", label="Negative prompt" |
|
) |
|
|
|
seed = mo.ui.slider(start=0, stop=MAX_SEED, value=0, label="Seed") |
|
|
|
randomize_seed = mo.ui.checkbox(label="Randomize seed", value=True) |
|
|
|
width = mo.ui.slider( |
|
start=256, stop=MAX_IMAGE_SIZE, step=32, value=1024, label="Width" |
|
) |
|
height = mo.ui.slider( |
|
start=256, stop=MAX_IMAGE_SIZE, step=32, value=1024, label="Height" |
|
) |
|
|
|
guidance_scale = mo.ui.slider( |
|
start=0.0, stop=10.0, step=0.1, value=0.0, label="Guidance scale" |
|
) |
|
|
|
num_inference_steps = mo.ui.slider( |
|
start=1, stop=50, step=1, value=2, label="Number of inference steps" |
|
) |
|
|
|
|
|
advanced_settings = mo.accordion( |
|
{ |
|
"::lucide:list:: Examples": example_options, |
|
"::lucide:settings:: Advanced Settings": mo.hstack( |
|
[ |
|
mo.vstack([negative_prompt, seed, randomize_seed]), |
|
mo.vstack( |
|
[width, height, guidance_scale, num_inference_steps], |
|
align="end", |
|
), |
|
] |
|
).style(padding="10px"), |
|
}, |
|
) |
|
|
|
|
|
mo.vstack([prompt, run_button, advanced_settings]) |
|
return ( |
|
advanced_settings, |
|
guidance_scale, |
|
height, |
|
negative_prompt, |
|
num_inference_steps, |
|
randomize_seed, |
|
run_button, |
|
seed, |
|
width, |
|
) |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
get_image, set_image = mo.state(None) |
|
return get_image, set_image |
|
|
|
|
|
@app.cell |
|
def __( |
|
guidance_scale, |
|
height, |
|
infer, |
|
mo, |
|
negative_prompt, |
|
num_inference_steps, |
|
prompt, |
|
randomize_seed, |
|
run_button, |
|
seed, |
|
set_image, |
|
width, |
|
): |
|
mo.stop(not run_button.value) |
|
|
|
_image, _seed = infer( |
|
prompt.value, |
|
negative_prompt.value, |
|
seed.value, |
|
randomize_seed.value, |
|
width.value, |
|
height.value, |
|
guidance_scale.value, |
|
num_inference_steps.value, |
|
) |
|
set_image(_image) |
|
mo.output.clear() |
|
return |
|
|
|
|
|
@app.cell |
|
def __(get_image): |
|
get_image() |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run() |
|
|