SeedOfEvil commited on
Commit
1bb64e4
·
verified ·
1 Parent(s): 5b25f5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -10,18 +10,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "stabilityai/stable-diffusion-3.5-large"
11
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
 
13
- # Load the Stable Diffusion pipeline and move it to the appropriate device.
14
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
- pipe = pipe.to(device)
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
20
- # Define a helper function to truncate text to a maximum of 77 tokens.
21
  def truncate_text(text, max_tokens=77):
 
 
 
22
  if text.strip() == "":
23
  return text
24
- # Use the pipeline's tokenizer (CLIP tokenizer)
25
  tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
26
  truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
27
  return truncated_text
@@ -40,16 +41,18 @@ def infer(
40
  ):
41
  if randomize_seed:
42
  seed = random.randint(0, MAX_SEED)
43
- generator = torch.Generator(device=device).manual_seed(seed)
44
 
45
- # Truncate both prompt and negative prompt to 77 tokens.
46
  prompt = truncate_text(prompt, max_tokens=77)
47
  negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
48
 
49
- # Move model to GPU before inference.
50
- pipe.model.to("cuda")
51
-
52
- # Generate image using the truncated prompts.
 
 
53
  image = pipe(
54
  prompt=prompt,
55
  negative_prompt=negative_prompt,
@@ -60,12 +63,14 @@ def infer(
60
  generator=generator,
61
  ).images[0]
62
 
63
- # Move model back to CPU after inference.
64
- pipe.model.to("cpu")
 
 
65
 
66
  return image, seed
67
 
68
- # UI layout remains unchanged.
69
  examples = [
70
  "A capybara wearing a suit holding a sign that reads Hello World",
71
  ]
@@ -154,5 +159,4 @@ with gr.Blocks(css=css) as demo:
154
  )
155
 
156
  if __name__ == "__main__":
157
- demo.launch(share=True)
158
-
 
10
  model_repo_id = "stabilityai/stable-diffusion-3.5-large"
11
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
 
13
+ # Preload the Stable Diffusion pipeline on CPU at startup.
14
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
+ pipe = pipe.to("cpu")
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
 
20
  def truncate_text(text, max_tokens=77):
21
+ """
22
+ Truncate a given text to a maximum of max_tokens using the pipeline's tokenizer.
23
+ """
24
  if text.strip() == "":
25
  return text
 
26
  tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
27
  truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
28
  return truncated_text
 
41
  ):
42
  if randomize_seed:
43
  seed = random.randint(0, MAX_SEED)
44
+ generator = torch.Generator(device="cuda").manual_seed(seed)
45
 
46
+ # Truncate prompts to avoid CLIP token length issues.
47
  prompt = truncate_text(prompt, max_tokens=77)
48
  negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
49
 
50
+ # Move pipeline components to GPU
51
+ pipe.unet.to("cuda")
52
+ pipe.text_encoder.to("cuda")
53
+ pipe.vae.to("cuda")
54
+
55
+ # Generate the image
56
  image = pipe(
57
  prompt=prompt,
58
  negative_prompt=negative_prompt,
 
63
  generator=generator,
64
  ).images[0]
65
 
66
+ # Move pipeline components back to CPU
67
+ pipe.unet.to("cpu")
68
+ pipe.text_encoder.to("cpu")
69
+ pipe.vae.to("cpu")
70
 
71
  return image, seed
72
 
73
+ # Gradio UI definition
74
  examples = [
75
  "A capybara wearing a suit holding a sign that reads Hello World",
76
  ]
 
159
  )
160
 
161
  if __name__ == "__main__":
162
+ demo.launch()