Prof-Hunt commited on
Commit
01c7a6f
·
verified ·
1 Parent(s): a5abdd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -260
app.py CHANGED
@@ -8,49 +8,68 @@ import textwrap
8
  import os
9
  import gc
10
  import re
 
11
  from datetime import datetime
12
  import spaces
13
  from kokoro import KPipeline
14
  import soundfile as sf
15
 
16
- # Initialize models at startup - outside of functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  print("Loading models...")
18
 
19
  # Load SmolVLM for image analysis
20
  processor_vlm = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
21
  model_vlm = AutoModelForVision2Seq.from_pretrained(
22
  "HuggingFaceTB/SmolVLM-500M-Instruct",
23
- torch_dtype=torch.bfloat16,
24
- use_safetensors=True
25
- )
26
 
27
  # Load SmolLM2 for story and prompt generation
28
  checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
29
  tokenizer_lm = AutoTokenizer.from_pretrained(checkpoint)
30
- model_lm = AutoModelForCausalLM.from_pretrained(
31
- checkpoint,
32
- use_safetensors=True
33
- )
34
-
35
- # Load Stable Diffusion pipeline
36
- pipe = StableDiffusionPipeline.from_pretrained(
37
- "runwayml/stable-diffusion-v1-5",
38
- torch_dtype=torch.float16,
39
- use_safetensors=True
40
- )
41
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
42
-
43
- # Move models to GPU if available
44
- if torch.cuda.is_available():
45
- model_vlm = model_vlm.to("cuda")
46
- model_lm = model_lm.to("cuda")
47
- pipe = pipe.to("cuda")
48
 
49
  @torch.inference_mode()
50
  @spaces.GPU(duration=30)
51
  def generate_image():
52
  """Generate a random landscape image."""
53
- torch.cuda.empty_cache()
 
 
54
 
55
  default_prompt = "a beautiful, professional landscape photograph"
56
  default_negative_prompt = "blurry, bad quality, distorted, deformed"
@@ -60,15 +79,25 @@ def generate_image():
60
 
61
  generator = torch.Generator("cuda").manual_seed(default_seed)
62
 
63
- image = pipe(
64
- prompt=default_prompt,
65
- negative_prompt=default_negative_prompt,
66
- num_inference_steps=default_steps,
67
- guidance_scale=default_guidance,
68
- generator=generator,
69
- ).images[0]
70
-
71
- return image
 
 
 
 
 
 
 
 
 
 
72
 
73
  @torch.inference_mode()
74
  @spaces.GPU(duration=30)
@@ -76,7 +105,7 @@ def analyze_image(image):
76
  if image is None:
77
  return "Please generate an image first."
78
 
79
- torch.cuda.empty_cache()
80
 
81
  if isinstance(image, np.ndarray):
82
  image = Image.fromarray(image)
@@ -86,38 +115,49 @@ def analyze_image(image):
86
  "role": "user",
87
  "content": [
88
  {"type": "image"},
89
- {"type": "text", "text": "Describe this image very briefly in five sentences or less. Short description."}
90
  ]
91
  }
92
  ]
93
 
94
- prompt = processor_vlm.apply_chat_template(messages, add_generation_prompt=True)
95
-
96
- inputs = processor_vlm(
97
- text=prompt,
98
- images=[image],
99
- return_tensors="pt"
100
- ).to('cuda')
101
-
102
- outputs = model_vlm.generate(
103
- input_ids=inputs.input_ids,
104
- pixel_values=inputs.pixel_values,
105
- attention_mask=inputs.attention_mask,
106
- num_return_sequences=1,
107
- no_repeat_ngram_size=2,
108
- max_new_tokens=500,
109
- min_new_tokens=10
110
- )
111
-
112
- description = processor_vlm.decode(outputs[0], skip_special_tokens=True)
113
- description = re.sub(r".*?Assistant:\s*", "", description, flags=re.DOTALL).strip()
114
-
115
- return description
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  @torch.inference_mode()
118
  @spaces.GPU(duration=30)
119
  def generate_story(image_description):
120
- torch.cuda.empty_cache()
121
 
122
  story_prompt = f"""Write a short children's story (one chapter, about 500 words) based on this scene: {image_description}
123
 
@@ -128,74 +168,89 @@ def generate_story(image_description):
128
  4. Keep it simple and engaging for young children
129
  5. End with a simple moral lesson"""
