Prof-Hunt commited on
Commit
a5abdd6
·
verified ·
1 Parent(s): b64f67c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -13
app.py CHANGED
@@ -181,19 +181,38 @@ def generate_image_prompts(story_text):
181
 
182
  @torch.inference_mode()
183
  @spaces.GPU(duration=60)
184
- def generate_story_image(prompt):
 
185
  torch.cuda.empty_cache()
186
 
187
- pipe.load_lora_weights("Prof-Hunt/lora-bulldog")
 
 
 
 
 
188
  enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors"
189
-
190
- image = pipe(
191
- prompt=enhanced_prompt,
192
- negative_prompt="deformed, ugly, blurry, bad art, poor quality, distorted",
193
- num_inference_steps=50,
194
- guidance_scale=15,
195
- ).images[0]
196
-
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  return image
198
 
199
  @torch.inference_mode()
@@ -221,20 +240,28 @@ def generate_all_scenes(prompts_text):
221
 
222
  if scene_prompt:
223
  try:
224
- torch.cuda.empty_cache() # Clear GPU memory before generation
225
  print(f"Generating image for scene: {scene_prompt}") # Debugging
 
226
  image = generate_story_image(scene_prompt)
227
 
228
  if image is not None:
229
- generated_images.append(np.array(image))
 
 
 
 
230
 
231
- torch.cuda.empty_cache() # Clear memory after each image
232
  except Exception as e:
233
  print(f"Error generating image: {str(e)}")
234
  continue
235
 
 
236
  return generated_images, "\n\n".join(formatted_prompts)
237
 
 
 
238
  # Helper functions without GPU usage
239
  def clean_story_output(story):
240
  story = story.replace("<|im_end|>", "")
 
181
 
182
  @torch.inference_mode()
183
  @spaces.GPU(duration=60)
184
+ def generate_story_image(prompt, seed=-1):
185
+ """Generate an image using Stable Diffusion with LoRA temporarily loaded."""
186
  torch.cuda.empty_cache()
187
 
188
+ generator = torch.Generator("cuda")
189
+ if seed != -1:
190
+ generator.manual_seed(seed)
191
+ else:
192
+ generator.manual_seed(torch.randint(0, 2**32 - 1, (1,)).item())
193
+
194
  enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors"
195
+
196
+ try:
197
+ # Load LoRA only for this function
198
+ pipe.load_lora_weights("Prof-Hunt/lora-bulldog")
199
+
200
+ image = pipe(
201
+ prompt=enhanced_prompt,
202
+ negative_prompt="deformed, ugly, blurry, bad art, poor quality, distorted",
203
+ num_inference_steps=50,
204
+ guidance_scale=15,
205
+ generator=generator
206
+ ).images[0]
207
+
208
+ # Unload LoRA properly
209
+ pipe.unload_lora_weights()
210
+ torch.cuda.empty_cache()
211
+
212
+ except Exception as e:
213
+ print(f"Error generating image: {e}")
214
+ return None
215
+
216
  return image
217
 
218
  @torch.inference_mode()
 
240
 
241
  if scene_prompt:
242
  try:
243
+ torch.cuda.empty_cache()
244
  print(f"Generating image for scene: {scene_prompt}") # Debugging
245
+
246
  image = generate_story_image(scene_prompt)
247
 
248
  if image is not None:
249
+ img_array = np.array(image)
250
+
251
+ # Ensure the image is valid
252
+ if img_array.shape[0] > 0:
253
+ generated_images.append(img_array)
254
 
255
+ torch.cuda.empty_cache()
256
  except Exception as e:
257
  print(f"Error generating image: {str(e)}")
258
  continue
259
 
260
+ print(f"Generated {len(generated_images)} images.")
261
  return generated_images, "\n\n".join(formatted_prompts)
262
 
263
+
264
+
265
  # Helper functions without GPU usage
266
  def clean_story_output(story):
267
  story = story.replace("<|im_end|>", "")