Update app.py
Browse files
app.py
CHANGED
@@ -10,18 +10,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
10 |
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
|
11 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
12 |
|
13 |
-
#
|
14 |
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
15 |
-
pipe = pipe.to(
|
16 |
|
17 |
MAX_SEED = np.iinfo(np.int32).max
|
18 |
MAX_IMAGE_SIZE = 1024
|
19 |
|
20 |
-
# Define a helper function to truncate text to a maximum of 77 tokens.
|
21 |
def truncate_text(text, max_tokens=77):
|
|
|
|
|
|
|
22 |
if text.strip() == "":
|
23 |
return text
|
24 |
-
# Use the pipeline's tokenizer (CLIP tokenizer)
|
25 |
tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
|
26 |
truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
|
27 |
return truncated_text
|
@@ -40,16 +41,18 @@ def infer(
|
|
40 |
):
|
41 |
if randomize_seed:
|
42 |
seed = random.randint(0, MAX_SEED)
|
43 |
-
generator = torch.Generator(device=
|
44 |
|
45 |
-
# Truncate
|
46 |
prompt = truncate_text(prompt, max_tokens=77)
|
47 |
negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
|
48 |
|
49 |
-
# Move
|
50 |
-
pipe.
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
image = pipe(
|
54 |
prompt=prompt,
|
55 |
negative_prompt=negative_prompt,
|
@@ -60,12 +63,14 @@ def infer(
|
|
60 |
generator=generator,
|
61 |
).images[0]
|
62 |
|
63 |
-
# Move
|
64 |
-
pipe.
|
|
|
|
|
65 |
|
66 |
return image, seed
|
67 |
|
68 |
-
# UI
|
69 |
examples = [
|
70 |
"A capybara wearing a suit holding a sign that reads Hello World",
|
71 |
]
|
@@ -154,5 +159,4 @@ with gr.Blocks(css=css) as demo:
|
|
154 |
)
|
155 |
|
156 |
if __name__ == "__main__":
|
157 |
-
demo.launch(
|
158 |
-
|
|
|
10 |
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
|
11 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
12 |
|
13 |
+
# Preload the Stable Diffusion pipeline on CPU at startup.
|
14 |
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
15 |
+
pipe = pipe.to("cpu")
|
16 |
|
17 |
MAX_SEED = np.iinfo(np.int32).max
|
18 |
MAX_IMAGE_SIZE = 1024
|
19 |
|
|
|
20 |
def truncate_text(text, max_tokens=77):
|
21 |
+
"""
|
22 |
+
Truncate a given text to a maximum of max_tokens using the pipeline's tokenizer.
|
23 |
+
"""
|
24 |
if text.strip() == "":
|
25 |
return text
|
|
|
26 |
tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
|
27 |
truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
|
28 |
return truncated_text
|
|
|
41 |
):
|
42 |
if randomize_seed:
|
43 |
seed = random.randint(0, MAX_SEED)
|
44 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
45 |
|
46 |
+
# Truncate prompts to avoid CLIP token length issues.
|
47 |
prompt = truncate_text(prompt, max_tokens=77)
|
48 |
negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
|
49 |
|
50 |
+
# Move pipeline components to GPU
|
51 |
+
pipe.unet.to("cuda")
|
52 |
+
pipe.text_encoder.to("cuda")
|
53 |
+
pipe.vae.to("cuda")
|
54 |
+
|
55 |
+
# Generate the image
|
56 |
image = pipe(
|
57 |
prompt=prompt,
|
58 |
negative_prompt=negative_prompt,
|
|
|
63 |
generator=generator,
|
64 |
).images[0]
|
65 |
|
66 |
+
# Move pipeline components back to CPU
|
67 |
+
pipe.unet.to("cpu")
|
68 |
+
pipe.text_encoder.to("cpu")
|
69 |
+
pipe.vae.to("cpu")
|
70 |
|
71 |
return image, seed
|
72 |
|
73 |
+
# Gradio UI definition
|
74 |
examples = [
|
75 |
"A capybara wearing a suit holding a sign that reads Hello World",
|
76 |
]
|
|
|
159 |
)
|
160 |
|
161 |
if __name__ == "__main__":
|
162 |
+
demo.launch()
|
|