130
 
131
- messages = [{"role": "user", "content": story_prompt}]
132
- input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False)
133
-
134
- inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda")
135
-
136
- outputs = model_lm.generate(
137
- inputs,
138
- max_new_tokens=750,
139
- temperature=0.7,
140
- top_p=0.9,
141
- do_sample=True,
142
- repetition_penalty=1.2
143
- )
144
-
145
- story = tokenizer_lm.decode(outputs[0])
146
- story = clean_story_output(story)
147
-
148
- return story
149
-
150
- @torch.inference_mode()
151
- @spaces.GPU(duration=30)
152
- def generate_image_prompts(story_text):
153
- torch.cuda.empty_cache()
154
- paragraphs = split_into_paragraphs(story_text)
155
-
156
- all_prompts = []
157
- prompt_instruction = '''Here is a story paragraph: {paragraph}
158
-
159
- Start your response with "Watercolor bulldog" and describe what Champ is doing in this scene. Add where it takes place and one mood detail. Keep it short.'''
160
-
161
- for i, paragraph in enumerate(paragraphs, 1):
162
- messages = [{"role": "user", "content": prompt_instruction.format(paragraph=paragraph)}]
163
  input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False)
164
 
165
  inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda")
166
 
167
  outputs = model_lm.generate(
168
  inputs,
169
- max_new_tokens=30,
170
- temperature=0.5,
171
  top_p=0.9,
172
  do_sample=True,
173
  repetition_penalty=1.2
174
  )
175
 
176
- prompt = process_generated_prompt(tokenizer_lm.decode(outputs[0]), paragraph)
177
- section = f"Paragraph {i}:\n{paragraph}\n\nScenery Prompt {i}:\n{prompt}\n\n{'='*50}"
178
- all_prompts.append(section)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- return '\n'.join(all_prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  @torch.inference_mode()
183
  @spaces.GPU(duration=60)
184
  def generate_story_image(prompt, seed=-1):
185
- """Generate an image using Stable Diffusion with LoRA temporarily loaded."""
186
- torch.cuda.empty_cache()
 
187
 
188
- generator = torch.Generator("cuda")
189
- if seed != -1:
190
- generator.manual_seed(seed)
191
- else:
192
- generator.manual_seed(torch.randint(0, 2**32 - 1, (1,)).item())
193
-
194
- enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors"
195
-
196
  try:
197
- # Load LoRA only for this function
198
  pipe.load_lora_weights("Prof-Hunt/lora-bulldog")
 
 
 
 
 
 
 
 
199
 
200
  image = pipe(
201
  prompt=enhanced_prompt,
@@ -205,19 +260,24 @@ def generate_story_image(prompt, seed=-1):
205
  generator=generator
206
  ).images[0]
207
 
208
- # Unload LoRA properly
209
  pipe.unload_lora_weights()
210
- torch.cuda.empty_cache()
 
 
211
 
212
  except Exception as e:
213
  print(f"Error generating image: {e}")
 
 
 
 
214
  return None
215
 
216
- return image
217
-
218
  @torch.inference_mode()
219
  @spaces.GPU(duration=180)
220
  def generate_all_scenes(prompts_text):
 
 
221
  generated_images = []
222
  formatted_prompts = []
223
 
@@ -227,43 +287,168 @@ def generate_all_scenes(prompts_text):
227
  if not section.strip():
228
  continue
229
 
230
- lines = [line.strip() for line in section.split('\n') if line.strip()]
231
-
232
  scene_prompt = None
233
- for i, line in enumerate(lines):
234
  if 'Scenery Prompt' in line:
235
  scene_num = line.split('Scenery Prompt')[1].split(':')[0].strip()
236
- if i + 1 < len(lines):
237
- scene_prompt = lines[i + 1]
 
238
  formatted_prompts.append(f"Scene {scene_num}: {scene_prompt}")
239
  break
240
 
241
  if scene_prompt:
242
  try:
243
- torch.cuda.empty_cache()
244
- print(f"Generating image for scene: {scene_prompt}") # Debugging
245
 
246
  image = generate_story_image(scene_prompt)
247
 
248
  if image is not None:
249
  img_array = np.array(image)
250
-
251
- # Ensure the image is valid
252
- if img_array.shape[0] > 0:
253
- generated_images.append(img_array)
254
 
255
- torch.cuda.empty_cache()
256
  except Exception as e:
257
  print(f"Error generating image: {str(e)}")
 
258
  continue
259
 
260
- print(f"Generated {len(generated_images)} images.")
261
  return generated_images, "\n\n".join(formatted_prompts)
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
 
 
 
 
 
 
 
 
264
 
265
- # Helper functions without GPU usage
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  def clean_story_output(story):
 
267
  story = story.replace("<|im_end|>", "")
268
 
269
  story_start = story.find("Once upon")
@@ -288,6 +473,7 @@ def clean_story_output(story):
288
  return '\n\n'.join(cleaned_lines).strip()
289
 
290
  def split_into_paragraphs(text):
 
291
  paragraphs = []
292
  current_paragraph = []
293
 
@@ -308,6 +494,7 @@ def split_into_paragraphs(text):
308
  'keep it simple', 'end with', 'write a'])]
309
 
310
  def process_generated_prompt(prompt, paragraph):
 
311
  prompt = prompt.replace("<|im_start|>", "").replace("<|im_end|>", "")
312
  prompt = prompt.replace("assistant", "").replace("system", "").replace("user", "")
313
 
@@ -326,143 +513,9 @@ def process_generated_prompt(prompt, paragraph):
326
 
327
  return prompt
328
 
329
- def overlay_text_on_image(image, text):
330
- if isinstance(image, np.ndarray):
331
- image = Image.fromarray(image)
332
-
333
- img = image.convert('RGB')
334
- draw = ImageDraw.Draw(img)
335
-
336
- try:
337
- font_size = int(img.width * 0.025)
338
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
339
- except:
340
- font = ImageFont.load_default()
341
-
342
- y_position = int(img.height * 0.005)
343
- x_margin = int(img.width * 0.005)
344
- available_width = img.width - (2 * x_margin)
345
-
346
- wrapped_text = textwrap.fill(text, width=int(available_width / (font_size * 0.6)))
347
-
348
- outline_color = (255, 255, 255)
349
- text_color = (0, 0, 0)
350
- offsets = [-2, -1, 1, 2]
351
-
352
- for dx in offsets:
353
- for dy in offsets:
354
- draw.multiline_text(
355
- (x_margin + dx, y_position + dy),
356
- wrapped_text,
357
- font=font,
358
- fill=outline_color
359
- )
360
-
361
- draw.multiline_text(
362
- (x_margin, y_position),
363
- wrapped_text,
364
- font=font,
365
- fill=text_color
366
- )
367
-
368
- return img
369
-
370
- # Initialize Kokoro TTS pipeline
371
- pipeline = KPipeline(lang_code='a') # 'a' for American English
372
-
373
- def generate_combined_audio_from_story(story_text, voice='af_heart', speed=1):
374
- """Generate a single audio file for all paragraphs in the story."""
375
- if not story_text:
376
- return None
377
-
378
- # Split story into paragraphs
379
- paragraphs = []
380
- current_paragraph = []
381
-
382
- for line in story_text.split('\n'):
383
- line = line.strip()
384
- if not line: # Empty line indicates paragraph break
385
- if current_paragraph:
386
- paragraphs.append(' '.join(current_paragraph))
387
- current_paragraph = []
388
- else:
389
- current_paragraph.append(line)
390
-
391
- if current_paragraph:
392
- paragraphs.append(' '.join(current_paragraph))
393
-
394
- # Combine audio for all paragraphs
395
- combined_audio = []
396
- for paragraph in paragraphs:
397
- if not paragraph.strip():
398
- continue # Skip empty paragraphs
399
-
400
- generator = pipeline(
401
- paragraph,
402
- voice=voice,
403
- speed=speed,
404
- split_pattern=r'\n+' # Split on newlines
405
- )
406
- for _, _, audio in generator:
407
- combined_audio.extend(audio) # Append audio data
408
-
409
- # Convert combined audio to NumPy array and save
410
- combined_audio = np.array(combined_audio)
411
- filename = "combined_story.wav"
412
- sf.write(filename, combined_audio, 24000) # Save audio as .wav
413
- return filename
414
-
415
- def add_text_to_scenes(gallery_images, prompts_text):
416
- if not isinstance(gallery_images, list):
417
- return [], []
418
-
419
- sections = prompts_text.split('='*50)
420
- overlaid_images = []
421
- output_files = []
422
-
423
- temp_dir = "temp_book_pages"
424
- os.makedirs(temp_dir, exist_ok=True)
425
-
426
- for i, (image_data, section) in enumerate(zip(gallery_images, sections)):
427
- if not section.strip():
428
- continue
429
-
430
- lines = [line.strip() for line in section.split('\n') if line.strip()]
431
- paragraph = None
432
- for j, line in enumerate(lines):
433
- if line.startswith('Paragraph'):
434
- if j + 1 < len(lines):
435
- paragraph = lines[j + 1]
436
- break
437
-
438
- if paragraph and image_data is not None:
439
- try:
440
- overlaid_img = overlay_text_on_image(image_data, paragraph)
441
- if overlaid_img is not None:
442
- overlaid_array = np.array(overlaid_img)
443
- overlaid_images.append(overlaid_array)
444
-
445
- output_path = os.path.join(temp_dir, f"panel_{i+1}.png")
446
- overlaid_img.save(output_path)
447
- output_files.append(output_path)
448
- except Exception as e:
449
- print(f"Error processing image: {str(e)}")
450
- continue
451
-
452
- return overlaid_images, output_files
453
-
454
  def create_interface():
455
- theme = gr.themes.Soft().set(
456
- body_background_fill="*primary_50",
457
- button_primary_background_fill="rgb(173, 216, 230)", # light blue
458
- button_secondary_background_fill="rgb(255, 182, 193)", # light red
459
- button_primary_background_fill_hover="rgb(135, 206, 235)", # slightly darker blue for hover
460
- button_secondary_background_fill_hover="rgb(255, 160, 180)", # slightly darker red for hover
461
- block_title_text_color="*primary_500",
462
- block_label_text_color="*secondary_500",
463
- )
464
-
465
- with gr.Blocks(theme=theme) as demo:
466
  gr.Markdown("# Tech Tales: Story Creation")
467
 
468
  with gr.Row():
 
8
  import os
9
  import gc
10
  import re
11
+ import psutil
12
  from datetime import datetime
13
  import spaces
14
  from kokoro import KPipeline
15
  import soundfile as sf
16
 
17
+ def clear_memory():
18
+ """Helper function to clear both CUDA and system memory"""
19
+ gc.collect()
20
+ torch.cuda.empty_cache()
21
+ if torch.cuda.is_available():
22
+ torch.cuda.synchronize()
23
+
24
+ process = psutil.Process(os.getpid())
25
+ if hasattr(process, 'memory_info'):
26
+ process.memory_info().rss
27
+
28
+ gc.collect(generation=0)
29
+ gc.collect(generation=1)
30
+ gc.collect(generation=2)
31
+
32
+ if torch.cuda.is_available():
33
+ print(f"GPU Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
34
+ print(f"GPU Memory cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
35
+ print(f"CPU RAM used: {process.memory_info().rss/1024**2:.2f} MB")
36
+
37
+ # Initialize models at startup - only the lightweight ones
38
  print("Loading models...")
39
 
40
  # Load SmolVLM for image analysis
41
  processor_vlm = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
42
  model_vlm = AutoModelForVision2Seq.from_pretrained(
43
  "HuggingFaceTB/SmolVLM-500M-Instruct",
44
+ torch_dtype=torch.bfloat16
45
+ ).to("cuda")
 
46
 
47
  # Load SmolLM2 for story and prompt generation
48
  checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
49
  tokenizer_lm = AutoTokenizer.from_pretrained(checkpoint)
50
+ model_lm = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
51
+
52
+ # Initialize Kokoro TTS pipeline
53
+ pipeline = KPipeline(lang_code='a') # 'a' for American English
54
+
55
+ def load_sd_model():
56
+ """Load Stable Diffusion model only when needed"""
57
+ pipe = StableDiffusionPipeline.from_pretrained(
58
+ "runwayml/stable-diffusion-v1-5",
59
+ torch_dtype=torch.float16,
60
+ )
61
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
62
+ pipe.to("cuda")
63
+ pipe.enable_attention_slicing()
64
+ return pipe
 
 
 
65
 
66
  @torch.inference_mode()
67
  @spaces.GPU(duration=30)
68
  def generate_image():
69
  """Generate a random landscape image."""
70
+ clear_memory()
71
+
72
+ pipe = load_sd_model()
73
 
74
  default_prompt = "a beautiful, professional landscape photograph"
75
  default_negative_prompt = "blurry, bad quality, distorted, deformed"
 
79
 
80
  generator = torch.Generator("cuda").manual_seed(default_seed)
81
 
82
+ try:
83
+ image = pipe(
84
+ prompt=default_prompt,
85
+ negative_prompt=default_negative_prompt,
86
+ num_inference_steps=default_steps,
87
+ guidance_scale=default_guidance,
88
+ generator=generator,
89
+ ).images[0]
90
+
91
+ del pipe
92
+ clear_memory()
93
+ return image
94
+
95
+ except Exception as e:
96
+ print(f"Error generating image: {e}")
97
+ if 'pipe' in locals():
98
+ del pipe
99
+ clear_memory()
100
+ return None
101
 
102
  @torch.inference_mode()
103
  @spaces.GPU(duration=30)
 
105
  if image is None:
106
  return "Please generate an image first."
107
 
108
+ clear_memory()
109
 
110
  if isinstance(image, np.ndarray):
111
  image = Image.fromarray(image)
 
115
  "role": "user",
116
  "content": [
117
  {"type": "image"},
118
+ {"type": "text", "text": "Describe this image and Be brief but descriptive."}
119
  ]
120
  }
121
  ]
