koey811 commited on
Commit
dc29be0
·
verified ·
1 Parent(s): 1a7366d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -1,30 +1,34 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
  from gtts import gTTS
4
- import os
 
 
 
 
 
 
 
5
 
6
  def generate_caption(image):
7
- # Load the image captioning model
8
- caption_model = pipeline("image-to-text", model="facebook/blip-image-captioning-base")
9
-
10
  # Generate the caption for the uploaded image
11
  caption = caption_model(image)[0]["generated_text"]
12
-
13
  return caption
14
 
15
  def generate_story(caption):
16
- # Load the text generation model
17
- text_generation_model = pipeline("text-generation", model="gpt2")
18
-
19
  # Generate the story based on the caption
20
- story = text_generation_model(caption, max_length=200, num_return_sequences=1)[0]["generated_text"]
21
-
 
22
  return story
23
 
24
  def convert_to_audio(story):
25
  # Convert the story to audio using gTTS
26
  tts = gTTS(text=story, lang="en")
27
- tts.save("story_audio.mp3")
 
 
 
28
 
29
  def main():
30
  st.title("Storytelling Application")
@@ -47,11 +51,9 @@ def main():
47
  st.write(story)
48
 
49
  # Convert the story to audio
50
- convert_to_audio(story)
51
 
52
  # Display the audio player
53
- audio_file = open("story_audio.mp3", "rb")
54
- audio_bytes = audio_file.read()
55
  st.audio(audio_bytes, format="audio/mp3")
56
 
57
  if __name__ == "__main__":
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  from gtts import gTTS
4
+ import io
5
+
6
+ # Load the image captioning model
7
+ caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
8
+
9
+ # Load the text generation model
10
+ text_generation_model = AutoModelForCausalLM.from_pretrained("gpt2")
11
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
12
 
13
  def generate_caption(image):
 
 
 
14
  # Generate the caption for the uploaded image
15
  caption = caption_model(image)[0]["generated_text"]
 
16
  return caption
17
 
18
  def generate_story(caption):
 
 
 
19
  # Generate the story based on the caption
20
+ input_ids = tokenizer.encode(caption, return_tensors="pt")
21
+ output = text_generation_model.generate(input_ids, max_length=200, num_return_sequences=1)
22
+ story = tokenizer.decode(output[0], skip_special_tokens=True)
23
  return story
24
 
25
  def convert_to_audio(story):
26
  # Convert the story to audio using gTTS
27
  tts = gTTS(text=story, lang="en")
28
+ audio_bytes = io.BytesIO()
29
+ tts.write_to_fp(audio_bytes)
30
+ audio_bytes.seek(0)
31
+ return audio_bytes
32
 
33
  def main():
34
  st.title("Storytelling Application")
 
51
  st.write(story)
52
 
53
  # Convert the story to audio
54
+ audio_bytes = convert_to_audio(story)
55
 
56
  # Display the audio player
 
 
57
  st.audio(audio_bytes, format="audio/mp3")
58
 
59
  if __name__ == "__main__":