|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
model_name = "EleutherAI/gpt-neo-1.3B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name).to("cpu") |
|
|
|
|
|
def remove_repeated_sentences(text): |
|
sentences = text.split(". ") |
|
unique_sentences = [] |
|
seen = set() |
|
for sentence in sentences: |
|
if sentence not in seen: |
|
unique_sentences.append(sentence) |
|
seen.add(sentence) |
|
return ". ".join(unique_sentences) |
|
|
|
|
|
def generate_text(prompt, max_length=300, temperature=0.5, top_p=0.9, top_k=50, repetition_penalty=1.2): |
|
try: |
|
max_input_length = 2048 - max_length |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_input_length) |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return remove_repeated_sentences(generated_text) |
|
except Exception as e: |
|
return f"Error during generation: {str(e)}" |
|
|
|
|
|
global_synopsis = "" |
|
global_chapters = "" |
|
|
|
|
|
def generate_synopsis(topic): |
|
global global_synopsis |
|
try: |
|
prompt = f"Write a brief synopsis for a story about {topic}. Avoid repeating ideas or phrases. Keep the synopsis concise and clear." |
|
global_synopsis = generate_text(prompt, max_length=300) |
|
return global_synopsis |
|
except Exception as e: |
|
return f"Error generating synopsis: {str(e)}" |
|
|
|
def generate_chapters(): |
|
global global_synopsis, global_chapters |
|
if not global_synopsis: |
|
return "Please generate a synopsis first." |
|
try: |
|
prompt = f'''Based on this synopsis for a book: {global_synopsis}. Divide the story into 4 chapters with brief descriptions for each. |
|
Enumerate every chapter created followed by its description. Make the first chapter sound like an introduction and the last as the epilogue. |
|
Keep each title and description pair under 500 characters.''' |
|
global_chapters = generate_text(prompt, max_length=700) |
|
return global_chapters |
|
except Exception as e: |
|
return f"Error generating chapters: {str(e)}" |
|
|
|
def expand_chapter(chapter_number): |
|
global global_chapters |
|
if not global_chapters: |
|
return "Please generate chapters first." |
|
try: |
|
chapters = global_chapters.split("\n") |
|
if chapter_number <= 0 or chapter_number > len(chapters): |
|
return f"Select a number between 1 and {len(chapters)}." |
|
prompt = f'''Knowing this synopsis for a book: {global_synopsis}. Expand and describe Chapter {chapter_number} |
|
in more detail. The title and current brief description of this chapter is: {chapters[chapter_number - 1]}''' |
|
return generate_text(prompt, max_length=500) |
|
except Exception as e: |
|
return f"Error expanding chapter: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## AI Hierarchical Story Generator") |
|
|
|
with gr.Tab("Generate Synopsis"): |
|
topic_input = gr.Textbox(label="Enter the story's main topic") |
|
synopsis_output = gr.Textbox(label="Generated Synopsis", interactive=False) |
|
synopsis_button = gr.Button("Generate Synopsis") |
|
|
|
with gr.Tab("Generate Chapters"): |
|
chapters_output = gr.Textbox(label="Generated Chapters", interactive=False) |
|
chapters_button = gr.Button("Generate Chapters") |
|
|
|
with gr.Tab("Expand Chapter"): |
|
chapter_input = gr.Number(label="Chapter Number", precision=0) |
|
chapter_detail_output = gr.Textbox(label="Expanded Chapter", interactive=False) |
|
chapter_button = gr.Button("Expand Chapter") |
|
|
|
|
|
synopsis_button.click(generate_synopsis, inputs=topic_input, outputs=synopsis_output) |
|
chapters_button.click(generate_chapters, outputs=chapters_output) |
|
chapter_button.click(expand_chapter, inputs=chapter_input, outputs=chapter_detail_output) |
|
|
|
|
|
demo.launch() |
|
|
|
|