Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -35,7 +35,7 @@ def img2text(url):
|
|
35 |
# Function to generate a kid-friendly superhero story from the text caption
|
36 |
def text2story(text):
|
37 |
"""
|
38 |
-
Generates a
|
39 |
|
40 |
Args:
|
41 |
text (str): Text caption generated from the image.
|
@@ -49,7 +49,7 @@ def text2story(text):
|
|
49 |
|
50 |
# Generate the story with the superhero genre
|
51 |
prompt = f"<BOS> <superhero> {text}" # Add genre tags to the prompt
|
52 |
-
story = story_generator(prompt, max_length=
|
53 |
|
54 |
# Remove <BOS> and <superhero> tags from the generated story
|
55 |
story = story.replace("<BOS>", "").replace("<superhero>", "").strip()
|
@@ -61,6 +61,14 @@ def text2story(text):
|
|
61 |
# Ensure the story is within 100 words by truncating if necessary
|
62 |
story = " ".join(story.split()[:100])
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
return story
|
65 |
except Exception as e:
|
66 |
st.error(f"Error generating story: {e}") # Display error message if something goes wrong
|
@@ -109,6 +117,14 @@ def main():
|
|
109 |
# Upload image
|
110 |
uploaded_file = st.file_uploader("π· **Upload your picture here!**", type=["jpg", "jpeg", "png"])
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
if uploaded_file is not None:
|
113 |
# Save the uploaded file to disk
|
114 |
image_bytes = uploaded_file.getvalue()
|
@@ -120,22 +136,26 @@ def main():
|
|
120 |
|
121 |
# Stage 1: Image to Text
|
122 |
with st.spinner('β¨ Turning your picture into words...'):
|
123 |
-
scenario
|
124 |
-
|
125 |
-
st.
|
|
|
|
|
|
|
126 |
|
127 |
# Stage 2: Text to Story
|
128 |
with st.spinner('π Creating a fun superhero story for you...'):
|
129 |
-
story
|
130 |
-
|
131 |
-
|
|
|
|
|
132 |
|
133 |
# Stage 3: Story to Audio
|
134 |
with st.spinner('π§ Turning your story into audio...'):
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
# Play button for the generated audio
|
140 |
if st.button("π΅ **Play Audio**"):
|
141 |
if os.path.exists(st.session_state.audio_file):
|
@@ -143,11 +163,9 @@ def main():
|
|
143 |
else:
|
144 |
st.error("Audio file not found. Please try again.") # Display error if audio file is missing
|
145 |
|
146 |
-
# Clean up temporary files (uploaded image
|
147 |
if os.path.exists(uploaded_file.name):
|
148 |
os.remove(uploaded_file.name) # Delete the uploaded image file
|
149 |
-
if 'audio_file' in st.session_state and os.path.exists(st.session_state.audio_file):
|
150 |
-
os.remove(st.session_state.audio_file) # Delete the generated audio file
|
151 |
|
152 |
# Run the application
|
153 |
if __name__ == "__main__":
|
|
|
35 |
# Function to generate a kid-friendly superhero story from the text caption
|
36 |
def text2story(text):
|
37 |
"""
|
38 |
+
Generates a superhero story from the text caption using the pranavpsv/gpt2-genre-story-generator model.
|
39 |
|
40 |
Args:
|
41 |
text (str): Text caption generated from the image.
|
|
|
49 |
|
50 |
# Generate the story with the superhero genre
|
51 |
prompt = f"<BOS> <superhero> {text}" # Add genre tags to the prompt
|
52 |
+
story = story_generator(prompt, max_length=150, num_return_sequences=1)[0]['generated_text']
|
53 |
|
54 |
# Remove <BOS> and <superhero> tags from the generated story
|
55 |
story = story.replace("<BOS>", "").replace("<superhero>", "").strip()
|
|
|
61 |
# Ensure the story is within 100 words by truncating if necessary
|
62 |
story = " ".join(story.split()[:100])
|
63 |
|
64 |
+
# If the story is too short, regenerate it
|
65 |
+
if len(story.split()) < 50: # Minimum 50 words
|
66 |
+
story = story_generator(prompt, max_length=200, num_return_sequences=1)[0]['generated_text']
|
67 |
+
story = story.replace("<BOS>", "").replace("<superhero>", "").strip()
|
68 |
+
if text in story:
|
69 |
+
story = story.replace(text, "").strip()
|
70 |
+
story = " ".join(story.split()[:100])
|
71 |
+
|
72 |
return story
|
73 |
except Exception as e:
|
74 |
st.error(f"Error generating story: {e}") # Display error message if something goes wrong
|
|
|
117 |
# Upload image
|
118 |
uploaded_file = st.file_uploader("π· **Upload your picture here!**", type=["jpg", "jpeg", "png"])
|
119 |
|
120 |
+
# Initialize session state variables
|
121 |
+
if 'scenario' not in st.session_state:
|
122 |
+
st.session_state.scenario = None
|
123 |
+
if 'story' not in st.session_state:
|
124 |
+
st.session_state.story = None
|
125 |
+
if 'audio_file' not in st.session_state:
|
126 |
+
st.session_state.audio_file = None
|
127 |
+
|
128 |
if uploaded_file is not None:
|
129 |
# Save the uploaded file to disk
|
130 |
image_bytes = uploaded_file.getvalue()
|
|
|
136 |
|
137 |
# Stage 1: Image to Text
|
138 |
with st.spinner('β¨ Turning your picture into words...'):
|
139 |
+
if st.session_state.scenario is None or uploaded_file.name != st.session_state.get('uploaded_file_name', None):
|
140 |
+
st.session_state.scenario = img2text(uploaded_file.name) # Generate text caption from the image
|
141 |
+
st.session_state.uploaded_file_name = uploaded_file.name # Store the uploaded file name
|
142 |
+
|
143 |
+
if st.session_state.scenario:
|
144 |
+
st.write("**What we see:**", st.session_state.scenario) # Display the generated caption
|
145 |
|
146 |
# Stage 2: Text to Story
|
147 |
with st.spinner('π Creating a fun superhero story for you...'):
|
148 |
+
if st.session_state.story is None or uploaded_file.name != st.session_state.uploaded_file_name:
|
149 |
+
st.session_state.story = text2story(st.session_state.scenario) # Generate a superhero story from the caption
|
150 |
+
|
151 |
+
if st.session_state.story:
|
152 |
+
st.write("**Your superhero story:**", st.session_state.story) # Display the generated story
|
153 |
|
154 |
# Stage 3: Story to Audio
|
155 |
with st.spinner('π§ Turning your story into audio...'):
|
156 |
+
if st.session_state.audio_file is None or uploaded_file.name != st.session_state.uploaded_file_name:
|
157 |
+
st.session_state.audio_file = text2audio(st.session_state.story) # Generate audio file
|
158 |
+
|
|
|
159 |
# Play button for the generated audio
|
160 |
if st.button("π΅ **Play Audio**"):
|
161 |
if os.path.exists(st.session_state.audio_file):
|
|
|
163 |
else:
|
164 |
st.error("Audio file not found. Please try again.") # Display error if audio file is missing
|
165 |
|
166 |
+
# Clean up temporary files (uploaded image)
|
167 |
if os.path.exists(uploaded_file.name):
|
168 |
os.remove(uploaded_file.name) # Delete the uploaded image file
|
|
|
|
|
169 |
|
170 |
# Run the application
|
171 |
if __name__ == "__main__":
|