from fastapi import FastAPI, File, UploadFile, Form, HTTPException # Keep these if you use them elsewhere in your app (HTML, static files) # from fastapi.responses import HTMLResponse # from fastapi.staticfiles import StaticFiles # from fastapi.templating import Jinja2Templates # from fastapi.responses import FileResponse # Removed 'requests' as we are using gradio_client # import requests import base64 # Keep if needed elsewhere (not strictly needed for this version) import os import random # Removed unused IO import # from typing import IO # Import necessary classes from transformers (Keeping only AutoTokenizer) from transformers import AutoTokenizer # Import necessary modules for llama-cpp-python and downloading from Hub from llama_cpp import Llama # The core Llama class from huggingface_hub import hf_hub_download # For downloading GGUF files # Import the Gradio Client and handle_file from gradio_client import Client, handle_file # Import necessary modules for temporary file handling import tempfile # shutil is not strictly necessary for this version, os.remove is sufficient # import shutil from deep_translator import GoogleTranslator from deep_translator.exceptions import InvalidSourceOrTargetLanguage app = FastAPI() # --- Llama.cpp Language Model Setup (Local CPU Inference) --- # Repository on Hugging Face Hub containing the Qwen1.5 0.5B GGUF file # Using the OFFICIAL Qwen 0.5B repository shown in the user's image: LLM_MODEL_REPO = "Qwen/Qwen1.5-0.5B-Chat-GGUF" # Updated to official 0.5B repo # Specify the filename for a Q4_K_M quantized version (good balance of speed/quality on CPU) # Based on DIRECT VERIFICATION from the user's IMAGE of the 0.5B repo: LLM_MODEL_FILE = "qwen1_5-0_5b-chat-q4_k_m.gguf" # Exact filename from the 0.5B repo image # Original model name for the tokenizer (needed by transformers) # This points to the base model repository for the tokenizer files. ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" # Updated to the 0.5B Chat model tokenizer = None # Using transformers tokenizer for chat templating llm_model = None # This will hold the llama_cpp.Llama instance # --- Hugging Face Gradio Space Client Setup (For External Image Captioning) --- # Global Gradio Client for Captioning caption_client = None # The URL of the external Gradio Space for image captioning CAPTION_SPACE_URL = "Makhinur/Image-to-Text-Salesforce-blip-image-captioning-base" # Function to load the language model (GGUF via llama.cpp) and its tokenizer (from transformers) def load_language_model(): global tokenizer, llm_model print(f"Loading language model: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...") try: # --- Load Tokenizer (using transformers) --- # Load the tokenizer from the original model repo print(f"Loading tokenizer from original model repo: {ORIGINAL_MODEL_NAME}...") tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL_NAME) # Set pad_token if not already defined, often necessary for correct batching/generation behavior # Qwen tokenizers should have pad_token, but this check is robust if tokenizer.pad_token is None: if tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token elif tokenizer.unk_token is not None: tokenizer.pad_token = tokenizer.unk_token else: # Fallback if neither exists (very rare) print("Warning: Neither EOS nor UNK token found for tokenizer. Setting pad_token to None.") tokenizer.pad_token = None # --- Download GGUF model file (using huggingface_hub) --- print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...") model_path = hf_hub_download( repo_id=LLM_MODEL_REPO, filename=LLM_MODEL_FILE, # cache_dir="/tmp/hf_cache" # Optional: specify a custom cache directory ) print(f"GGUF model downloaded to: {model_path}") # --- Load the GGUF model (using llama-cpp-python) --- print(f"Loading GGUF model into llama_cpp...") # Instantiate the Llama model from the downloaded GGUF file # n_gpu_layers=0: Crucial for forcing CPU-only inference # n_ctx: Context window size (tokens model can consider), match model's spec if possible (Qwen1.5 0.5B has a smaller context than 1.8B, maybe 4096 or 8192 is standard) # n_threads: Number of CPU threads to use. Set to your vCPU count (2) for better performance. llm_model = Llama( model_path=model_path, n_gpu_layers=0, # Explicitly use CPU n_ctx=4096, # Context window size (4096 is a common safe value) n_threads=2 # Use 2 CPU threads ) print("Llama.cpp model loaded successfully.") except Exception as e: print(f"Error loading language model {LLM_MODEL_REPO}/{LLM_MODEL_FILE}: {e}") tokenizer = None llm_model = None # Ensure the model is None if loading fails # Function to initialize the Gradio Client for the captioning Space def initialize_caption_client(): global caption_client print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...") try: # If the target Gradio Space requires authentication (e.g., private) # store HF_TOKEN as a Space Secret and uncomment these lines. # HF_TOKEN = os.environ.get("HF_TOKEN") # if HF_TOKEN: # print("Using HF_TOKEN for Gradio client.") # caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN) # else: # print("HF_TOKEN not found. Initializing public Gradio client.") # caption_client = Client(CAPTION_SPACE_URL) # Assuming the caption space is public caption_client = Client(CAPTION_SPACE_URL) print("Gradio client initialized successfully.") except Exception as e: print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}") # Set client to None so the endpoint can check and return an error caption_client = None # Load models and initialize clients when the app starts @app.on_event("startup") async def startup_event(): # Load the language model (Qwen1.5 0.5B GGUF via llama.cpp) load_language_model() # Initialize the client for the external captioning Space initialize_caption_client() # --- Image Captioning Function (Using gradio_client and temporary file) --- def generate_image_caption(image_file: UploadFile): """ Generates a caption for the uploaded image using the external Gradio Space API. Reads the uploaded file's content, saves it to a temporary file, and uses the temporary file's path with handle_file for the API call. """ if caption_client is None: # If the client failed to initialize at startup error_msg = "Gradio caption client not initialized. Cannot generate caption." print(error_msg) return f"Error: {error_msg}" temp_file_path = None # Variable to store the path of the temporary file try: print(f"Attempting to generate caption for file: {image_file.filename}") # Read the content of the uploaded file # Seek to the beginning just in case the file-like object's pointer was moved image_file.file.seek(0) image_bytes = image_file.file.read() # Create a temporary file on the local filesystem # delete=False ensures the file persists after closing the handle # suffix helps hint at the file type for the Gradio API temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image_file.filename)[1] or '.jpg') temp_file.write(image_bytes) temp_file.close() # Close the file handle so gradio_client can access the file temp_file_path = temp_file.name # Get the full path to the temporary file print(f"Saved uploaded file temporarily to: {temp_file_path}") # Use handle_file() with the path string to the temporary file. # This correctly prepares the file for the Gradio API input. prepared_input = handle_file(temp_file_path) # Call the predict method on the initialized gradio_client # api_name="/predict" matches the endpoint specified in the Gradio API docs caption = caption_client.predict(img=prepared_input, api_name="/predict") print(f"Caption generated successfully.") # Return the caption string received from the API return caption except Exception as e: # Catch any exceptions that occur during reading, writing, or the API call print(f"Error during caption generation API call: {e}") # Log the error details server-side # Return a structured error string including the exception type and message return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}" finally: # Clean up the temporary file regardless of whether the process succeeded or failed if temp_file_path and os.path.exists(temp_file_path): print(f"Cleaning up temporary file: {temp_file_path}") try: os.remove(temp_file_path) # Delete the file using its path except OSError as e: print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors # --- Language Model Story Generation Function (Qwen1.5 0.5B via llama.cpp) --- # Renamed function to reflect the model being used def generate_story_qwen_0_5b(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str: """ Generates text using the loaded Qwen1.5 0.5B model via llama.cpp. Uses the tokenizer to apply the chat template and calls llama.cpp's chat completion. """ # Check if the language model was loaded successfully at startup # Check for both tokenizer and llm_model (llama.cpp instance) if tokenizer is None or llm_model is None: # Raise a RuntimeError which is caught by the calling endpoint raise RuntimeError("Language model (llama.cpp) or tokenizer not loaded.") # Construct the messages list following the chat format for Qwen1.5 Chat # Qwen models use a standard ChatML-like format. messages = [ # System message is optional but can help guide the model's persona/style # {"role": "system", "content": "You are a helpful and creative assistant."} {"role": "user", "content": prompt_text} ] try: print("Calling llama.cpp create_chat_completion for Qwen 0.5B...") # Call the create_chat_completion method from llama_cpp.Llama instance # This method handles the chat templating internally for models like Qwen. # max_tokens is the max number of tokens to generate # temperature, top_p control sampling. top_k might not be a direct parameter. response = llm_model.create_chat_completion( messages=messages, max_tokens=max_new_tokens, temperature=temperature, top_p=top_p, # top_k is sometimes supported, but check llama-cpp-python docs if needed # top_k=top_k, stream=False # We want the full response at once ) print("Llama.cpp completion received for Qwen 0.5B.") # Parse the response to get the generated text content # The response structure is typically like OpenAI's chat API response if response and response.get('choices') and len(response['choices']) > 0: story = response['choices'][0].get('message', {}).get('content', '') else: # Handle cases where the response is empty or has an unexpected structure print("Warning: Llama.cpp Qwen 0.5B response structure unexpected or content missing.") story = "" # Return an empty string if content is not found except Exception as e: # Catch any exception that occurs during the llama.cpp inference process print(f"Llama.cpp Qwen 0.5B inference failed: {e}") # Log the error server-side # Re-raise as a RuntimeError to indicate failure to the endpoint raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}") # Return the generated story text, removing leading/trailing whitespace return story.strip() # --- FastAPI Endpoint for Story Generation --- @app.post("/generate-story/") async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)): # Choose a random theme for the story prompt story_theme = random.choice([ 'an adventurous journey', 'a mysterious encounter', 'a heroic quest', 'a magical adventure', 'a thrilling escape', 'an unexpected discovery', 'a dangerous mission', 'a romantic escapade', 'an epic battle', 'a journey into the unknown' ]) # Step 1: Get image caption using the external Gradio API via gradio_client # Pass the UploadFile object directly to the captioning function caption = generate_image_caption(image_file) # Check if caption generation returned an error string if caption.startswith("Error:"): print(f"Caption generation failed: {caption}") # Log the error detail server-side # Raise an HTTPException with a 500 status code and the error message raise HTTPException(status_code=500, detail=caption) # Step 2: Construct the prompt text for the language model # This prompt instructs the model on what to write and incorporates the caption. prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:" # Step 3: Generate the story using the local language model (Qwen 0.5B via llama.cpp) try: # Call the Qwen 0.5B story generation function story = generate_story_qwen_0_5b( # <--- Use the updated function name prompt_text, max_new_tokens=300, # Request ~300 new tokens temperature=0.7, # Sampling parameters top_p=0.9, top_k=50 # Note: top_k may not be directly used by llama_cpp.create_chat_completion ) story = story.strip() # Basic cleanup of generated story text except RuntimeError as e: # Catch specific RuntimeError raised by generate_story_qwen_0_5b if LLM loading or inference fails print(f"Language model generation error: {e}") # Log the error server-side # Return a 503 Service Unavailable error if the LLM is not available or failed raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}") except Exception as e: # Catch any other unexpected errors during story generation print(f"An unexpected error occurred during story generation: {e}") # Log server-side raise HTTPException(status_code=500, detail=f"An unexpected error occurred during story generation: {type(e).__name__}: {e}") # Step 4: Translate the generated story if the target language is not English # Check if language is provided and not English (case-insensitive) if language and language.lower() != "english": try: # Initialize GoogleTranslator with English source and requested target language translator = GoogleTranslator(source='english', target=language.lower()) # Perform the translation translated_story = translator.translate(story) # Check if translation returned None or an empty string (indicates failure) if translated_story is None or translated_story == "": print(f"Translation returned None or empty string for language: {language}") # Log failure # If translation fails, return the original English story with a warning return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"} # If translation was successful, use the translated text story = translated_story except InvalidSourceOrTargetLanguage: print(f"Invalid target language requested: {language}") # Log invalid language raise HTTPException(status_code=400, detail=f"Invalid target language: {language}") except Exception as e: # Catch any other errors during translation (e.g., network issues, API problems) print(f"Translation failed for language {language}: {e}") # Log server-side raise HTTPException(status_code=500, detail=f"Translation failed: {type(e).__name__}: {e}") # Step 5: Return the final generated (and potentially translated) story as a JSON response return {"story": story} # --- Optional: Serve a simple HTML form for testing --- # To use this, uncomment the imports related to HTMLResponse, StaticFiles, Jinja2Templates, Request # at the top of the file, and create a 'templates' directory with an 'index.html' file. # from fastapi import Request # from fastapi.templating import Jinja2Templates # from fastapi.staticfiles import StaticFiles # templates = Jinja2Templates(directory="templates") # app.mount("/static", StaticFiles(directory="static"), name="static") # @app.get("/", response_class=HTMLResponse) # async def read_root(request: Request): # # Simple HTML form to upload an image and specify language # html_content = """ # # # Story Generator # #

Generate a Story from an Image

#
#

# Target Language (e.g., english, french, spanish):

# #
# # # """ # # If using templates: return templates.TemplateResponse("index.html", {"request": request}) # return HTMLResponse(content=html_content) # Using direct HTML for simplicity if templates not set up