from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
# Keep these if you use them elsewhere in your app (HTML, static files) | |
# from fastapi.responses import HTMLResponse | |
# from fastapi.staticfiles import StaticFiles | |
# from fastapi.templating import Jinja2Templates | |
# from fastapi.responses import FileResponse | |
# Removed 'requests' as we are using gradio_client | |
# import requests | |
import base64 # Keep if needed elsewhere (not strictly needed for this version) | |
import os | |
import random | |
# Removed unused IO import | |
# from typing import IO | |
# Import necessary classes from transformers (Keeping only AutoTokenizer) | |
from transformers import AutoTokenizer | |
# Import necessary modules for llama-cpp-python and downloading from Hub | |
from llama_cpp import Llama # The core Llama class | |
from huggingface_hub import hf_hub_download # For downloading GGUF files | |
# Import the Gradio Client and handle_file | |
from gradio_client import Client, handle_file | |
# Import necessary modules for temporary file handling | |
import tempfile | |
# shutil is not strictly necessary for this version, os.remove is sufficient | |
# import shutil | |
from deep_translator import GoogleTranslator | |
from deep_translator.exceptions import InvalidSourceOrTargetLanguage | |
app = FastAPI() | |
# --- Llama.cpp Language Model Setup (Local CPU Inference) --- | |
# Repository on Hugging Face Hub containing the Qwen1.5 0.5B GGUF file | |
# Using the OFFICIAL Qwen 0.5B repository shown in the user's image: | |
LLM_MODEL_REPO = "Qwen/Qwen1.5-0.5B-Chat-GGUF" # Updated to official 0.5B repo | |
# Specify the filename for a Q4_K_M quantized version (good balance of speed/quality on CPU) | |
# Based on DIRECT VERIFICATION from the user's IMAGE of the 0.5B repo: | |
LLM_MODEL_FILE = "qwen1_5-0_5b-chat-q4_k_m.gguf" # Exact filename from the 0.5B repo image | |
# Original model name for the tokenizer (needed by transformers) | |
# This points to the base model repository for the tokenizer files. | |
ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" # Updated to the 0.5B Chat model | |
tokenizer = None # Using transformers tokenizer for chat templating | |
llm_model = None # This will hold the llama_cpp.Llama instance | |
# --- Hugging Face Gradio Space Client Setup (For External Image Captioning) --- | |
# Global Gradio Client for Captioning | |
caption_client = None | |
# The URL of the external Gradio Space for image captioning | |
CAPTION_SPACE_URL = "Makhinur/Image-to-Text-Salesforce-blip-image-captioning-base" | |
# Function to load the language model (GGUF via llama.cpp) and its tokenizer (from transformers) | |
def load_language_model(): | |
global tokenizer, llm_model | |
print(f"Loading language model: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...") | |
try: | |
# --- Load Tokenizer (using transformers) --- | |
# Load the tokenizer from the original model repo | |
print(f"Loading tokenizer from original model repo: {ORIGINAL_MODEL_NAME}...") | |
tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL_NAME) | |
# Set pad_token if not already defined, often necessary for correct batching/generation behavior | |
# Qwen tokenizers should have pad_token, but this check is robust | |
if tokenizer.pad_token is None: | |
if tokenizer.eos_token is not None: | |
tokenizer.pad_token = tokenizer.eos_token | |
elif tokenizer.unk_token is not None: | |
tokenizer.pad_token = tokenizer.unk_token | |
else: | |
# Fallback if neither exists (very rare) | |
print("Warning: Neither EOS nor UNK token found for tokenizer. Setting pad_token to None.") | |
tokenizer.pad_token = None | |
# --- Download GGUF model file (using huggingface_hub) --- | |
print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...") | |
model_path = hf_hub_download( | |
repo_id=LLM_MODEL_REPO, | |
filename=LLM_MODEL_FILE, | |
# cache_dir="/tmp/hf_cache" # Optional: specify a custom cache directory | |
) | |
print(f"GGUF model downloaded to: {model_path}") | |
# --- Load the GGUF model (using llama-cpp-python) --- | |
print(f"Loading GGUF model into llama_cpp...") | |
# Instantiate the Llama model from the downloaded GGUF file | |
# n_gpu_layers=0: Crucial for forcing CPU-only inference | |
# n_ctx: Context window size (tokens model can consider), match model's spec if possible (Qwen1.5 0.5B has a smaller context than 1.8B, maybe 4096 or 8192 is standard) | |
# n_threads: Number of CPU threads to use. Set to your vCPU count (2) for better performance. | |
llm_model = Llama( | |
model_path=model_path, | |
n_gpu_layers=0, # Explicitly use CPU | |
n_ctx=4096, # Context window size (4096 is a common safe value) | |
n_threads=2 # Use 2 CPU threads | |
) | |
print("Llama.cpp model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading language model {LLM_MODEL_REPO}/{LLM_MODEL_FILE}: {e}") | |
tokenizer = None | |
llm_model = None # Ensure the model is None if loading fails | |
# Function to initialize the Gradio Client for the captioning Space | |
def initialize_caption_client(): | |
global caption_client | |
print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...") | |
try: | |
# If the target Gradio Space requires authentication (e.g., private) | |
# store HF_TOKEN as a Space Secret and uncomment these lines. | |
# HF_TOKEN = os.environ.get("HF_TOKEN") | |
# if HF_TOKEN: | |
# print("Using HF_TOKEN for Gradio client.") | |
# caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN) | |
# else: | |
# print("HF_TOKEN not found. Initializing public Gradio client.") | |
# caption_client = Client(CAPTION_SPACE_URL) | |
# Assuming the caption space is public | |
caption_client = Client(CAPTION_SPACE_URL) | |
print("Gradio client initialized successfully.") | |
except Exception as e: | |
print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}") | |
# Set client to None so the endpoint can check and return an error | |
caption_client = None | |
# Load models and initialize clients when the app starts | |
async def startup_event(): | |
# Load the language model (Qwen1.5 0.5B GGUF via llama.cpp) | |
load_language_model() | |
# Initialize the client for the external captioning Space | |
initialize_caption_client() | |
# --- Image Captioning Function (Using gradio_client and temporary file) --- | |
def generate_image_caption(image_file: UploadFile): | |
""" | |
Generates a caption for the uploaded image using the external Gradio Space API. | |
Reads the uploaded file's content, saves it to a temporary file, | |
and uses the temporary file's path with handle_file for the API call. | |
""" | |
if caption_client is None: | |
# If the client failed to initialize at startup | |
error_msg = "Gradio caption client not initialized. Cannot generate caption." | |
print(error_msg) | |
return f"Error: {error_msg}" | |
temp_file_path = None # Variable to store the path of the temporary file | |
try: | |
print(f"Attempting to generate caption for file: {image_file.filename}") | |
# Read the content of the uploaded file | |
# Seek to the beginning just in case the file-like object's pointer was moved | |
image_file.file.seek(0) | |
image_bytes = image_file.file.read() | |
# Create a temporary file on the local filesystem | |
# delete=False ensures the file persists after closing the handle | |
# suffix helps hint at the file type for the Gradio API | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image_file.filename)[1] or '.jpg') | |
temp_file.write(image_bytes) | |
temp_file.close() # Close the file handle so gradio_client can access the file | |
temp_file_path = temp_file.name # Get the full path to the temporary file | |
print(f"Saved uploaded file temporarily to: {temp_file_path}") | |
# Use handle_file() with the path string to the temporary file. | |
# This correctly prepares the file for the Gradio API input. | |
prepared_input = handle_file(temp_file_path) | |
# Call the predict method on the initialized gradio_client | |
# api_name="/predict" matches the endpoint specified in the Gradio API docs | |
caption = caption_client.predict(img=prepared_input, api_name="/predict") | |
print(f"Caption generated successfully.") | |
# Return the caption string received from the API | |
return caption | |
except Exception as e: | |
# Catch any exceptions that occur during reading, writing, or the API call | |
print(f"Error during caption generation API call: {e}") # Log the error details server-side | |
# Return a structured error string including the exception type and message | |
return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}" | |
finally: | |
# Clean up the temporary file regardless of whether the process succeeded or failed | |
if temp_file_path and os.path.exists(temp_file_path): | |
print(f"Cleaning up temporary file: {temp_file_path}") | |
try: | |
os.remove(temp_file_path) # Delete the file using its path | |
except OSError as e: | |
print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors | |
# --- Language Model Story Generation Function (Qwen1.5 0.5B via llama.cpp) --- | |
# Renamed function to reflect the model being used | |
def generate_story_qwen_0_5b(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str: | |
""" | |
Generates text using the loaded Qwen1.5 0.5B model via llama.cpp. | |
Uses the tokenizer to apply the chat template and calls llama.cpp's chat completion. | |
""" | |
# Check if the language model was loaded successfully at startup | |
# Check for both tokenizer and llm_model (llama.cpp instance) | |
if tokenizer is None or llm_model is None: | |
# Raise a RuntimeError which is caught by the calling endpoint | |
raise RuntimeError("Language model (llama.cpp) or tokenizer not loaded.") | |
# Construct the messages list following the chat format for Qwen1.5 Chat | |
# Qwen models use a standard ChatML-like format. | |
messages = [ | |
# System message is optional but can help guide the model's persona/style | |
# {"role": "system", "content": "You are a helpful and creative assistant."} | |
{"role": "user", "content": prompt_text} | |
] | |
try: | |
print("Calling llama.cpp create_chat_completion for Qwen 0.5B...") | |
# Call the create_chat_completion method from llama_cpp.Llama instance | |
# This method handles the chat templating internally for models like Qwen. | |
# max_tokens is the max number of tokens to generate | |
# temperature, top_p control sampling. top_k might not be a direct parameter. | |
response = llm_model.create_chat_completion( | |
messages=messages, | |
max_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
# top_k is sometimes supported, but check llama-cpp-python docs if needed | |
# top_k=top_k, | |
stream=False # We want the full response at once | |
) | |
print("Llama.cpp completion received for Qwen 0.5B.") | |
# Parse the response to get the generated text content | |
# The response structure is typically like OpenAI's chat API response | |
if response and response.get('choices') and len(response['choices']) > 0: | |
story = response['choices'][0].get('message', {}).get('content', '') | |
else: | |
# Handle cases where the response is empty or has an unexpected structure | |
print("Warning: Llama.cpp Qwen 0.5B response structure unexpected or content missing.") | |
story = "" # Return an empty string if content is not found | |
except Exception as e: | |
# Catch any exception that occurs during the llama.cpp inference process | |
print(f"Llama.cpp Qwen 0.5B inference failed: {e}") # Log the error server-side | |
# Re-raise as a RuntimeError to indicate failure to the endpoint | |
raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}") | |
# Return the generated story text, removing leading/trailing whitespace | |
return story.strip() | |
# --- FastAPI Endpoint for Story Generation --- | |
async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)): | |
# Choose a random theme for the story prompt | |
story_theme = random.choice([ | |
'an adventurous journey', 'a mysterious encounter', 'a heroic quest', | |
'a magical adventure', 'a thrilling escape', 'an unexpected discovery', | |
'a dangerous mission', 'a romantic escapade', 'an epic battle', | |
'a journey into the unknown' | |
]) | |
# Step 1: Get image caption using the external Gradio API via gradio_client | |
# Pass the UploadFile object directly to the captioning function | |
caption = generate_image_caption(image_file) | |
# Check if caption generation returned an error string | |
if caption.startswith("Error:"): | |
print(f"Caption generation failed: {caption}") # Log the error detail server-side | |
# Raise an HTTPException with a 500 status code and the error message | |
raise HTTPException(status_code=500, detail=caption) | |
# Step 2: Construct the prompt text for the language model | |
# This prompt instructs the model on what to write and incorporates the caption. | |
prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:" | |
# Step 3: Generate the story using the local language model (Qwen 0.5B via llama.cpp) | |
try: | |
# Call the Qwen 0.5B story generation function | |
story = generate_story_qwen_0_5b( # <--- Use the updated function name | |
prompt_text, | |
max_new_tokens=300, # Request ~300 new tokens | |
temperature=0.7, # Sampling parameters | |
top_p=0.9, | |
top_k=50 # Note: top_k may not be directly used by llama_cpp.create_chat_completion | |
) | |
story = story.strip() # Basic cleanup of generated story text | |
except RuntimeError as e: | |
# Catch specific RuntimeError raised by generate_story_qwen_0_5b if LLM loading or inference fails | |
print(f"Language model generation error: {e}") # Log the error server-side | |
# Return a 503 Service Unavailable error if the LLM is not available or failed | |
raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}") | |
except Exception as e: | |
# Catch any other unexpected errors during story generation | |
print(f"An unexpected error occurred during story generation: {e}") # Log server-side | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during story generation: {type(e).__name__}: {e}") | |
# Step 4: Translate the generated story if the target language is not English | |
# Check if language is provided and not English (case-insensitive) | |
if language and language.lower() != "english": | |
try: | |
# Initialize GoogleTranslator with English source and requested target language | |
translator = GoogleTranslator(source='english', target=language.lower()) | |
# Perform the translation | |
translated_story = translator.translate(story) | |
# Check if translation returned None or an empty string (indicates failure) | |
if translated_story is None or translated_story == "": | |
print(f"Translation returned None or empty string for language: {language}") # Log failure | |
# If translation fails, return the original English story with a warning | |
return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"} | |
# If translation was successful, use the translated text | |
story = translated_story | |
except InvalidSourceOrTargetLanguage: | |
print(f"Invalid target language requested: {language}") # Log invalid language | |
raise HTTPException(status_code=400, detail=f"Invalid target language: {language}") | |
except Exception as e: | |
# Catch any other errors during translation (e.g., network issues, API problems) | |
print(f"Translation failed for language {language}: {e}") # Log server-side | |
raise HTTPException(status_code=500, detail=f"Translation failed: {type(e).__name__}: {e}") | |
# Step 5: Return the final generated (and potentially translated) story as a JSON response | |
return {"story": story} | |
# --- Optional: Serve a simple HTML form for testing --- | |
# To use this, uncomment the imports related to HTMLResponse, StaticFiles, Jinja2Templates, Request | |
# at the top of the file, and create a 'templates' directory with an 'index.html' file. | |
# from fastapi import Request | |
# from fastapi.templating import Jinja2Templates | |
# from fastapi.staticfiles import StaticFiles | |
# templates = Jinja2Templates(directory="templates") | |
# app.mount("/static", StaticFiles(directory="static"), name="static") | |
# @app.get("/", response_class=HTMLResponse) | |
# async def read_root(request: Request): | |
# # Simple HTML form to upload an image and specify language | |
# html_content = """ | |
# <!DOCTYPE html> | |
# <html> | |
# <head><title>Story Generator</title></head> | |
# <body> | |
# <h1>Generate a Story from an Image</h1> | |
# <form action="/generate-story/" method="post" enctype="multipart/form-data"> | |
# <input type="file" name="image_file" accept="image/*" required><br><br> | |
# Target Language (e.g., english, french, spanish): <input type="text" name="language" value="english"><br><br> | |
# <button type="submit">Generate Story</button> | |
# </form> | |
# </body> | |
# </html> | |
# """ | |
# # If using templates: return templates.TemplateResponse("index.html", {"request": request}) | |
# return HTMLResponse(content=html_content) # Using direct HTML for simplicity if templates not set up |