Spaces:
Runtime error
Runtime error
import gradio as gr | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
import numpy as np | |
import textwrap | |
import os | |
import gc | |
import re | |
from datetime import datetime | |
import spaces | |
from kokoro import KPipeline | |
import soundfile as sf | |
# Initialize models at startup - outside of functions | |
print("Loading models...") | |
# Load SmolVLM for image analysis | |
processor_vlm = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct") | |
model_vlm = AutoModelForVision2Seq.from_pretrained( | |
"HuggingFaceTB/SmolVLM-500M-Instruct", | |
torch_dtype=torch.bfloat16, | |
use_safetensors=True | |
) | |
# Load SmolLM2 for story and prompt generation | |
checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | |
tokenizer_lm = AutoTokenizer.from_pretrained(checkpoint) | |
model_lm = AutoModelForCausalLM.from_pretrained( | |
checkpoint, | |
use_safetensors=True | |
) | |
# Load Stable Diffusion pipeline | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch.float16, | |
use_safetensors=True | |
) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
# Move models to GPU if available | |
if torch.cuda.is_available(): | |
model_vlm = model_vlm.to("cuda") | |
model_lm = model_lm.to("cuda") | |
pipe = pipe.to("cuda") | |
def generate_image(): | |
"""Generate a random landscape image.""" | |
torch.cuda.empty_cache() | |
default_prompt = "a beautiful, professional landscape photograph" | |
default_negative_prompt = "blurry, bad quality, distorted, deformed" | |
default_steps = 30 | |
default_guidance = 7.5 | |
default_seed = torch.randint(0, 2**32 - 1, (1,)).item() | |
generator = torch.Generator("cuda").manual_seed(default_seed) | |
image = pipe( | |
prompt=default_prompt, | |
negative_prompt=default_negative_prompt, | |
num_inference_steps=default_steps, | |
guidance_scale=default_guidance, | |
generator=generator, | |
).images[0] | |
return image | |
def analyze_image(image): | |
if image is None: | |
return "Please generate an image first." | |
torch.cuda.empty_cache() | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": "Describe this image very briefly in five sentences or less."} | |
] | |
} | |
] | |
prompt = processor_vlm.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = processor_vlm( | |
text=prompt, | |
images=[image], | |
return_tensors="pt" | |
).to('cuda') | |
outputs = model_vlm.generate( | |
input_ids=inputs.input_ids, | |
pixel_values=inputs.pixel_values, | |
attention_mask=inputs.attention_mask, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
max_new_tokens=500, | |
min_new_tokens=10 | |
) | |
description = processor_vlm.decode(outputs[0], skip_special_tokens=True) | |
description = re.sub(r".*?Assistant:\s*", "", description, flags=re.DOTALL).strip() | |
return description | |
def generate_story(image_description): | |
torch.cuda.empty_cache() | |
story_prompt = f"""Write a short children's story (one chapter, about 500 words) based on this scene: {image_description} | |
Requirements: | |
1. Main character: An English bulldog named Champ | |
2. Include these values: confidence, teamwork, caring, and hope | |
3. Theme: "We are stronger together than as individuals" | |
4. Keep it simple and engaging for young children | |
5. End with a simple moral lesson""" | |
messages = [{"role": "user", "content": story_prompt}] | |
input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False) | |
inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda") | |
outputs = model_lm.generate( | |
inputs, | |
max_new_tokens=750, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.2 | |
) | |
story = tokenizer_lm.decode(outputs[0]) | |
story = clean_story_output(story) | |
return story | |
def generate_image_prompts(story_text): | |
torch.cuda.empty_cache() | |
paragraphs = split_into_paragraphs(story_text) | |
all_prompts = [] | |
prompt_instruction = '''Here is a story paragraph: {paragraph} | |
Start your response with "Watercolor bulldog" and describe what Champ is doing in this scene. Add where it takes place and one mood detail. Keep it short.''' | |
for i, paragraph in enumerate(paragraphs, 1): | |
messages = [{"role": "user", "content": prompt_instruction.format(paragraph=paragraph)}] | |
input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False) | |
inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda") | |
outputs = model_lm.generate( | |
inputs, | |
max_new_tokens=30, | |
temperature=0.5, | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.2 | |
) | |
prompt = process_generated_prompt(tokenizer_lm.decode(outputs[0]), paragraph) | |
section = f"Paragraph {i}:\n{paragraph}\n\nScenery Prompt {i}:\n{prompt}\n\n{'='*50}" | |
all_prompts.append(section) | |
return '\n'.join(all_prompts) | |
def generate_story_image(prompt): | |
torch.cuda.empty_cache() | |
pipe.load_lora_weights("Prof-Hunt/lora-bulldog") | |
enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors" | |
image = pipe( | |
prompt=enhanced_prompt, | |
negative_prompt="deformed, ugly, blurry, bad art, poor quality, distorted", | |
num_inference_steps=50, | |
guidance_scale=15, | |
).images[0] | |
return image | |
# Longer duration for multiple image generation | |
def generate_all_scenes(prompts_text): | |
generated_images = [] | |
formatted_prompts = [] | |
sections = prompts_text.split('='*50) | |
for section in sections: | |
if not section.strip(): | |
continue | |
lines = [line.strip() for line in section.split('\n') if line.strip()] | |
scene_prompt = None | |
for i, line in enumerate(lines): | |
if 'Scenery Prompt' in line: | |
scene_num = line.split('Scenery Prompt')[1].split(':')[0].strip() | |
if i + 1 < len(lines): | |
scene_prompt = lines[i + 1] | |
formatted_prompts.append(f"Scene {scene_num}: {scene_prompt}") | |
break | |
if scene_prompt: | |
try: | |
torch.cuda.empty_cache() | |
image = generate_story_image(scene_prompt) | |
if image is not None: | |
generated_images.append(np.array(image)) | |
except Exception as e: | |
print(f"Error generating image: {str(e)}") | |
continue | |
return generated_images, "\n\n".join(formatted_prompts) | |
# Helper functions without GPU usage | |
def clean_story_output(story): | |
story = story.replace("<|im_end|>", "") | |
story_start = story.find("Once upon") | |
if story_start == -1: | |
possible_starts = ["One day", "In a", "There was", "Champ"] | |
for marker in possible_starts: | |
story_start = story.find(marker) | |
if story_start != -1: | |
break | |
if story_start != -1: | |
story = story[story_start:] | |
lines = story.split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
line = line.strip() | |
if line and not any(skip in line.lower() for skip in ['requirement', 'include these values', 'theme:', 'keep it simple', 'end with', 'write a']): | |
if not line.startswith(('1.', '2.', '3.', '4.', '5.')): | |
cleaned_lines.append(line) | |
return '\n\n'.join(cleaned_lines).strip() | |
def split_into_paragraphs(text): | |
paragraphs = [] | |
current_paragraph = [] | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
if current_paragraph: | |
paragraphs.append(' '.join(current_paragraph)) | |
current_paragraph = [] | |
else: | |
current_paragraph.append(line) | |
if current_paragraph: | |
paragraphs.append(' '.join(current_paragraph)) | |
return [p for p in paragraphs if not any(skip in p.lower() | |
for skip in ['requirement', 'include these values', 'theme:', | |
'keep it simple', 'end with', 'write a'])] | |
def process_generated_prompt(prompt, paragraph): | |
prompt = prompt.replace("<|im_start|>", "").replace("<|im_end|>", "") | |
prompt = prompt.replace("assistant", "").replace("system", "").replace("user", "") | |
cleaned_lines = [line.strip() for line in prompt.split('\n') | |
if line.strip().lower().startswith("watercolor bulldog")] | |
if cleaned_lines: | |
prompt = cleaned_lines[0] | |
else: | |
setting = "quiet town" if "quiet town" in paragraph.lower() else "park" | |
mood = "hopeful" if "wished" in paragraph.lower() else "peaceful" | |
prompt = f"Watercolor bulldog watching friends play in {setting}, {mood} atmosphere." | |
if not prompt.endswith('.'): | |
prompt = prompt + '.' | |
return prompt | |
def overlay_text_on_image(image, text): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
img = image.convert('RGB') | |
draw = ImageDraw.Draw(img) | |
try: | |
font_size = int(img.width * 0.025) | |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size) | |
except: | |
font = ImageFont.load_default() | |
y_position = int(img.height * 0.005) | |
x_margin = int(img.width * 0.005) | |
available_width = img.width - (2 * x_margin) | |
wrapped_text = textwrap.fill(text, width=int(available_width / (font_size * 0.6))) | |
outline_color = (255, 255, 255) | |
text_color = (0, 0, 0) | |
offsets = [-2, -1, 1, 2] | |
for dx in offsets: | |
for dy in offsets: | |
draw.multiline_text( | |
(x_margin + dx, y_position + dy), | |
wrapped_text, | |
font=font, | |
fill=outline_color | |
) | |
draw.multiline_text( | |
(x_margin, y_position), | |
wrapped_text, | |
font=font, | |
fill=text_color | |
) | |
return img | |
# Initialize Kokoro TTS pipeline | |
pipeline = KPipeline(lang_code='a') # 'a' for American English | |
def generate_combined_audio_from_story(story_text, voice='af_heart', speed=1): | |
"""Generate a single audio file for all paragraphs in the story.""" | |
if not story_text: | |
return None | |
# Split story into paragraphs | |
paragraphs = [] | |
current_paragraph = [] | |
for line in story_text.split('\n'): | |
line = line.strip() | |
if not line: # Empty line indicates paragraph break | |
if current_paragraph: | |
paragraphs.append(' '.join(current_paragraph)) | |
current_paragraph = [] | |
else: | |
current_paragraph.append(line) | |
if current_paragraph: | |
paragraphs.append(' '.join(current_paragraph)) | |
# Combine audio for all paragraphs | |
combined_audio = [] | |
for paragraph in paragraphs: | |
if not paragraph.strip(): | |
continue # Skip empty paragraphs | |
generator = pipeline( | |
paragraph, | |
voice=voice, | |
speed=speed, | |
split_pattern=r'\n+' # Split on newlines | |
) | |
for _, _, audio in generator: | |
combined_audio.extend(audio) # Append audio data | |
# Convert combined audio to NumPy array and save | |
combined_audio = np.array(combined_audio) | |
filename = "combined_story.wav" | |
sf.write(filename, combined_audio, 24000) # Save audio as .wav | |
return filename | |
def add_text_to_scenes(gallery_images, prompts_text): | |
if not isinstance(gallery_images, list): | |
return [], [] | |
sections = prompts_text.split('='*50) | |
overlaid_images = [] | |
output_files = [] | |
temp_dir = "temp_book_pages" | |
os.makedirs(temp_dir, exist_ok=True) | |
for i, (image_data, section) in enumerate(zip(gallery_images, sections)): | |
if not section.strip(): | |
continue | |
lines = [line.strip() for line in section.split('\n') if line.strip()] | |
paragraph = None | |
for j, line in enumerate(lines): | |
if line.startswith('Paragraph'): | |
if j + 1 < len(lines): | |
paragraph = lines[j + 1] | |
break | |
if paragraph and image_data is not None: | |
try: | |
overlaid_img = overlay_text_on_image(image_data, paragraph) | |
if overlaid_img is not None: | |
overlaid_array = np.array(overlaid_img) | |
overlaid_images.append(overlaid_array) | |
output_path = os.path.join(temp_dir, f"panel_{i+1}.png") | |
overlaid_img.save(output_path) | |
output_files.append(output_path) | |
except Exception as e: | |
print(f"Error processing image: {str(e)}") | |
continue | |
return overlaid_images, output_files | |
def create_interface(): | |
theme = gr.themes.Soft().set( | |
body_background_fill="*primary_50", | |
button_primary_background_fill="rgb(173, 216, 230)", # light blue | |
button_secondary_background_fill="rgb(255, 182, 193)", # light red | |
button_primary_background_fill_hover="rgb(135, 206, 235)", # slightly darker blue for hover | |
button_secondary_background_fill_hover="rgb(255, 160, 180)", # slightly darker red for hover | |
block_title_text_color="*primary_500", | |
block_label_text_color="*secondary_500", | |
) | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown("# Tech Tales: Story Creation") | |
with gr.Row(): | |
generate_btn = gr.Button("1. Generate Random Landscape") | |
image_output = gr.Image(label="Generated Image", type="pil") | |
with gr.Row(): | |
analyze_btn = gr.Button("2. Get Brief Description") | |
analysis_output = gr.Textbox(label="Image Description", lines=3) | |
with gr.Row(): | |
story_btn = gr.Button("3. Create Children's Story") | |
story_output = gr.Textbox(label="Generated Story", lines=10) | |
with gr.Row(): | |
prompts_btn = gr.Button("4. Generate Scene Prompts") | |
prompts_output = gr.Textbox(label="Generated Scene Prompts", lines=20) | |
with gr.Row(): | |
generate_scenes_btn = gr.Button("5. Generate Story Scenes", variant="primary") | |
with gr.Row(): | |
scene_prompts_display = gr.Textbox( | |
label="Scenes Being Generated", | |
lines=8, | |
interactive=False | |
) | |
with gr.Row(): | |
gallery = gr.Gallery( | |
label="Story Scenes", | |
show_label=True, | |
columns=2, | |
height="auto" | |
) | |
with gr.Row(): | |
add_text_btn = gr.Button("6. Add Text to Scenes", variant="primary") | |
with gr.Row(): | |
final_gallery = gr.Gallery( | |
label="Story Book Pages", | |
show_label=True, | |
columns=2, | |
height="auto" | |
) | |
with gr.Row(): | |
download_btn = gr.File( | |
label="Download Story Book", | |
file_count="multiple", | |
interactive=False | |
) | |
with gr.Row(): | |
tts_btn = gr.Button("7. Read Story Aloud") | |
audio_output = gr.Audio(label="Story Audio") | |
# Event handlers | |
generate_btn.click( | |
fn=generate_image, | |
outputs=image_output | |
) | |
analyze_btn.click( | |
fn=analyze_image, | |
inputs=[image_output], | |
outputs=analysis_output | |
) | |
story_btn.click( | |
fn=generate_story, | |
inputs=[analysis_output], | |
outputs=story_output | |
) | |
prompts_btn.click( | |
fn=generate_image_prompts, | |
inputs=[story_output], | |
outputs=prompts_output | |
) | |
generate_scenes_btn.click( | |
fn=generate_all_scenes, | |
inputs=[prompts_output], | |
outputs=[gallery, scene_prompts_display] | |
) | |
add_text_btn.click( | |
fn=add_text_to_scenes, | |
inputs=[gallery, prompts_output], | |
outputs=[final_gallery, download_btn] | |
) | |
tts_btn.click( | |
fn=generate_combined_audio_from_story, | |
inputs=[story_output], | |
outputs=audio_output | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |