Spaces:
Sleeping
Sleeping
File size: 6,091 Bytes
3ee9503 d927270 3ee9503 3f58313 3ee9503 f1d41aa 3ee9503 d927270 52bfc97 d927270 3ee9503 3f58313 3ee9503 3f58313 d29e675 3ee9503 3f58313 3ee9503 3f58313 3ee9503 3f58313 3ee9503 3f58313 3ee9503 3f58313 3ee9503 f1d41aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub.file_download import http_get
from llama_cpp import Llama
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
# from langchain_core.prompts import ChatPromptTemplate
import os
import fal_client
# FastAPI app
app = FastAPI()
# Set the environment variable
os.environ['FAL_KEY'] = 'bb79b746-999d-4bec-af22-04fddb05d087:49350e8b76fd8dda0fb9dd8442a9ccf5'
# Request body model
class StoryRequest(BaseModel):
mood: str
story_type: str
theme: str
length: int
num_scenes: int
txt: str
# Initialize the LLM
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
def load_model(
directory: str = ".",
model_name: str = "natsumura-storytelling-rp-1.0-llama-3.1-8B.Q3_K_M.gguf",
model_url: str = "https://huggingface.co/tohur/natsumura-storytelling-rp-1.0-llama-3.1-8b-GGUF/resolve/main/natsumura-storytelling-rp-1.0-llama-3.1-8B.Q3_K_M.gguf"
):
final_model_path = os.path.join(directory, model_name)
print("Downloading all files...")
if not os.path.exists(final_model_path):
with open(final_model_path, "wb") as f:
http_get(model_url, f)
os.chmod(final_model_path, 0o777)
print("Files downloaded!")
model = Llama(
model_path=final_model_path,
n_ctx=1024
)
print("Model loaded!")
return model
llm = load_model()
# Create a prompt template
# system = """You are a helpful and creative assistant that specializes in generating engaging and imaginative stories for kids.
# Based on the user's provided mood, preferred story type, theme, age, and desired story length of 500-600 words, create a unique and captivating story.
# Always start with Story Title then generate a single story and dont ask for any feedback at the end just sign off with a cute closing inviting the reader
# to create another adventure soon!
# """
# system = """You are a helpful and creative assistant that specializes in generating engaging and imaginative short storie for kids.
# Based on the user's provided mood, preferred story type, theme, age, and desired story length of 500-600 words, create a unique and captivating story.
# Always start with Story Title then generate a single story.Storie begin on Page 1(also mention the all pages headings in bold) and end on Page 7.
# Total pages in storie are seven each page have one short paragraph and dont ask for any feedback at the end just sign off with a cute closing inviting the reader
# to create another adventure soon!
# """
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a helpful and creative assistant that specializes in generating engaging and imaginative short storie for kids.
Based on the user's provided mood, preferred story type, theme, age, and desired story length of 500-600 words, create a unique and captivating story.
Always start with Story Title then generate a single story.Storie begin on Page 1(also mention the all pages headings in bold) and end on Page 7.
Total pages in storie are seven each page have one short paragraph and dont ask for any feedback at the end just sign off with a cute closing inviting the reader
to create another adventure soon!
""",
),
MessagesPlaceholder(variable_name="messages"),
]
)
# prompt_template = ChatPromptTemplate.from_messages([("system", system), ("human", "{text}")])
# FastAPI endpoint to generate the story
@app.post("/generate_story/")
async def generate_story(story_request: StoryRequest):
chain = prompt | llm
story = f"""here are the inputs from user:
- **Mood:** {story_request.mood}
- **Story Type:** {story_request.story_type}
- **Theme:** {story_request.theme}
- **Details Provided:** {story_request.txt}
"""
response = chain.invoke({"messages": [HumanMessage(content=story)]})
# final_prompt = prompt_template.format(text=story)
# Create the LLMChain
# chain = LLMChain(llm=llm, prompt=prompt_template)
# chain = llm | prompt_template
# try:
# response = chain.invoke(final_prompt)
# return {"story": response}
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
# response = chain.invoke(final_prompt)
if not response:
raise HTTPException(status_code=500, detail="Failed to generate the story")
images = []
for i in range(story_request.num_scenes):
# image_prompt = f"Generate an image for Scene {i+1} based on this story: Mood: {story_request.mood}, Story Type: {story_request.story_type}, Theme: {story_request.theme}. Story: {response}"
image_prompt = (
f"Generate an image for Scene {i+1}. "
f"This image should represent the details described in paragraph {i+1} of the story. "
f"Mood: {story_request.mood}, Story Type: {', '.join(story_request.story_type)}, Theme: {story_request.theme}. "
f"Story: {response} "
f"Focus on the key elements in paragraph {i+1}."
)
handler = fal_client.submit(
"fal-ai/flux/schnell",
arguments={
"prompt": image_prompt,
"num_images": 1,
"enable_safety_checker": True
},
)
result = handler.get()
image_url = result['images'][0]['url']
images.append(image_url)
return {
"story": response,
"images": images
}
# image_prompt = (
# f"Generate an image for Scene {i+1}. "
# f"This image should represent the details described in paragraph {i+1} of the story. "
# f"Mood: {mood}, Story Type: {', '.join(story_type)}, Theme: {theme}. "
# f"Story: {response} "
# f"Focus on the key elements in paragraph {i+1}."
# )
|