assignment1 / app.py
koey811's picture
Update app.py
e220f85 verified
raw
history blame
3.88 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from gtts import gTTS
import io
from PIL import Image
# Install PyTorch
try:
import torch
except ImportError:
st.warning("PyTorch is not installed. Installing PyTorch...")
import subprocess
subprocess.run(["pip", "install", "torch"])
st.success("PyTorch has been successfully installed!")
import torch
# Load the image captioning model
caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap")
# Load the GPT-2 model for story generation
#story_generator = pipeline("text-generation", model="gpt2")
story_generator = pipeline("text-generation", model="distilbert/distilgpt2")
def generate_caption(image):
# Generate the caption for the uploaded image
caption = caption_model(image)[0]["generated_text"]
return caption
#def generate_story(caption):
# Generate the story based on the caption using the GPT-2 model
#prompt = f"Starting with 'Once upon a time', based on the image described as '{caption}', here is a short and interesting story for children aged 3-10. The story is positive and happy in tone, with added imagination:\n\n"
#story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
# Extract the story text from the generated output
#story = story.split("\n\n")[1].strip()
#return story
#def generate_story(caption):
# Generate the story based on the caption
#story = story_generator(caption, max_length=200, num_return_sequences=1)[0]["generated_text"]
#return story
def generate_story(caption):
# Generate the story based on the caption using the GPT-2 model
prompt = f"Once upon a time, based on the image described as '{caption}', here is a short, simple, and engaging story for children aged 3-10. The story should be easy to understand, use age-appropriate language, and convey a positive message. Focus on the main elements in the image and create a story that sparks their imagination:\n\n"
story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
# Extract the story text from the generated output
story = story.split("\n\n")[1].strip()
# Post-process the story (example: remove inappropriate words)
inappropriate_words = ["violence", "horror", "scary"]
for word in inappropriate_words:
story = story.replace(word, "")
# Limit the story to approximately 100 words
words = story.split()
if len(words) > 100:
story = " ".join(words[:100]) + "..."
return story
def convert_to_audio(story):
# Convert the story to audio using gTTS
tts = gTTS(text=story, lang="en")
audio_bytes = io.BytesIO()
tts.write_to_fp(audio_bytes)
audio_bytes.seek(0)
return audio_bytes
def main():
st.title("Storytelling Application")
# File uploader for the image (restricted to JPG)
uploaded_image = st.file_uploader("Upload an image", type=["jpg"])
if uploaded_image is not None:
# Convert the uploaded image to PIL image
image = Image.open(uploaded_image)
# Display the uploaded image
st.image(image, caption="Uploaded Image", use_container_width=True)
# Generate the caption for the image
caption = generate_caption(image)
st.subheader("Generated Caption:")
st.write(caption)
# Generate the story based on the caption using the GPT-2 model
story = generate_story(caption)
st.subheader("Generated Story:")
st.write(story)
# Convert the story to audio
audio_bytes = convert_to_audio(story)
# Display the audio player
st.audio(audio_bytes, format="audio/mp3")
if __name__ == "__main__":
main()