Spaces:
Runtime error
Runtime error
import gradio as gr | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
import numpy as np | |
import textwrap | |
import os | |
import gc | |
import re | |
import psutil | |
from datetime import datetime | |
import spaces | |
from kokoro import KPipeline | |
import soundfile as sf | |
def clear_memory(): | |
"""Helper function to clear both CUDA and system memory, safe for Spaces environment""" | |
gc.collect() | |
# Only perform CUDA operations if we're in a GPU task context | |
if hasattr(spaces, "current_task") and spaces.current_task and torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
process = psutil.Process(os.getpid()) | |
if hasattr(process, 'memory_info'): | |
process.memory_info().rss | |
gc.collect(generation=0) | |
gc.collect(generation=1) | |
gc.collect(generation=2) | |
# Only log GPU stats if we're in a GPU task context | |
if hasattr(spaces, "current_task") and spaces.current_task and torch.cuda.is_available(): | |
print(f"GPU Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB") | |
print(f"GPU Memory cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB") | |
print(f"CPU RAM used: {process.memory_info().rss/1024**2:.2f} MB") | |
# Initialize models at startup - only the lightweight ones | |
print("Loading models...") | |
# Load SmolVLM for image analysis | |
processor_vlm = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct") | |
model_vlm = AutoModelForVision2Seq.from_pretrained( | |
"HuggingFaceTB/SmolVLM-500M-Instruct", | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
# Load SmolLM2 for story and prompt generation | |
checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | |
tokenizer_lm = AutoTokenizer.from_pretrained(checkpoint) | |
model_lm = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda") | |
# Initialize Kokoro TTS pipeline | |
pipeline = KPipeline(lang_code='a') # 'a' for American English | |
def load_sd_model(): | |
"""Load Stable Diffusion model only when needed""" | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch.float16, | |
) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.to("cuda") | |
pipe.enable_attention_slicing() | |
return pipe | |
def generate_image(): | |
"""Generate a random landscape image.""" | |
clear_memory() | |
pipe = load_sd_model() | |
default_prompt = "a beautiful, professional landscape photograph" | |
default_negative_prompt = "blurry, bad quality, distorted, deformed" | |
default_steps = 30 | |
default_guidance = 7.5 | |
default_seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
generator = torch.Generator("cuda").manual_seed(default_seed) | |
try: | |
image = pipe( | |
prompt=default_prompt, | |
negative_prompt=default_negative_prompt, | |
num_inference_steps=default_steps, | |
guidance_scale=default_guidance, | |
generator=generator, | |
).images[0] | |
del pipe | |
clear_memory() | |
return image | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
if 'pipe' in locals(): | |
del pipe | |
clear_memory() | |
return None | |
def analyze_image(image): | |
if image is None: | |
return "Please generate an image first." | |
clear_memory() | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": "Describe this image and Be brief but descriptive."} | |
] | |
} | |
] | |
try: | |
prompt = processor_vlm.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = processor_vlm( | |
text=prompt, | |
images=[image], | |
return_tensors="pt" | |
).to('cuda') | |
outputs = model_vlm.generate( | |
input_ids=inputs.input_ids, | |
pixel_values=inputs.pixel_values, | |
attention_mask=inputs.attention_mask, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
max_new_tokens=500, | |
min_new_tokens=10 | |
) | |
description = processor_vlm.decode(outputs[0], skip_special_tokens=True) | |
description = re.sub(r".*?Assistant:\s*", "", description, flags=re.DOTALL).strip() | |
# Split into sentences and take only the first three | |
sentences = re.split(r'(?<=[.!?])\s+', description) | |
description = ' '.join(sentences[:3]) | |
clear_memory() | |
return description | |
except Exception as e: | |
print(f"Error analyzing image: {e}") | |
clear_memory() | |
return "Error analyzing image. Please try again." | |
def generate_story(image_description): | |
clear_memory() | |
story_prompt = f"""Write a short children's story (one chapter, about 500 words) based on this scene: {image_description} | |
Requirements: | |
1. Main character: An English bulldog named Champ | |
2. Include these values: confidence, teamwork, caring, and hope | |
3. Theme: "We are stronger together than as individuals" | |
4. Keep it simple and engaging for young children | |
5. End with a simple moral lesson""" | |
try: | |
messages = [{"role": "user", "content": story_prompt}] | |
input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False) | |
inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda") | |
outputs = model_lm.generate( | |
inputs, | |
max_new_tokens=750, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.2 | |
) | |
story = tokenizer_lm.decode(outputs[0]) | |
story = clean_story_output(story) | |
clear_memory() | |
return story | |
except Exception as e: | |
print(f"Error generating story: {e}") | |
clear_memory() | |
return "Error generating story. Please try again." | |
def generate_image_prompts(story_text): | |
clear_memory() | |
paragraphs = split_into_paragraphs(story_text) | |
all_prompts = [] | |
prompt_instruction = '''Here is a story paragraph: {paragraph} | |
Start your response with "Watercolor bulldog" and describe what Champ is doing in this scene. Add where it takes place and one mood detail. Keep it short.''' | |
try: | |
for i, paragraph in enumerate(paragraphs, 1): | |
messages = [{"role": "user", "content": prompt_instruction.format(paragraph=paragraph)}] | |
input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False) | |
inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda") | |
outputs = model_lm.generate( | |
inputs, | |
max_new_tokens=30, | |
temperature=0.5, | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.2 | |
) | |
prompt = process_generated_prompt(tokenizer_lm.decode(outputs[0]), paragraph) | |
section = f"Paragraph {i}:\n{paragraph}\n\nScenery Prompt {i}:\n{prompt}\n\n{'='*50}" | |
all_prompts.append(section) | |
clear_memory() | |
return '\n'.join(all_prompts) | |
except Exception as e: | |
print(f"Error generating prompts: {e}") | |
clear_memory() | |
return "Error generating prompts. Please try again." | |
def generate_story_image(prompt, seed=-1): | |
clear_memory() | |
pipe = load_sd_model() | |
try: | |
pipe.load_lora_weights("Prof-Hunt/lora-bulldog") | |
generator = torch.Generator("cuda") | |
if seed != -1: | |
generator.manual_seed(seed) | |
else: | |
generator.manual_seed(torch.randint(0, 2**32 - 1, (1,)).item()) | |
enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors" | |
image = pipe( | |
prompt=enhanced_prompt, | |
negative_prompt="deformed, ugly, blurry, bad art, poor quality, distorted", | |
num_inference_steps=50, | |
guidance_scale=15, | |
generator=generator | |
).images[0] | |
pipe.unload_lora_weights() | |
del pipe | |
clear_memory() | |
return image | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
if 'pipe' in locals(): | |
pipe.unload_lora_weights() | |
del pipe | |
clear_memory() | |
return None | |
def generate_all_scenes(prompts_text): | |
clear_memory() | |
generated_images = [] | |
formatted_prompts = [] | |
progress_messages = [] | |
total_scenes = len([s for s in prompts_text.split('='*50) if s.strip()]) | |
def update_progress(): | |
"""Create a progress message showing completed/total scenes""" | |
completed = len(generated_images) | |
message = f"Generated {completed}/{total_scenes} scenes\n\n" | |
if progress_messages: | |
message += "\n".join(progress_messages[-3:]) # Show last 3 status messages | |
return message | |
sections = prompts_text.split('='*50) | |
for section_num, section in enumerate(sections, 1): | |
if not section.strip(): | |
continue | |
scene_prompt = None | |
for line in section.split('\n'): | |
if 'Scenery Prompt' in line: | |
scene_num = line.split('Scenery Prompt')[1].split(':')[0].strip() | |
next_line_index = section.split('\n').index(line) + 1 | |
if next_line_index < len(section.split('\n')): | |
scene_prompt = section.split('\n')[next_line_index].strip() | |
formatted_prompts.append(f"Scene {scene_num}: {scene_prompt}") | |
break | |
if scene_prompt: | |
try: | |
clear_memory() | |
status_msg = f"🎨 Creating scene {section_num}: '{scene_prompt[:50]}...'" | |
progress_messages.append(status_msg) | |
# Yield progress update | |
yield generated_images, "\n\n".join(formatted_prompts), update_progress() | |
image = generate_story_image(scene_prompt) | |
if image is not None: | |
# Convert PIL Image to numpy array with explicit mode conversion | |
pil_image = image if isinstance(image, Image.Image) else Image.fromarray(image) | |
pil_image = pil_image.convert('RGB') # Ensure RGB mode | |
img_array = np.array(pil_image) | |
# Verify array shape and type | |
if len(img_array.shape) == 3 and img_array.shape[2] == 3: | |
generated_images.append(img_array) | |
progress_messages.append(f"✅ Successfully completed scene {section_num}") | |
else: | |
progress_messages.append(f"❌ Error: Invalid image format for scene {section_num}") | |
else: | |
progress_messages.append(f"❌ Failed to generate scene {section_num}") | |
clear_memory() | |
except Exception as e: | |
error_msg = f"❌ Error generating scene {section_num}: {str(e)}" | |
progress_messages.append(error_msg) | |
clear_memory() | |
continue | |
# Yield progress update after each scene | |
yield generated_images, "\n\n".join(formatted_prompts), update_progress() | |
# Final status update | |
if not generated_images: | |
progress_messages.append("❌ No images were successfully generated") | |
else: | |
progress_messages.append(f"✅ Successfully completed all {len(generated_images)} scenes!") | |
# Final yield | |
yield generated_images, "\n\n".join(formatted_prompts), update_progress() | |
def add_text_to_scenes(gallery_images, prompts_text): | |
if not isinstance(gallery_images, list): | |
return [], [] | |
clear_memory() | |
sections = prompts_text.split('='*50) | |
overlaid_images = [] | |
output_files = [] | |
temp_dir = "temp_book_pages" | |
os.makedirs(temp_dir, exist_ok=True) | |
for i, (image_data, section) in enumerate(zip(gallery_images, sections)): | |
if not section.strip(): | |
continue | |
lines = [line.strip() for line in section.split('\n') if line.strip()] | |
paragraph = None | |
for j, line in enumerate(lines): | |
if line.startswith('Paragraph'): | |
if j + 1 < len(lines): | |
paragraph = lines[j + 1] | |
break | |
if paragraph and image_data is not None: | |
try: | |
# Handle tuple case (image, label) from gallery | |
if isinstance(image_data, tuple): | |
image_data = image_data[0] | |
# Convert numpy array to PIL Image | |
if isinstance(image_data, np.ndarray): | |
image = Image.fromarray(image_data) | |
else: | |
image = image_data | |
print(f"Processing image {i+1}, type: {type(image)}") | |
# Ensure we have a PIL Image | |
if not isinstance(image, Image.Image): | |
raise TypeError(f"Expected PIL Image, got {type(image)}") | |
overlaid_img = overlay_text_on_image(image, paragraph) | |
if overlaid_img is not None: | |
overlaid_array = np.array(overlaid_img) | |
overlaid_images.append(overlaid_array) | |
output_path = os.path.join(temp_dir, f"panel_{i+1}.png") | |
overlaid_img.save(output_path) | |
output_files.append(output_path) | |
print(f"Successfully processed image {i+1}") | |
except Exception as e: | |
print(f"Error processing image {i+1}: {str(e)}") | |
continue | |
if not overlaid_images: | |
print("No images were successfully processed") | |
else: | |
print(f"Successfully processed {len(overlaid_images)} images") | |
clear_memory() | |
return overlaid_images, output_files | |
def overlay_text_on_image(image, text): | |
"""Helper function to overlay text on an image""" | |
if image is None: | |
return None | |
try: | |
# Ensure we're working with RGB mode | |
img = image.convert('RGB') | |
draw = ImageDraw.Draw(img) | |
# Calculate font size based on image dimensions | |
font_size = int(img.width * 0.025) | |
try: | |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size) | |
except: | |
font = ImageFont.load_default() | |
# Calculate text positioning | |
y_position = int(img.height * 0.005) | |
x_margin = int(img.width * 0.005) | |
available_width = img.width - (2 * x_margin) | |
# Wrap text to fit image width | |
wrapped_text = textwrap.fill(text, width=int(available_width / (font_size * 0.6))) | |
# Add white outline to text for better readability | |
outline_color = (255, 255, 255) | |
text_color = (0, 0, 0) | |
offsets = [-2, -1, 1, 2] | |
# Draw text outline | |
for dx in offsets: | |
for dy in offsets: | |
draw.multiline_text( | |
(x_margin + dx, y_position + dy), | |
wrapped_text, | |
font=font, | |
fill=outline_color | |
) | |
# Draw main text | |
draw.multiline_text( | |
(x_margin, y_position), | |
wrapped_text, | |
font=font, | |
fill=text_color | |
) | |
return img | |
except Exception as e: | |
print(f"Error in overlay_text_on_image: {e}") | |
return None | |
def generate_combined_audio_from_story(story_text, voice='af_heart', speed=1): | |
clear_memory() | |
if not story_text: | |
return None | |
paragraphs = split_into_paragraphs(story_text) | |
combined_audio = [] | |
try: | |
for paragraph in paragraphs: | |
if not paragraph.strip(): | |
continue | |
generator = pipeline( | |
paragraph, | |
voice=voice, | |
speed=speed, | |
split_pattern=r'\n+' | |
) | |
for _, _, audio in generator: | |
combined_audio.extend(audio) | |
# Convert combined audio to NumPy array and save | |
combined_audio = np.array(combined_audio) | |
filename = "combined_story.wav" | |
sf.write(filename, combined_audio, 24000) # Save audio as .wav | |
clear_memory() | |
return filename | |
except Exception as e: | |
print(f"Error generating audio: {e}") | |
clear_memory() | |
return None | |
# Helper functions | |
def clean_story_output(story): | |
"""Clean up the generated story text.""" | |
story = story.replace("<|im_end|>", "") | |
story_start = story.find("Once upon") | |
if story_start == -1: | |
possible_starts = ["One day", "In a", "There was", "Champ"] | |
for marker in possible_starts: | |
story_start = story.find(marker) | |
if story_start != -1: | |
break | |
if story_start != -1: | |
story = story[story_start:] | |
lines = story.split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
line = line.strip() | |
if line and not any(skip in line.lower() for skip in ['requirement', 'include these values', 'theme:', 'keep it simple', 'end with', 'write a']): | |
if not line.startswith(('1.', '2.', '3.', '4.', '5.')): | |
cleaned_lines.append(line) | |
return '\n\n'.join(cleaned_lines).strip() | |
def split_into_paragraphs(text): | |
"""Split text into paragraphs.""" | |
paragraphs = [] | |
current_paragraph = [] | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
if current_paragraph: | |
paragraphs.append(' '.join(current_paragraph)) | |
current_paragraph = [] | |
else: | |
current_paragraph.append(line) | |
if current_paragraph: | |
paragraphs.append(' '.join(current_paragraph)) | |
return [p for p in paragraphs if not any(skip in p.lower() | |
for skip in ['requirement', 'include these values', 'theme:', | |
'keep it simple', 'end with', 'write a'])] | |
def process_generated_prompt(prompt, paragraph): | |
"""Process and clean up generated image prompts.""" | |
prompt = prompt.replace("<|im_start|>", "").replace("<|im_end|>", "") | |
prompt = prompt.replace("assistant", "").replace("system", "").replace("user", "") | |
cleaned_lines = [line.strip() for line in prompt.split('\n') | |
if line.strip().lower().startswith("watercolor bulldog")] | |
if cleaned_lines: | |
prompt = cleaned_lines[0] | |
else: | |
setting = "quiet town" if "quiet town" in paragraph.lower() else "park" | |
mood = "hopeful" if "wished" in paragraph.lower() else "peaceful" | |
prompt = f"Watercolor bulldog watching friends play in {setting}, {mood} atmosphere." | |
if not prompt.endswith('.'): | |
prompt = prompt + '.' | |
return prompt | |
# Create the interface | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Tech Tales: Story Creation") | |
with gr.Row(): | |
generate_btn = gr.Button("1. Generate Random Landscape") | |
with gr.Row(): | |
image_output = gr.Image(label="Generated Image", type="pil", interactive=False) | |
with gr.Row(): | |
analyze_btn = gr.Button("2. Get Brief Description") | |
with gr.Row(): | |
analysis_output = gr.Textbox(label="Image Description", lines=3) | |
with gr.Row(): | |
story_btn = gr.Button("3. Create Children's Story") | |
with gr.Row(): | |
story_output = gr.Textbox(label="Generated Story", lines=10) | |
with gr.Row(): | |
prompts_btn = gr.Button("4. Generate Scene Prompts") | |
with gr.Row(): | |
prompts_output = gr.Textbox(label="Generated Scene Prompts", lines=20) | |
with gr.Row(): | |
generate_scenes_btn = gr.Button("5. Generate Story Scenes", variant="primary") | |
with gr.Row(): | |
scene_progress = gr.Textbox( | |
label="Generation Progress", | |
lines=6, | |
interactive=False | |
) | |
with gr.Row(): | |
gallery = gr.Gallery( | |
label="Story Scenes", | |
show_label=True, | |
columns=2, | |
height="auto", | |
interactive=False | |
) | |
with gr.Row(): | |
scene_prompts_display = gr.Textbox( | |
label="Scene Descriptions", | |
lines=8, | |
interactive=False | |
) | |
with gr.Row(): | |
add_text_btn = gr.Button("6. Add Text to Scenes", variant="primary") | |
with gr.Row(): | |
final_gallery = gr.Gallery( | |
label="Story Book Pages", | |
show_label=True, | |
columns=2, | |
height="auto", | |
interactive=False | |
) | |
with gr.Row(): | |
download_btn = gr.File( | |
label="Download Story Book", | |
file_count="multiple", | |
interactive=False | |
) | |
with gr.Row(): | |
tts_btn = gr.Button("7. Read Story Aloud") | |
audio_output = gr.Audio(label="Story Audio") | |
# Event handlers | |
generate_btn.click( | |
fn=generate_image, | |
outputs=image_output | |
) | |
analyze_btn.click( | |
fn=analyze_image, | |
inputs=[image_output], | |
outputs=analysis_output | |
) | |
story_btn.click( | |
fn=generate_story, | |
inputs=[analysis_output], | |
outputs=story_output | |
) | |
prompts_btn.click( | |
fn=generate_image_prompts, | |
inputs=[story_output], | |
outputs=prompts_output | |
) | |
generate_scenes_btn.click( | |
fn=generate_all_scenes, | |
inputs=[prompts_output], | |
outputs=[gallery, scene_prompts_display, scene_progress] | |
) | |
add_text_btn.click( | |
fn=add_text_to_scenes, | |
inputs=[gallery, prompts_output], | |
outputs=[final_gallery, download_btn] | |
) | |
tts_btn.click( | |
fn=generate_combined_audio_from_story, | |
inputs=[story_output], | |
outputs=audio_output | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |