smtsead commited on
Commit
9955631
·
verified ·
1 Parent(s): 6070942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -48,11 +48,11 @@ def text2story(text):
48
  story_generator = pipeline("text-generation", model="pranavpsv/gpt2-genre-story-generator")
49
 
50
  # Generate the story with the superhero genre
51
- prompt = f"<BOS> <fun superhero> {text}" # Add genre tags to the prompt
52
  story = story_generator(prompt, max_length=150, num_return_sequences=1)[0]['generated_text']
53
 
54
  # Remove <BOS> and <superhero> tags from the generated story
55
- story = story.replace("<BOS>", "").replace("<superhero>", "").strip()
56
 
57
 
58
  # Ensure the story is within 100 words by truncating if necessary
@@ -61,7 +61,7 @@ def text2story(text):
61
  # If the story is too short, regenerate it
62
  if len(story.split()) < 50: # Minimum 50 words
63
  story = story_generator(prompt, max_length=200, num_return_sequences=1)[0]['generated_text']
64
- story = story.replace("<BOS>", "").replace("<superhero>", "").strip()
65
  story = " ".join(story.split()[:100])
66
 
67
  return story
 
48
  story_generator = pipeline("text-generation", model="pranavpsv/gpt2-genre-story-generator")
49
 
50
  # Generate the story with the superhero genre
51
+ prompt = f"<BOS> <funny superhero> {text}" # Add genre tags to the prompt
52
  story = story_generator(prompt, max_length=150, num_return_sequences=1)[0]['generated_text']
53
 
54
  # Remove <BOS> and <superhero> tags from the generated story
55
+ story = story.replace("<BOS>", "").replace("<funny superhero>", "").strip()
56
 
57
 
58
  # Ensure the story is within 100 words by truncating if necessary
 
61
  # If the story is too short, regenerate it
62
  if len(story.split()) < 50: # Minimum 50 words
63
  story = story_generator(prompt, max_length=200, num_return_sequences=1)[0]['generated_text']
64
+ story = story.replace("<BOS>", "").replace("<funny superhero>", "").strip()
65
  story = " ".join(story.split()[:100])
66
 
67
  return story