smtsead commited on
Commit
aee30ec
Β·
verified Β·
1 Parent(s): 28c2183

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -15
app.py CHANGED
@@ -35,7 +35,7 @@ def img2text(url):
35
  # Function to generate a kid-friendly superhero story from the text caption
36
  def text2story(text):
37
  """
38
- Generates a kid-friendly superhero story from the text caption using the pranavpsv/gpt2-genre-story-generator model.
39
 
40
  Args:
41
  text (str): Text caption generated from the image.
@@ -49,7 +49,7 @@ def text2story(text):
49
 
50
  # Generate the story with the superhero genre
51
  prompt = f"<BOS> <superhero> {text}" # Add genre tags to the prompt
52
- story = story_generator(prompt, max_length=100, num_return_sequences=1)[0]['generated_text']
53
 
54
  # Remove <BOS> and <superhero> tags from the generated story
55
  story = story.replace("<BOS>", "").replace("<superhero>", "").strip()
@@ -61,6 +61,14 @@ def text2story(text):
61
  # Ensure the story is within 100 words by truncating if necessary
62
  story = " ".join(story.split()[:100])
63
 
 
 
 
 
 
 
 
 
64
  return story
65
  except Exception as e:
66
  st.error(f"Error generating story: {e}") # Display error message if something goes wrong
@@ -109,6 +117,14 @@ def main():
109
  # Upload image
110
  uploaded_file = st.file_uploader("πŸ“· **Upload your picture here!**", type=["jpg", "jpeg", "png"])
111
 
 
 
 
 
 
 
 
 
112
  if uploaded_file is not None:
113
  # Save the uploaded file to disk
114
  image_bytes = uploaded_file.getvalue()
@@ -120,22 +136,26 @@ def main():
120
 
121
  # Stage 1: Image to Text
122
  with st.spinner('✨ Turning your picture into words...'):
123
- scenario = img2text(uploaded_file.name) # Generate text caption from the image
124
- if scenario:
125
- st.write("**What we see:**", scenario) # Display the generated caption
 
 
 
126
 
127
  # Stage 2: Text to Story
128
  with st.spinner('πŸ“– Creating a fun superhero story for you...'):
129
- story = text2story(scenario) # Generate a superhero story from the caption
130
- if story:
131
- st.write("**Your superhero story:**", story) # Display the generated story
 
 
132
 
133
  # Stage 3: Story to Audio
134
  with st.spinner('🎧 Turning your story into audio...'):
135
- # Generate audio file if it doesn't already exist in the session state
136
- if 'audio_file' not in st.session_state:
137
- st.session_state.audio_file = text2audio(story)
138
-
139
  # Play button for the generated audio
140
  if st.button("🎡 **Play Audio**"):
141
  if os.path.exists(st.session_state.audio_file):
@@ -143,11 +163,9 @@ def main():
143
  else:
144
  st.error("Audio file not found. Please try again.") # Display error if audio file is missing
145
 
146
- # Clean up temporary files (uploaded image and generated audio)
147
  if os.path.exists(uploaded_file.name):
148
  os.remove(uploaded_file.name) # Delete the uploaded image file
149
- if 'audio_file' in st.session_state and os.path.exists(st.session_state.audio_file):
150
- os.remove(st.session_state.audio_file) # Delete the generated audio file
151
 
152
  # Run the application
153
  if __name__ == "__main__":
 
35
  # Function to generate a kid-friendly superhero story from the text caption
36
  def text2story(text):
37
  """
38
+ Generates a superhero story from the text caption using the pranavpsv/gpt2-genre-story-generator model.
39
 
40
  Args:
41
  text (str): Text caption generated from the image.
 
49
 
50
  # Generate the story with the superhero genre
51
  prompt = f"<BOS> <superhero> {text}" # Add genre tags to the prompt
52
+ story = story_generator(prompt, max_length=150, num_return_sequences=1)[0]['generated_text']
53
 
54
  # Remove <BOS> and <superhero> tags from the generated story
55
  story = story.replace("<BOS>", "").replace("<superhero>", "").strip()
 
61
  # Ensure the story is within 100 words by truncating if necessary
62
  story = " ".join(story.split()[:100])
63
 
64
+ # If the story is too short, regenerate it
65
+ if len(story.split()) < 50: # Minimum 50 words
66
+ story = story_generator(prompt, max_length=200, num_return_sequences=1)[0]['generated_text']
67
+ story = story.replace("<BOS>", "").replace("<superhero>", "").strip()
68
+ if text in story:
69
+ story = story.replace(text, "").strip()
70
+ story = " ".join(story.split()[:100])
71
+
72
  return story
73
  except Exception as e:
74
  st.error(f"Error generating story: {e}") # Display error message if something goes wrong
 
117
  # Upload image
118
  uploaded_file = st.file_uploader("πŸ“· **Upload your picture here!**", type=["jpg", "jpeg", "png"])
119
 
120
+ # Initialize session state variables
121
+ if 'scenario' not in st.session_state:
122
+ st.session_state.scenario = None
123
+ if 'story' not in st.session_state:
124
+ st.session_state.story = None
125
+ if 'audio_file' not in st.session_state:
126
+ st.session_state.audio_file = None
127
+
128
  if uploaded_file is not None:
129
  # Save the uploaded file to disk
130
  image_bytes = uploaded_file.getvalue()
 
136
 
137
  # Stage 1: Image to Text
138
  with st.spinner('✨ Turning your picture into words...'):
139
+ if st.session_state.scenario is None or uploaded_file.name != st.session_state.get('uploaded_file_name', None):
140
+ st.session_state.scenario = img2text(uploaded_file.name) # Generate text caption from the image
141
+ st.session_state.uploaded_file_name = uploaded_file.name # Store the uploaded file name
142
+
143
+ if st.session_state.scenario:
144
+ st.write("**What we see:**", st.session_state.scenario) # Display the generated caption
145
 
146
  # Stage 2: Text to Story
147
  with st.spinner('πŸ“– Creating a fun superhero story for you...'):
148
+ if st.session_state.story is None or uploaded_file.name != st.session_state.uploaded_file_name:
149
+ st.session_state.story = text2story(st.session_state.scenario) # Generate a superhero story from the caption
150
+
151
+ if st.session_state.story:
152
+ st.write("**Your superhero story:**", st.session_state.story) # Display the generated story
153
 
154
  # Stage 3: Story to Audio
155
  with st.spinner('🎧 Turning your story into audio...'):
156
+ if st.session_state.audio_file is None or uploaded_file.name != st.session_state.uploaded_file_name:
157
+ st.session_state.audio_file = text2audio(st.session_state.story) # Generate audio file
158
+
 
159
  # Play button for the generated audio
160
  if st.button("🎡 **Play Audio**"):
161
  if os.path.exists(st.session_state.audio_file):
 
163
  else:
164
  st.error("Audio file not found. Please try again.") # Display error if audio file is missing
165
 
166
+ # Clean up temporary files (uploaded image)
167
  if os.path.exists(uploaded_file.name):
168
  os.remove(uploaded_file.name) # Delete the uploaded image file
 
 
169
 
170
  # Run the application
171
  if __name__ == "__main__":