122
 
123
+ try:
124
+ prompt = processor_vlm.apply_chat_template(messages, add_generation_prompt=True)
125
+
126
+ inputs = processor_vlm(
127
+ text=prompt,
128
+ images=[image],
129
+ return_tensors="pt"
130
+ ).to('cuda')
131
+
132
+ outputs = model_vlm.generate(
133
+ input_ids=inputs.input_ids,
134
+ pixel_values=inputs.pixel_values,
135
+ attention_mask=inputs.attention_mask,
136
+ num_return_sequences=1,
137
+ no_repeat_ngram_size=2,
138
+ max_new_tokens=500,
139
+ min_new_tokens=10
140
+ )
141
+
142
+ description = processor_vlm.decode(outputs[0], skip_special_tokens=True)
143
+ description = re.sub(r".*?Assistant:\s*", "", description, flags=re.DOTALL).strip()
144
+
145
+ # Split into sentences and take only the first three
146
+ sentences = re.split(r'(?<=[.!?])\s+', description)
147
+ description = ' '.join(sentences[:3])
148
+
149
+ clear_memory()
150
+ return description
151
+
152
+ except Exception as e:
153
+ print(f"Error analyzing image: {e}")
154
+ clear_memory()
155
+ return "Error analyzing image. Please try again."
156
 
157
  @torch.inference_mode()
158
  @spaces.GPU(duration=30)
159
  def generate_story(image_description):
160
+ clear_memory()
161
 
162
  story_prompt = f"""Write a short children's story (one chapter, about 500 words) based on this scene: {image_description}
163
 
 
168
  4. Keep it simple and engaging for young children
169
  5. End with a simple moral lesson"""
170
 
171
+ try:
172
+ messages = [{"role": "user", "content": story_prompt}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False)
174
 
175
  inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda")
176
 
177
  outputs = model_lm.generate(
178
  inputs,
179
+ max_new_tokens=750,
180
+ temperature=0.7,
181
  top_p=0.9,
182
  do_sample=True,
183
  repetition_penalty=1.2
184
  )
185
 
186
+ story = tokenizer_lm.decode(outputs[0])
187
+ story = clean_story_output(story)
188
+
189
+ clear_memory()
190
+ return story
191
+
192
+ except Exception as e:
193
+ print(f"Error generating story: {e}")
194
+ clear_memory()
195
+ return "Error generating story. Please try again."
196
+
197
+ @torch.inference_mode()
198
+ @spaces.GPU(duration=30)
199
+ def generate_image_prompts(story_text):
200
+ clear_memory()
201
+
202
+ paragraphs = split_into_paragraphs(story_text)
203
+ all_prompts = []
204
+ prompt_instruction = '''Here is a story paragraph: {paragraph}
205
+
206
+ Start your response with "Watercolor bulldog" and describe what Champ is doing in this scene. Add where it takes place and one mood detail. Keep it short.'''
207
 
