File size: 18,332 Bytes
2be2394 f9a665c 778e6dd e4ba08d b1abf7a 2be2394 f9a665c a56d7be b1abf7a a56d7be b1abf7a f9a665c f2d6482 a113176 f9a665c 013c0f5 b1abf7a 013c0f5 96023a2 f9a665c 2be2394 b1abf7a f496ef0 a0a8457 a56d7be f496ef0 a56d7be f496ef0 b1abf7a f9a665c b1abf7a e4ba08d b1abf7a e4ba08d b1abf7a a113176 b1abf7a a113176 b1abf7a a56d7be b1abf7a 14839a7 a113176 b1abf7a a56d7be b1abf7a f496ef0 b1abf7a f496ef0 a56d7be a113176 b1abf7a f9a665c a113176 b1abf7a a113176 b1abf7a e4ba08d a113176 e4ba08d a56d7be a113176 f2d6482 b1abf7a f2d6482 b1abf7a f2d6482 a56d7be a113176 e4ba08d a113176 b1abf7a f2d6482 96023a2 2be2394 e4ba08d f9a665c f496ef0 a113176 b1abf7a e4ba08d 013c0f5 e4ba08d b1abf7a e4ba08d b1abf7a e4ba08d 013c0f5 b1abf7a e4ba08d b1abf7a f2d6482 013c0f5 b1abf7a a113176 b1abf7a 013c0f5 b1abf7a 013c0f5 b1abf7a 013c0f5 a113176 b1abf7a a113176 f2d6482 b1abf7a 013c0f5 e4ba08d b1abf7a f2d6482 f9a665c 013c0f5 b1abf7a 013c0f5 b1abf7a 013c0f5 f9a665c f496ef0 a56d7be f496ef0 f9a665c f496ef0 b1abf7a f9a665c b1abf7a a56d7be b1abf7a f9a665c a56d7be f9a665c a56d7be f9a665c e4ba08d f496ef0 b1abf7a a56d7be b1abf7a a56d7be b1abf7a f496ef0 b1abf7a f496ef0 b1abf7a f496ef0 b1abf7a e4ba08d a113176 2be2394 f9a665c b1abf7a f9a665c 013c0f5 f9a665c b1abf7a e4ba08d a113176 b1abf7a 013c0f5 b1abf7a 2be2394 b1abf7a f9a665c f496ef0 f9a665c f496ef0 f9a665c b1abf7a e4ba08d b1abf7a f9a665c b1abf7a f9a665c a113176 f496ef0 b1abf7a f9a665c b1abf7a f9a665c b1abf7a a113176 f9a665c b1abf7a f9a665c b1abf7a f9a665c b1abf7a a113176 b1abf7a f9a665c b1abf7a 013c0f5 f9a665c b1abf7a f9a665c b1abf7a a113176 b1abf7a f9a665c a113176 b1abf7a f9a665c b1abf7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
# Keep these if you use them elsewhere in your app (HTML, static files)
# from fastapi.responses import HTMLResponse
# from fastapi.staticfiles import StaticFiles
# from fastapi.templating import Jinja2Templates
# from fastapi.responses import FileResponse
# Removed 'requests' as we are using gradio_client
# import requests
import base64 # Keep if needed elsewhere (not strictly needed for this version)
import os
import random
# Removed unused IO import
# from typing import IO
# Import necessary classes from transformers (Keeping only AutoTokenizer)
from transformers import AutoTokenizer
# Import necessary modules for llama-cpp-python and downloading from Hub
from llama_cpp import Llama # The core Llama class
from huggingface_hub import hf_hub_download # For downloading GGUF files
# Import the Gradio Client and handle_file
from gradio_client import Client, handle_file
# Import necessary modules for temporary file handling
import tempfile
# shutil is not strictly necessary for this version, os.remove is sufficient
# import shutil
from deep_translator import GoogleTranslator
from deep_translator.exceptions import InvalidSourceOrTargetLanguage
app = FastAPI()
# --- Llama.cpp Language Model Setup (Local CPU Inference) ---
# Repository on Hugging Face Hub containing the Qwen1.5 0.5B GGUF file
# Using the OFFICIAL Qwen 0.5B repository shown in the user's image:
LLM_MODEL_REPO = "Qwen/Qwen1.5-0.5B-Chat-GGUF" # Updated to official 0.5B repo
# Specify the filename for a Q4_K_M quantized version (good balance of speed/quality on CPU)
# Based on DIRECT VERIFICATION from the user's IMAGE of the 0.5B repo:
LLM_MODEL_FILE = "qwen1_5-0_5b-chat-q4_k_m.gguf" # Exact filename from the 0.5B repo image
# Original model name for the tokenizer (needed by transformers)
# This points to the base model repository for the tokenizer files.
ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" # Updated to the 0.5B Chat model
tokenizer = None # Using transformers tokenizer for chat templating
llm_model = None # This will hold the llama_cpp.Llama instance
# --- Hugging Face Gradio Space Client Setup (For External Image Captioning) ---
# Global Gradio Client for Captioning
caption_client = None
# The URL of the external Gradio Space for image captioning
CAPTION_SPACE_URL = "Makhinur/Image-to-Text-Salesforce-blip-image-captioning-base"
# Function to load the language model (GGUF via llama.cpp) and its tokenizer (from transformers)
def load_language_model():
global tokenizer, llm_model
print(f"Loading language model: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...")
try:
# --- Load Tokenizer (using transformers) ---
# Load the tokenizer from the original model repo
print(f"Loading tokenizer from original model repo: {ORIGINAL_MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL_NAME)
# Set pad_token if not already defined, often necessary for correct batching/generation behavior
# Qwen tokenizers should have pad_token, but this check is robust
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
elif tokenizer.unk_token is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
# Fallback if neither exists (very rare)
print("Warning: Neither EOS nor UNK token found for tokenizer. Setting pad_token to None.")
tokenizer.pad_token = None
# --- Download GGUF model file (using huggingface_hub) ---
print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...")
model_path = hf_hub_download(
repo_id=LLM_MODEL_REPO,
filename=LLM_MODEL_FILE,
# cache_dir="/tmp/hf_cache" # Optional: specify a custom cache directory
)
print(f"GGUF model downloaded to: {model_path}")
# --- Load the GGUF model (using llama-cpp-python) ---
print(f"Loading GGUF model into llama_cpp...")
# Instantiate the Llama model from the downloaded GGUF file
# n_gpu_layers=0: Crucial for forcing CPU-only inference
# n_ctx: Context window size (tokens model can consider), match model's spec if possible (Qwen1.5 0.5B has a smaller context than 1.8B, maybe 4096 or 8192 is standard)
# n_threads: Number of CPU threads to use. Set to your vCPU count (2) for better performance.
llm_model = Llama(
model_path=model_path,
n_gpu_layers=0, # Explicitly use CPU
n_ctx=4096, # Context window size (4096 is a common safe value)
n_threads=2 # Use 2 CPU threads
)
print("Llama.cpp model loaded successfully.")
except Exception as e:
print(f"Error loading language model {LLM_MODEL_REPO}/{LLM_MODEL_FILE}: {e}")
tokenizer = None
llm_model = None # Ensure the model is None if loading fails
# Function to initialize the Gradio Client for the captioning Space
def initialize_caption_client():
global caption_client
print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...")
try:
# If the target Gradio Space requires authentication (e.g., private)
# store HF_TOKEN as a Space Secret and uncomment these lines.
# HF_TOKEN = os.environ.get("HF_TOKEN")
# if HF_TOKEN:
# print("Using HF_TOKEN for Gradio client.")
# caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN)
# else:
# print("HF_TOKEN not found. Initializing public Gradio client.")
# caption_client = Client(CAPTION_SPACE_URL)
# Assuming the caption space is public
caption_client = Client(CAPTION_SPACE_URL)
print("Gradio client initialized successfully.")
except Exception as e:
print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}")
# Set client to None so the endpoint can check and return an error
caption_client = None
# Load models and initialize clients when the app starts
@app.on_event("startup")
async def startup_event():
# Load the language model (Qwen1.5 0.5B GGUF via llama.cpp)
load_language_model()
# Initialize the client for the external captioning Space
initialize_caption_client()
# --- Image Captioning Function (Using gradio_client and temporary file) ---
def generate_image_caption(image_file: UploadFile):
"""
Generates a caption for the uploaded image using the external Gradio Space API.
Reads the uploaded file's content, saves it to a temporary file,
and uses the temporary file's path with handle_file for the API call.
"""
if caption_client is None:
# If the client failed to initialize at startup
error_msg = "Gradio caption client not initialized. Cannot generate caption."
print(error_msg)
return f"Error: {error_msg}"
temp_file_path = None # Variable to store the path of the temporary file
try:
print(f"Attempting to generate caption for file: {image_file.filename}")
# Read the content of the uploaded file
# Seek to the beginning just in case the file-like object's pointer was moved
image_file.file.seek(0)
image_bytes = image_file.file.read()
# Create a temporary file on the local filesystem
# delete=False ensures the file persists after closing the handle
# suffix helps hint at the file type for the Gradio API
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image_file.filename)[1] or '.jpg')
temp_file.write(image_bytes)
temp_file.close() # Close the file handle so gradio_client can access the file
temp_file_path = temp_file.name # Get the full path to the temporary file
print(f"Saved uploaded file temporarily to: {temp_file_path}")
# Use handle_file() with the path string to the temporary file.
# This correctly prepares the file for the Gradio API input.
prepared_input = handle_file(temp_file_path)
# Call the predict method on the initialized gradio_client
# api_name="/predict" matches the endpoint specified in the Gradio API docs
caption = caption_client.predict(img=prepared_input, api_name="/predict")
print(f"Caption generated successfully.")
# Return the caption string received from the API
return caption
except Exception as e:
# Catch any exceptions that occur during reading, writing, or the API call
print(f"Error during caption generation API call: {e}") # Log the error details server-side
# Return a structured error string including the exception type and message
return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}"
finally:
# Clean up the temporary file regardless of whether the process succeeded or failed
if temp_file_path and os.path.exists(temp_file_path):
print(f"Cleaning up temporary file: {temp_file_path}")
try:
os.remove(temp_file_path) # Delete the file using its path
except OSError as e:
print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors
# --- Language Model Story Generation Function (Qwen1.5 0.5B via llama.cpp) ---
# Renamed function to reflect the model being used
def generate_story_qwen_0_5b(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
"""
Generates text using the loaded Qwen1.5 0.5B model via llama.cpp.
Uses the tokenizer to apply the chat template and calls llama.cpp's chat completion.
"""
# Check if the language model was loaded successfully at startup
# Check for both tokenizer and llm_model (llama.cpp instance)
if tokenizer is None or llm_model is None:
# Raise a RuntimeError which is caught by the calling endpoint
raise RuntimeError("Language model (llama.cpp) or tokenizer not loaded.")
# Construct the messages list following the chat format for Qwen1.5 Chat
# Qwen models use a standard ChatML-like format.
messages = [
# System message is optional but can help guide the model's persona/style
# {"role": "system", "content": "You are a helpful and creative assistant."}
{"role": "user", "content": prompt_text}
]
try:
print("Calling llama.cpp create_chat_completion for Qwen 0.5B...")
# Call the create_chat_completion method from llama_cpp.Llama instance
# This method handles the chat templating internally for models like Qwen.
# max_tokens is the max number of tokens to generate
# temperature, top_p control sampling. top_k might not be a direct parameter.
response = llm_model.create_chat_completion(
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
# top_k is sometimes supported, but check llama-cpp-python docs if needed
# top_k=top_k,
stream=False # We want the full response at once
)
print("Llama.cpp completion received for Qwen 0.5B.")
# Parse the response to get the generated text content
# The response structure is typically like OpenAI's chat API response
if response and response.get('choices') and len(response['choices']) > 0:
story = response['choices'][0].get('message', {}).get('content', '')
else:
# Handle cases where the response is empty or has an unexpected structure
print("Warning: Llama.cpp Qwen 0.5B response structure unexpected or content missing.")
story = "" # Return an empty string if content is not found
except Exception as e:
# Catch any exception that occurs during the llama.cpp inference process
print(f"Llama.cpp Qwen 0.5B inference failed: {e}") # Log the error server-side
# Re-raise as a RuntimeError to indicate failure to the endpoint
raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}")
# Return the generated story text, removing leading/trailing whitespace
return story.strip()
# --- FastAPI Endpoint for Story Generation ---
@app.post("/generate-story/")
async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
# Choose a random theme for the story prompt
story_theme = random.choice([
'an adventurous journey', 'a mysterious encounter', 'a heroic quest',
'a magical adventure', 'a thrilling escape', 'an unexpected discovery',
'a dangerous mission', 'a romantic escapade', 'an epic battle',
'a journey into the unknown'
])
# Step 1: Get image caption using the external Gradio API via gradio_client
# Pass the UploadFile object directly to the captioning function
caption = generate_image_caption(image_file)
# Check if caption generation returned an error string
if caption.startswith("Error:"):
print(f"Caption generation failed: {caption}") # Log the error detail server-side
# Raise an HTTPException with a 500 status code and the error message
raise HTTPException(status_code=500, detail=caption)
# Step 2: Construct the prompt text for the language model
# This prompt instructs the model on what to write and incorporates the caption.
prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
# Step 3: Generate the story using the local language model (Qwen 0.5B via llama.cpp)
try:
# Call the Qwen 0.5B story generation function
story = generate_story_qwen_0_5b( # <--- Use the updated function name
prompt_text,
max_new_tokens=300, # Request ~300 new tokens
temperature=0.7, # Sampling parameters
top_p=0.9,
top_k=50 # Note: top_k may not be directly used by llama_cpp.create_chat_completion
)
story = story.strip() # Basic cleanup of generated story text
except RuntimeError as e:
# Catch specific RuntimeError raised by generate_story_qwen_0_5b if LLM loading or inference fails
print(f"Language model generation error: {e}") # Log the error server-side
# Return a 503 Service Unavailable error if the LLM is not available or failed
raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}")
except Exception as e:
# Catch any other unexpected errors during story generation
print(f"An unexpected error occurred during story generation: {e}") # Log server-side
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during story generation: {type(e).__name__}: {e}")
# Step 4: Translate the generated story if the target language is not English
# Check if language is provided and not English (case-insensitive)
if language and language.lower() != "english":
try:
# Initialize GoogleTranslator with English source and requested target language
translator = GoogleTranslator(source='english', target=language.lower())
# Perform the translation
translated_story = translator.translate(story)
# Check if translation returned None or an empty string (indicates failure)
if translated_story is None or translated_story == "":
print(f"Translation returned None or empty string for language: {language}") # Log failure
# If translation fails, return the original English story with a warning
return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}
# If translation was successful, use the translated text
story = translated_story
except InvalidSourceOrTargetLanguage:
print(f"Invalid target language requested: {language}") # Log invalid language
raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
except Exception as e:
# Catch any other errors during translation (e.g., network issues, API problems)
print(f"Translation failed for language {language}: {e}") # Log server-side
raise HTTPException(status_code=500, detail=f"Translation failed: {type(e).__name__}: {e}")
# Step 5: Return the final generated (and potentially translated) story as a JSON response
return {"story": story}
# --- Optional: Serve a simple HTML form for testing ---
# To use this, uncomment the imports related to HTMLResponse, StaticFiles, Jinja2Templates, Request
# at the top of the file, and create a 'templates' directory with an 'index.html' file.
# from fastapi import Request
# from fastapi.templating import Jinja2Templates
# from fastapi.staticfiles import StaticFiles
# templates = Jinja2Templates(directory="templates")
# app.mount("/static", StaticFiles(directory="static"), name="static")
# @app.get("/", response_class=HTMLResponse)
# async def read_root(request: Request):
# # Simple HTML form to upload an image and specify language
# html_content = """
# <!DOCTYPE html>
# <html>
# <head><title>Story Generator</title></head>
# <body>
# <h1>Generate a Story from an Image</h1>
# <form action="/generate-story/" method="post" enctype="multipart/form-data">
# <input type="file" name="image_file" accept="image/*" required><br><br>
# Target Language (e.g., english, french, spanish): <input type="text" name="language" value="english"><br><br>
# <button type="submit">Generate Story</button>
# </form>
# </body>
# </html>
# """
# # If using templates: return templates.TemplateResponse("index.html", {"request": request})
# return HTMLResponse(content=html_content) # Using direct HTML for simplicity if templates not set up |