koey811 commited on
Commit
7f3369b
·
verified ·
1 Parent(s): 510b489

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -17,7 +17,9 @@ except ImportError:
17
  # Load the image captioning model
18
  caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap")
19
 
20
- story_generator = pipeline("text-generation", model="distilbert/distilgpt2")
 
 
21
 
22
  def generate_caption(image):
23
  # Generate the caption for the uploaded image
@@ -26,11 +28,15 @@ def generate_caption(image):
26
 
27
  def generate_story(caption):
28
  # Generate the story based on the caption using the GPT-2 model
29
- prompt = f"Imagine a delightful children's fairy tale inspired by the image described as '{caption}'. Itnshould be interesting, easy to understand, and use age-appropriate language suitable for children aged 3-10. Let the magical story unfold:\n\n"
30
  story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
31
 
32
  # Extract the story text from the generated output
33
- story = story.split("\n\n")[1].strip()
 
 
 
 
34
 
35
  # Post-process the story (example: remove inappropriate words)
36
  inappropriate_words = ["violence", "horror", "scary"]
 
17
  # Load the image captioning model
18
  caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap")
19
 
20
+ #story_generator = pipeline("text-generation", model="distilbert/distilgpt2")
21
+
22
+ story_generator = pipeline("text-generation", model="isarth/distill_gpt2_story_generator")
23
 
24
  def generate_caption(image):
25
  # Generate the caption for the uploaded image
 
28
 
29
  def generate_story(caption):
30
  # Generate the story based on the caption using the GPT-2 model
31
+ prompt = f"Imagine a delightful children's fairy tale inspired by the image described as '{caption}'. The story should be enchanting, easy to understand, and use age-appropriate language suitable for children aged 3-10. Let the magical story unfold:\n\n"
32
  story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
33
 
34
  # Extract the story text from the generated output
35
+ story_parts = story.split("\n\n")
36
+ if len(story_parts) > 1:
37
+ story = "\n\n".join(story_parts[1:]).strip()
38
+ else:
39
+ story = story_parts[0].strip()
40
 
41
  # Post-process the story (example: remove inappropriate words)
42
  inappropriate_words = ["violence", "horror", "scary"]