Update app.py
Browse files
app.py
CHANGED
@@ -10,19 +10,20 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
10 |
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
|
11 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
12 |
|
13 |
-
# Preload the Stable Diffusion pipeline on
|
14 |
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
15 |
-
pipe = pipe.to(
|
16 |
|
17 |
MAX_SEED = np.iinfo(np.int32).max
|
18 |
MAX_IMAGE_SIZE = 1024
|
19 |
|
20 |
def truncate_text(text, max_tokens=77):
|
21 |
"""
|
22 |
-
|
23 |
"""
|
24 |
if text.strip() == "":
|
25 |
return text
|
|
|
26 |
tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
|
27 |
truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
|
28 |
return truncated_text
|
@@ -39,20 +40,16 @@ def infer(
|
|
39 |
num_inference_steps=40,
|
40 |
progress=gr.Progress(track_tqdm=True),
|
41 |
):
|
|
|
42 |
if randomize_seed:
|
43 |
seed = random.randint(0, MAX_SEED)
|
44 |
-
generator = torch.Generator(device=
|
45 |
-
|
46 |
-
#
|
47 |
prompt = truncate_text(prompt, max_tokens=77)
|
48 |
negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
|
49 |
-
|
50 |
-
#
|
51 |
-
pipe.unet.to("cuda")
|
52 |
-
pipe.text_encoder.to("cuda")
|
53 |
-
pipe.vae.to("cuda")
|
54 |
-
|
55 |
-
# Generate the image
|
56 |
image = pipe(
|
57 |
prompt=prompt,
|
58 |
negative_prompt=negative_prompt,
|
@@ -62,15 +59,10 @@ def infer(
|
|
62 |
height=height,
|
63 |
generator=generator,
|
64 |
).images[0]
|
65 |
-
|
66 |
-
# Move pipeline components back to CPU
|
67 |
-
pipe.unet.to("cpu")
|
68 |
-
pipe.text_encoder.to("cpu")
|
69 |
-
pipe.vae.to("cpu")
|
70 |
-
|
71 |
return image, seed
|
72 |
|
73 |
-
# Gradio UI
|
74 |
examples = [
|
75 |
"A capybara wearing a suit holding a sign that reads Hello World",
|
76 |
]
|
|
|
10 |
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
|
11 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
12 |
|
13 |
+
# Preload the Stable Diffusion pipeline on GPU (if available)
|
14 |
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
15 |
+
pipe = pipe.to(device)
|
16 |
|
17 |
MAX_SEED = np.iinfo(np.int32).max
|
18 |
MAX_IMAGE_SIZE = 1024
|
19 |
|
20 |
def truncate_text(text, max_tokens=77):
|
21 |
"""
|
22 |
+
Explicitly truncate a given text to a maximum of `max_tokens` using the pipeline's tokenizer.
|
23 |
"""
|
24 |
if text.strip() == "":
|
25 |
return text
|
26 |
+
# Tokenize with truncation enabled and a maximum length
|
27 |
tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
|
28 |
truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
|
29 |
return truncated_text
|
|
|
40 |
num_inference_steps=40,
|
41 |
progress=gr.Progress(track_tqdm=True),
|
42 |
):
|
43 |
+
# Randomize seed if requested
|
44 |
if randomize_seed:
|
45 |
seed = random.randint(0, MAX_SEED)
|
46 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
47 |
+
|
48 |
+
# Explicitly truncate both prompt and negative prompt to avoid CLIP token warnings.
|
49 |
prompt = truncate_text(prompt, max_tokens=77)
|
50 |
negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
|
51 |
+
|
52 |
+
# Generate the image (the pipeline is already on GPU)
|
|
|
|
|
|
|
|
|
|
|
53 |
image = pipe(
|
54 |
prompt=prompt,
|
55 |
negative_prompt=negative_prompt,
|
|
|
59 |
height=height,
|
60 |
generator=generator,
|
61 |
).images[0]
|
62 |
+
|
|
|
|
|
|
|
|
|
|
|
63 |
return image, seed
|
64 |
|
65 |
+
# Gradio UI layout remains as before.
|
66 |
examples = [
|
67 |
"A capybara wearing a suit holding a sign that reads Hello World",
|
68 |
]
|