208
+ try:
209
+ for i, paragraph in enumerate(paragraphs, 1):
210
+ messages = [{"role": "user", "content": prompt_instruction.format(paragraph=paragraph)}]
211
+ input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False)
212
+
213
+ inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda")
214
+
215
+ outputs = model_lm.generate(
216
+ inputs,
217
+ max_new_tokens=30,
218
+ temperature=0.5,
219
+ top_p=0.9,
220
+ do_sample=True,
221
+ repetition_penalty=1.2
222
+ )
223
+
224
+ prompt = process_generated_prompt(tokenizer_lm.decode(outputs[0]), paragraph)
225
+ section = f"Paragraph {i}:\n{paragraph}\n\nScenery Prompt {i}:\n{prompt}\n\n{'='*50}"
226
+ all_prompts.append(section)
227
+
228
+ clear_memory()
229
+
230
+ return '\n'.join(all_prompts)
231
+
232
+ except Exception as e:
233
+ print(f"Error generating prompts: {e}")
234
+ clear_memory()
235
+ return "Error generating prompts. Please try again."
236
 
237
  @torch.inference_mode()
238
  @spaces.GPU(duration=60)
239
  def generate_story_image(prompt, seed=-1):
240
+ clear_memory()
241
+
242
+ pipe = load_sd_model()
243
 
 
 
 
 
 
 
 
 
244
  try:
 
245
  pipe.load_lora_weights("Prof-Hunt/lora-bulldog")
246
+
247
+ generator = torch.Generator("cuda")
248
+ if seed != -1:
249
+ generator.manual_seed(seed)
250
+ else:
251
+ generator.manual_seed(torch.randint(0, 2**32 - 1, (1,)).item())
252
+
253
+ enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors"
254
 
255
  image = pipe(
256
  prompt=enhanced_prompt,
 
260
  generator=generator
261
  ).images[0]
262
 
 
263
  pipe.unload_lora_weights()
264
+ del pipe
265
+ clear_memory()
266
+ return image
267
 
268
  except Exception as e:
269
  print(f"Error generating image: {e}")
270
+ if 'pipe' in locals():
271
+ pipe.unload_lora_weights()
272
+ del pipe
273
+ clear_memory()
274
  return None
275
 
 
 
276
  @torch.inference_mode()
277
  @spaces.GPU(duration=180)
278
  def generate_all_scenes(prompts_text):
279
+ clear_memory()
280
+
281
  generated_images = []
282
  formatted_prompts = []
283
 
 
287
  if not section.strip():
288
  continue
289
 
 
 
290
  scene_prompt = None
291
+ for line in section.split('\n'):
292
  if 'Scenery Prompt' in line:
293
  scene_num = line.split('Scenery Prompt')[1].split(':')[0].strip()
294
+ next_line_index = section.split('\n').index(line) + 1
295
+ if next_line_index < len(section.split('\n')):
296
+ scene_prompt = section.split('\n')[next_line_index].strip()
297
  formatted_prompts.append(f"Scene {scene_num}: {scene_prompt}")
298
  break
299
 
300
  if scene_prompt:
301
  try:
302
+ clear_memory()
303
+ print(f"Generating image for scene: {scene_prompt}")
304
 
305
  image = generate_story_image(scene_prompt)
306
 
307
  if image is not None:
308
  img_array = np.array(image)
309
+ generated_images.append(img_array)
310
+
311
+ clear_memory()
 
312
 
 
313
  except Exception as e:
314
  print(f"Error generating image: {str(e)}")
315
+ clear_memory()
316
  continue
317
 
 
318
  return generated_images, "\n\n".join(formatted_prompts)
319
 
