SeedOfEvil commited on
Commit
10ad556
·
verified ·
1 Parent(s): 1bb64e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -10,19 +10,20 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_repo_id = "stabilityai/stable-diffusion-3.5-large"
11
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
 
13
- # Preload the Stable Diffusion pipeline on CPU at startup.
14
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
- pipe = pipe.to("cpu")
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
20
  def truncate_text(text, max_tokens=77):
21
  """
22
- Truncate a given text to a maximum of max_tokens using the pipeline's tokenizer.
23
  """
24
  if text.strip() == "":
25
  return text
 
26
  tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
27
  truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
28
  return truncated_text
@@ -39,20 +40,16 @@ def infer(
39
  num_inference_steps=40,
40
  progress=gr.Progress(track_tqdm=True),
41
  ):
 
42
  if randomize_seed:
43
  seed = random.randint(0, MAX_SEED)
44
- generator = torch.Generator(device="cuda").manual_seed(seed)
45
-
46
- # Truncate prompts to avoid CLIP token length issues.
47
  prompt = truncate_text(prompt, max_tokens=77)
48
  negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
49
-
50
- # Move pipeline components to GPU
51
- pipe.unet.to("cuda")
52
- pipe.text_encoder.to("cuda")
53
- pipe.vae.to("cuda")
54
-
55
- # Generate the image
56
  image = pipe(
57
  prompt=prompt,
58
  negative_prompt=negative_prompt,
@@ -62,15 +59,10 @@ def infer(
62
  height=height,
63
  generator=generator,
64
  ).images[0]
65
-
66
- # Move pipeline components back to CPU
67
- pipe.unet.to("cpu")
68
- pipe.text_encoder.to("cpu")
69
- pipe.vae.to("cpu")
70
-
71
  return image, seed
72
 
73
- # Gradio UI definition
74
  examples = [
75
  "A capybara wearing a suit holding a sign that reads Hello World",
76
  ]
 
10
  model_repo_id = "stabilityai/stable-diffusion-3.5-large"
11
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
 
13
+ # Preload the Stable Diffusion pipeline on GPU (if available)
14
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
+ pipe = pipe.to(device)
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
20
  def truncate_text(text, max_tokens=77):
21
  """
22
+ Explicitly truncate a given text to a maximum of `max_tokens` using the pipeline's tokenizer.
23
  """
24
  if text.strip() == "":
25
  return text
26
+ # Tokenize with truncation enabled and a maximum length
27
  tokens = pipe.tokenizer(text, truncation=True, max_length=max_tokens, add_special_tokens=True)
28
  truncated_text = pipe.tokenizer.decode(tokens["input_ids"], skip_special_tokens=True)
29
  return truncated_text
 
40
  num_inference_steps=40,
41
  progress=gr.Progress(track_tqdm=True),
42
  ):
43
+ # Randomize seed if requested
44
  if randomize_seed:
45
  seed = random.randint(0, MAX_SEED)
46
+ generator = torch.Generator(device=device).manual_seed(seed)
47
+
48
+ # Explicitly truncate both prompt and negative prompt to avoid CLIP token warnings.
49
  prompt = truncate_text(prompt, max_tokens=77)
50
  negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
51
+
52
+ # Generate the image (the pipeline is already on GPU)
 
 
 
 
 
53
  image = pipe(
54
  prompt=prompt,
55
  negative_prompt=negative_prompt,
 
59
  height=height,
60
  generator=generator,
61
  ).images[0]
62
+
 
 
 
 
 
63
  return image, seed
64
 
65
+ # Gradio UI layout remains as before.
66
  examples = [
67
  "A capybara wearing a suit holding a sign that reads Hello World",
68
  ]