SeedOfEvil commited on
Commit
5b25f5e
·
verified ·
1 Parent(s): f11f490

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -14,7 +14,6 @@ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
  pipe = pipe.to(device)
16
 
17
- # Maximum values as defined in your original code
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
@@ -39,7 +38,6 @@ def infer(
39
  num_inference_steps=40,
40
  progress=gr.Progress(track_tqdm=True),
41
  ):
42
- # Optionally randomize seed
43
  if randomize_seed:
44
  seed = random.randint(0, MAX_SEED)
45
  generator = torch.Generator(device=device).manual_seed(seed)
@@ -48,7 +46,10 @@ def infer(
48
  prompt = truncate_text(prompt, max_tokens=77)
49
  negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
50
 
51
- # Explicitly set pad_token_id to eos_token_id for open-end generation.
 
 
 
52
  image = pipe(
53
  prompt=prompt,
54
  negative_prompt=negative_prompt,
@@ -57,12 +58,14 @@ def infer(
57
  width=width,
58
  height=height,
59
  generator=generator,
60
- pad_token_id=pipe.tokenizer.eos_token_id,
61
  ).images[0]
62
 
 
 
 
63
  return image, seed
64
 
65
- # Example prompt for testing
66
  examples = [
67
  "A capybara wearing a suit holding a sign that reads Hello World",
68
  ]
@@ -151,5 +154,5 @@ with gr.Blocks(css=css) as demo:
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
155
 
 
14
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
  pipe = pipe.to(device)
16
 
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
 
38
  num_inference_steps=40,
39
  progress=gr.Progress(track_tqdm=True),
40
  ):
 
41
  if randomize_seed:
42
  seed = random.randint(0, MAX_SEED)
43
  generator = torch.Generator(device=device).manual_seed(seed)
 
46
  prompt = truncate_text(prompt, max_tokens=77)
47
  negative_prompt = truncate_text(negative_prompt, max_tokens=77) if negative_prompt.strip() else ""
48
 
49
+ # Move model to GPU before inference.
50
+ pipe.model.to("cuda")
51
+
52
+ # Generate image using the truncated prompts.
53
  image = pipe(
54
  prompt=prompt,
55
  negative_prompt=negative_prompt,
 
58
  width=width,
59
  height=height,
60
  generator=generator,
 
61
  ).images[0]
62
 
63
+ # Move model back to CPU after inference.
64
+ pipe.model.to("cpu")
65
+
66
  return image, seed
67
 
68
+ # UI layout remains unchanged.
69
  examples = [
70
  "A capybara wearing a suit holding a sign that reads Hello World",
71
  ]
 
154
  )
155
 
156
  if __name__ == "__main__":
157
+ demo.launch(share=True)
158