320
+ def overlay_text_on_image(image, text):
321
+ if image is None:
322
+ return None
323
+
324
+ try:
325
+ img = image.convert('RGB')
326
+ draw = ImageDraw.Draw(img)
327
+
328
+ font_size = int(img.width * 0.025)
329
+ try:
330
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
331
+ except:
332
+ font = ImageFont.load_default()
333
+
334
+ y_position = int(img.height * 0.005)
335
+ x_margin = int(img.width * 0.005)
336
+ available_width = img.width - (2 * x_margin)
337
+
338
+ wrapped_text = textwrap.fill(text, width=int(available_width / (font_size * 0.6)))
339
+
340
+ outline_color = (255, 255, 255)
341
+ text_color = (0, 0, 0)
342
+ offsets = [-2, -1, 1, 2]
343
+
344
+ for dx in offsets:
345
+ for dy in offsets:
346
+ draw.multiline_text(
347
+ (x_margin + dx, y_position + dy),
348
+ wrapped_text,
349
+ font=font,
350
+ fill=outline_color
351
+ )
352
+
353
+ draw.multiline_text(
354
+ (x_margin, y_position),
355
+ wrapped_text,
356
+ font=font,
357
+ fill=text_color
358
+ )
359
+
360
+ return img
361
+
362
+ except Exception as e:
363
+ print(f"Error overlaying text: {e}")
364
+ return None
365
+
366
+ def add_text_to_scenes(gallery_images, prompts_text):
367
+ if not isinstance(gallery_images, list):
368
+ return [], []
369
+
370
+ clear_memory()
371
+
372
+ sections = prompts_text.split('='*50)
373
+ overlaid_images = []
374
+ output_files = []
375
+
376
+ temp_dir = "temp_book_pages"
377
+ os.makedirs(temp_dir, exist_ok=True)
378
+
379
+ for i, (image_data, section) in enumerate(zip(gallery_images, sections)):
380
+ if not section.strip():
381
+ continue
382
+
383
+ lines = [line.strip() for line in section.split('\n') if line.strip()]
384
+ paragraph = None
385
+ for j, line in enumerate(lines):
386
+ if line.startswith('Paragraph'):
387
+ if j + 1 < len(lines):
388
+ paragraph = lines[j + 1]
389
+ break
390
+
391
+ if paragraph and image_data is not None:
392
+ try:
393
+ if isinstance(image_data, np.ndarray):
394
+ image = Image.fromarray(image_data)
395
+ else:
396
+ image = image_data
397
+
398
+ overlaid_img = overlay_text_on_image(image, paragraph)
399
+ if overlaid_img is not None:
400
+ overlaid_array = np.array(overlaid_img)
401
+ overlaid_images.append(overlaid_array)
402
+
403
+ output_path = os.path.join(temp_dir, f"panel_{i+1}.png")
404
+ overlaid_img.save(output_path)
405
+ output_files.append(output_path)
406
+ except Exception as e:
407
+ print(f"Error processing image: {str(e)}")
408
+ continue
409
+
410
+ clear_memory()
411
+ return overlaid_images, output_files
412
+
413
+ def generate_combined_audio_from_story(story_text, voice='af_heart', speed=1):
414
+ clear_memory()
415
+
416
+ if not story_text:
417
+ return None
418
+
419
+ paragraphs = split_into_paragraphs(story_text)
420
+ combined_audio = []
421
+
422
+ try:
423
+ for paragraph in paragraphs:
424
+ if not paragraph.strip():
425
+ continue
426
 
427
+ generator = pipeline(
428
+ paragraph,
429
+ voice=voice,
430
+ speed=speed,
431
+ split_pattern=r'\n+'
432
+ )
433
+ for _, _, audio in generator:
434
+ combined_audio.extend(audio)
435
 
436
+ # Convert combined audio to NumPy array and save
437
+ combined_audio = np.array(combined_audio)
438
+ filename = "combined_story.wav"
439
+ sf.write(filename, combined_audio, 24000) # Save audio as .wav
440
+
441
+ clear_memory()
442
+ return filename
443
+
444
+ except Exception as e:
445
+ print(f"Error generating audio: {e}")
446
+ clear_memory()
447
+ return None
448
+
449
+ # Helper functions
450
  def clean_story_output(story):
451
+ """Clean up the generated story text."""
452
  story = story.replace("<|im_end|>", "")
453
 
454
  story_start = story.find("Once upon")
 
473
  return '\n\n'.join(cleaned_lines).strip()
474
 
475
  def split_into_paragraphs(text):
476
+ """Split text into paragraphs."""
477
  paragraphs = []
478
  current_paragraph = []
479
 
 
494
  'keep it simple', 'end with', 'write a'])]
495
 
496
  def process_generated_prompt(prompt, paragraph):
497
+ """Process and clean up generated image prompts."""
498
  prompt = prompt.replace("<|im_start|>", "").replace("<|im_end|>", "")
499
  prompt = prompt.replace("assistant", "").replace("system", "").replace("user", "")
500
 
 
513
 
514
  return prompt
515
 
516
+ # Create the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  def create_interface():
518
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
519
  gr.Markdown("# Tech Tales: Story Creation")
520
 
521
  with gr.Row():