Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -181,19 +181,38 @@ def generate_image_prompts(story_text):
|
|
181 |
|
182 |
@torch.inference_mode()
|
183 |
@spaces.GPU(duration=60)
|
184 |
-
def generate_story_image(prompt):
|
|
|
185 |
torch.cuda.empty_cache()
|
186 |
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
188 |
enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors"
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
return image
|
198 |
|
199 |
@torch.inference_mode()
|
@@ -221,20 +240,28 @@ def generate_all_scenes(prompts_text):
|
|
221 |
|
222 |
if scene_prompt:
|
223 |
try:
|
224 |
-
torch.cuda.empty_cache()
|
225 |
print(f"Generating image for scene: {scene_prompt}") # Debugging
|
|
|
226 |
image = generate_story_image(scene_prompt)
|
227 |
|
228 |
if image is not None:
|
229 |
-
|
|
|
|
|
|
|
|
|
230 |
|
231 |
-
torch.cuda.empty_cache()
|
232 |
except Exception as e:
|
233 |
print(f"Error generating image: {str(e)}")
|
234 |
continue
|
235 |
|
|
|
236 |
return generated_images, "\n\n".join(formatted_prompts)
|
237 |
|
|
|
|
|
238 |
# Helper functions without GPU usage
|
239 |
def clean_story_output(story):
|
240 |
story = story.replace("<|im_end|>", "")
|
|
|
181 |
|
182 |
@torch.inference_mode()
|
183 |
@spaces.GPU(duration=60)
|
184 |
+
def generate_story_image(prompt, seed=-1):
|
185 |
+
"""Generate an image using Stable Diffusion with LoRA temporarily loaded."""
|
186 |
torch.cuda.empty_cache()
|
187 |
|
188 |
+
generator = torch.Generator("cuda")
|
189 |
+
if seed != -1:
|
190 |
+
generator.manual_seed(seed)
|
191 |
+
else:
|
192 |
+
generator.manual_seed(torch.randint(0, 2**32 - 1, (1,)).item())
|
193 |
+
|
194 |
enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors"
|
195 |
+
|
196 |
+
try:
|
197 |
+
# Load LoRA only for this function
|
198 |
+
pipe.load_lora_weights("Prof-Hunt/lora-bulldog")
|
199 |
+
|
200 |
+
image = pipe(
|
201 |
+
prompt=enhanced_prompt,
|
202 |
+
negative_prompt="deformed, ugly, blurry, bad art, poor quality, distorted",
|
203 |
+
num_inference_steps=50,
|
204 |
+
guidance_scale=15,
|
205 |
+
generator=generator
|
206 |
+
).images[0]
|
207 |
+
|
208 |
+
# Unload LoRA properly
|
209 |
+
pipe.unload_lora_weights()
|
210 |
+
torch.cuda.empty_cache()
|
211 |
+
|
212 |
+
except Exception as e:
|
213 |
+
print(f"Error generating image: {e}")
|
214 |
+
return None
|
215 |
+
|
216 |
return image
|
217 |
|
218 |
@torch.inference_mode()
|
|
|
240 |
|
241 |
if scene_prompt:
|
242 |
try:
|
243 |
+
torch.cuda.empty_cache()
|
244 |
print(f"Generating image for scene: {scene_prompt}") # Debugging
|
245 |
+
|
246 |
image = generate_story_image(scene_prompt)
|
247 |
|
248 |
if image is not None:
|
249 |
+
img_array = np.array(image)
|
250 |
+
|
251 |
+
# Ensure the image is valid
|
252 |
+
if img_array.shape[0] > 0:
|
253 |
+
generated_images.append(img_array)
|
254 |
|
255 |
+
torch.cuda.empty_cache()
|
256 |
except Exception as e:
|
257 |
print(f"Error generating image: {str(e)}")
|
258 |
continue
|
259 |
|
260 |
+
print(f"Generated {len(generated_images)} images.")
|
261 |
return generated_images, "\n\n".join(formatted_prompts)
|
262 |
|
263 |
+
|
264 |
+
|
265 |
# Helper functions without GPU usage
|
266 |
def clean_story_output(story):
|
267 |
story = story.replace("<|im_end|>", "")
|