File size: 18,332 Bytes
2be2394
f9a665c
 
 
 
 
 
778e6dd
e4ba08d
b1abf7a
2be2394
f9a665c
a56d7be
 
b1abf7a
a56d7be
b1abf7a
 
 
 
 
f9a665c
 
f2d6482
a113176
f9a665c
013c0f5
 
b1abf7a
 
013c0f5
 
96023a2
f9a665c
 
2be2394
 
 
b1abf7a
f496ef0
 
 
a0a8457
a56d7be
f496ef0
 
a56d7be
 
f496ef0
 
b1abf7a
 
 
 
f9a665c
b1abf7a
e4ba08d
 
b1abf7a
e4ba08d
 
 
b1abf7a
a113176
b1abf7a
 
a113176
b1abf7a
a56d7be
 
 
b1abf7a
 
14839a7
a113176
b1abf7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a56d7be
b1abf7a
 
 
 
 
 
 
f496ef0
b1abf7a
 
 
 
f496ef0
a56d7be
a113176
b1abf7a
f9a665c
a113176
b1abf7a
a113176
b1abf7a
e4ba08d
 
a113176
e4ba08d
 
 
 
a56d7be
 
a113176
f2d6482
b1abf7a
f2d6482
 
b1abf7a
f2d6482
 
a56d7be
a113176
e4ba08d
 
a113176
b1abf7a
f2d6482
96023a2
2be2394
e4ba08d
f9a665c
 
f496ef0
a113176
b1abf7a
e4ba08d
 
 
013c0f5
e4ba08d
 
 
b1abf7a
 
e4ba08d
 
b1abf7a
e4ba08d
013c0f5
 
 
b1abf7a
e4ba08d
 
b1abf7a
f2d6482
013c0f5
b1abf7a
 
a113176
 
b1abf7a
 
 
013c0f5
 
b1abf7a
 
013c0f5
 
 
b1abf7a
 
013c0f5
a113176
b1abf7a
 
a113176
f2d6482
b1abf7a
 
 
013c0f5
e4ba08d
b1abf7a
 
 
f2d6482
f9a665c
013c0f5
b1abf7a
013c0f5
 
 
b1abf7a
013c0f5
 
 
f9a665c
f496ef0
a56d7be
f496ef0
f9a665c
f496ef0
b1abf7a
f9a665c
b1abf7a
a56d7be
b1abf7a
 
 
f9a665c
a56d7be
 
f9a665c
a56d7be
 
f9a665c
 
 
e4ba08d
f496ef0
b1abf7a
a56d7be
 
b1abf7a
 
 
 
 
 
 
 
a56d7be
b1abf7a
f496ef0
b1abf7a
 
 
 
 
 
 
f496ef0
b1abf7a
 
 
 
f496ef0
b1abf7a
 
 
 
 
 
 
e4ba08d
a113176
2be2394
f9a665c
b1abf7a
f9a665c
013c0f5
 
 
f9a665c
 
 
b1abf7a
 
e4ba08d
a113176
b1abf7a
013c0f5
b1abf7a
 
2be2394
 
b1abf7a
 
f9a665c
 
f496ef0
f9a665c
f496ef0
 
f9a665c
b1abf7a
 
e4ba08d
b1abf7a
f9a665c
b1abf7a
f9a665c
a113176
f496ef0
b1abf7a
 
 
f9a665c
b1abf7a
 
 
f9a665c
 
b1abf7a
 
a113176
f9a665c
b1abf7a
f9a665c
b1abf7a
f9a665c
 
b1abf7a
a113176
b1abf7a
 
f9a665c
 
b1abf7a
013c0f5
f9a665c
 
b1abf7a
f9a665c
 
b1abf7a
 
a113176
 
b1abf7a
f9a665c
 
a113176
b1abf7a
 
f9a665c
 
 
 
 
 
 
b1abf7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
# Keep these if you use them elsewhere in your app (HTML, static files)
# from fastapi.responses import HTMLResponse
# from fastapi.staticfiles import StaticFiles
# from fastapi.templating import Jinja2Templates
# from fastapi.responses import FileResponse

# Removed 'requests' as we are using gradio_client
# import requests
import base64 # Keep if needed elsewhere (not strictly needed for this version)
import os
import random
# Removed unused IO import
# from typing import IO

# Import necessary classes from transformers (Keeping only AutoTokenizer)
from transformers import AutoTokenizer

# Import necessary modules for llama-cpp-python and downloading from Hub
from llama_cpp import Llama # The core Llama class
from huggingface_hub import hf_hub_download # For downloading GGUF files


# Import the Gradio Client and handle_file
from gradio_client import Client, handle_file

# Import necessary modules for temporary file handling
import tempfile
# shutil is not strictly necessary for this version, os.remove is sufficient
# import shutil


from deep_translator import GoogleTranslator
from deep_translator.exceptions import InvalidSourceOrTargetLanguage


app = FastAPI()

# --- Llama.cpp Language Model Setup (Local CPU Inference) ---
# Repository on Hugging Face Hub containing the Qwen1.5 0.5B GGUF file
# Using the OFFICIAL Qwen 0.5B repository shown in the user's image:
LLM_MODEL_REPO = "Qwen/Qwen1.5-0.5B-Chat-GGUF" # Updated to official 0.5B repo

# Specify the filename for a Q4_K_M quantized version (good balance of speed/quality on CPU)
# Based on DIRECT VERIFICATION from the user's IMAGE of the 0.5B repo:
LLM_MODEL_FILE = "qwen1_5-0_5b-chat-q4_k_m.gguf" # Exact filename from the 0.5B repo image

# Original model name for the tokenizer (needed by transformers)
# This points to the base model repository for the tokenizer files.
ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" # Updated to the 0.5B Chat model

tokenizer = None # Using transformers tokenizer for chat templating
llm_model = None # This will hold the llama_cpp.Llama instance


# --- Hugging Face Gradio Space Client Setup (For External Image Captioning) ---
# Global Gradio Client for Captioning
caption_client = None
# The URL of the external Gradio Space for image captioning
CAPTION_SPACE_URL = "Makhinur/Image-to-Text-Salesforce-blip-image-captioning-base"


# Function to load the language model (GGUF via llama.cpp) and its tokenizer (from transformers)
def load_language_model():
    global tokenizer, llm_model
    print(f"Loading language model: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...")
    try:
        # --- Load Tokenizer (using transformers) ---
        # Load the tokenizer from the original model repo
        print(f"Loading tokenizer from original model repo: {ORIGINAL_MODEL_NAME}...")
        tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL_NAME)

        # Set pad_token if not already defined, often necessary for correct batching/generation behavior
        # Qwen tokenizers should have pad_token, but this check is robust
        if tokenizer.pad_token is None:
             if tokenizer.eos_token is not None:
                 tokenizer.pad_token = tokenizer.eos_token
             elif tokenizer.unk_token is not None:
                 tokenizer.pad_token = tokenizer.unk_token
             else:
                 # Fallback if neither exists (very rare)
                 print("Warning: Neither EOS nor UNK token found for tokenizer. Setting pad_token to None.")
                 tokenizer.pad_token = None


        # --- Download GGUF model file (using huggingface_hub) ---
        print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...")
        model_path = hf_hub_download(
            repo_id=LLM_MODEL_REPO,
            filename=LLM_MODEL_FILE,
            # cache_dir="/tmp/hf_cache" # Optional: specify a custom cache directory
        )
        print(f"GGUF model downloaded to: {model_path}")

        # --- Load the GGUF model (using llama-cpp-python) ---
        print(f"Loading GGUF model into llama_cpp...")
        # Instantiate the Llama model from the downloaded GGUF file
        # n_gpu_layers=0: Crucial for forcing CPU-only inference
        # n_ctx: Context window size (tokens model can consider), match model's spec if possible (Qwen1.5 0.5B has a smaller context than 1.8B, maybe 4096 or 8192 is standard)
        # n_threads: Number of CPU threads to use. Set to your vCPU count (2) for better performance.
        llm_model = Llama(
            model_path=model_path,
            n_gpu_layers=0, # Explicitly use CPU
            n_ctx=4096,     # Context window size (4096 is a common safe value)
            n_threads=2     # Use 2 CPU threads
        )
        print("Llama.cpp model loaded successfully.")

    except Exception as e:
        print(f"Error loading language model {LLM_MODEL_REPO}/{LLM_MODEL_FILE}: {e}")
        tokenizer = None
        llm_model = None # Ensure the model is None if loading fails


# Function to initialize the Gradio Client for the captioning Space
def initialize_caption_client():
    global caption_client
    print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...")
    try:
        # If the target Gradio Space requires authentication (e.g., private)
        # store HF_TOKEN as a Space Secret and uncomment these lines.
        # HF_TOKEN = os.environ.get("HF_TOKEN")
        # if HF_TOKEN:
        #     print("Using HF_TOKEN for Gradio client.")
        #     caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN)
        # else:
        #     print("HF_TOKEN not found. Initializing public Gradio client.")
        #     caption_client = Client(CAPTION_SPACE_URL)

        # Assuming the caption space is public
        caption_client = Client(CAPTION_SPACE_URL)
        print("Gradio client initialized successfully.")
    except Exception as e:
        print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}")
        # Set client to None so the endpoint can check and return an error
        caption_client = None


# Load models and initialize clients when the app starts
@app.on_event("startup")
async def startup_event():
    # Load the language model (Qwen1.5 0.5B GGUF via llama.cpp)
    load_language_model()
    # Initialize the client for the external captioning Space
    initialize_caption_client()


# --- Image Captioning Function (Using gradio_client and temporary file) ---
def generate_image_caption(image_file: UploadFile):
    """
    Generates a caption for the uploaded image using the external Gradio Space API.
    Reads the uploaded file's content, saves it to a temporary file,
    and uses the temporary file's path with handle_file for the API call.
    """
    if caption_client is None:
        # If the client failed to initialize at startup
        error_msg = "Gradio caption client not initialized. Cannot generate caption."
        print(error_msg)
        return f"Error: {error_msg}"

    temp_file_path = None # Variable to store the path of the temporary file

    try:
        print(f"Attempting to generate caption for file: {image_file.filename}")

        # Read the content of the uploaded file
        # Seek to the beginning just in case the file-like object's pointer was moved
        image_file.file.seek(0)
        image_bytes = image_file.file.read()

        # Create a temporary file on the local filesystem
        # delete=False ensures the file persists after closing the handle
        # suffix helps hint at the file type for the Gradio API
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image_file.filename)[1] or '.jpg')
        temp_file.write(image_bytes)
        temp_file.close() # Close the file handle so gradio_client can access the file
        temp_file_path = temp_file.name # Get the full path to the temporary file

        print(f"Saved uploaded file temporarily to: {temp_file_path}")

        # Use handle_file() with the path string to the temporary file.
        # This correctly prepares the file for the Gradio API input.
        prepared_input = handle_file(temp_file_path)

        # Call the predict method on the initialized gradio_client
        # api_name="/predict" matches the endpoint specified in the Gradio API docs
        caption = caption_client.predict(img=prepared_input, api_name="/predict")

        print(f"Caption generated successfully.")
        # Return the caption string received from the API
        return caption

    except Exception as e:
        # Catch any exceptions that occur during reading, writing, or the API call
        print(f"Error during caption generation API call: {e}") # Log the error details server-side
        # Return a structured error string including the exception type and message
        return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}"

    finally:
        # Clean up the temporary file regardless of whether the process succeeded or failed
        if temp_file_path and os.path.exists(temp_file_path):
            print(f"Cleaning up temporary file: {temp_file_path}")
            try:
                os.remove(temp_file_path) # Delete the file using its path
            except OSError as e:
                 print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors


# --- Language Model Story Generation Function (Qwen1.5 0.5B via llama.cpp) ---
# Renamed function to reflect the model being used
def generate_story_qwen_0_5b(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
    """
    Generates text using the loaded Qwen1.5 0.5B model via llama.cpp.
    Uses the tokenizer to apply the chat template and calls llama.cpp's chat completion.
    """
    # Check if the language model was loaded successfully at startup
    # Check for both tokenizer and llm_model (llama.cpp instance)
    if tokenizer is None or llm_model is None:
        # Raise a RuntimeError which is caught by the calling endpoint
        raise RuntimeError("Language model (llama.cpp) or tokenizer not loaded.")

    # Construct the messages list following the chat format for Qwen1.5 Chat
    # Qwen models use a standard ChatML-like format.
    messages = [
        # System message is optional but can help guide the model's persona/style
        # {"role": "system", "content": "You are a helpful and creative assistant."}
        {"role": "user", "content": prompt_text}
    ]

    try:
        print("Calling llama.cpp create_chat_completion for Qwen 0.5B...")
        # Call the create_chat_completion method from llama_cpp.Llama instance
        # This method handles the chat templating internally for models like Qwen.
        # max_tokens is the max number of tokens to generate
        # temperature, top_p control sampling. top_k might not be a direct parameter.
        response = llm_model.create_chat_completion(
            messages=messages,
            max_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            # top_k is sometimes supported, but check llama-cpp-python docs if needed
            # top_k=top_k,
            stream=False # We want the full response at once
        )
        print("Llama.cpp completion received for Qwen 0.5B.")

        # Parse the response to get the generated text content
        # The response structure is typically like OpenAI's chat API response
        if response and response.get('choices') and len(response['choices']) > 0:
            story = response['choices'][0].get('message', {}).get('content', '')
        else:
            # Handle cases where the response is empty or has an unexpected structure
            print("Warning: Llama.cpp Qwen 0.5B response structure unexpected or content missing.")
            story = "" # Return an empty string if content is not found

    except Exception as e:
        # Catch any exception that occurs during the llama.cpp inference process
        print(f"Llama.cpp Qwen 0.5B inference failed: {e}") # Log the error server-side
        # Re-raise as a RuntimeError to indicate failure to the endpoint
        raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}")


    # Return the generated story text, removing leading/trailing whitespace
    return story.strip()


# --- FastAPI Endpoint for Story Generation ---
@app.post("/generate-story/")
async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
    # Choose a random theme for the story prompt
    story_theme = random.choice([
        'an adventurous journey', 'a mysterious encounter', 'a heroic quest',
        'a magical adventure', 'a thrilling escape', 'an unexpected discovery',
        'a dangerous mission', 'a romantic escapade', 'an epic battle',
        'a journey into the unknown'
    ])

    # Step 1: Get image caption using the external Gradio API via gradio_client
    # Pass the UploadFile object directly to the captioning function
    caption = generate_image_caption(image_file)

    # Check if caption generation returned an error string
    if caption.startswith("Error:"):
        print(f"Caption generation failed: {caption}") # Log the error detail server-side
        # Raise an HTTPException with a 500 status code and the error message
        raise HTTPException(status_code=500, detail=caption)

    # Step 2: Construct the prompt text for the language model
    # This prompt instructs the model on what to write and incorporates the caption.
    prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"

    # Step 3: Generate the story using the local language model (Qwen 0.5B via llama.cpp)
    try:
        # Call the Qwen 0.5B story generation function
        story = generate_story_qwen_0_5b( # <--- Use the updated function name
            prompt_text,
            max_new_tokens=300, # Request ~300 new tokens
            temperature=0.7,    # Sampling parameters
            top_p=0.9,
            top_k=50            # Note: top_k may not be directly used by llama_cpp.create_chat_completion
        )
        story = story.strip() # Basic cleanup of generated story text

    except RuntimeError as e:
        # Catch specific RuntimeError raised by generate_story_qwen_0_5b if LLM loading or inference fails
        print(f"Language model generation error: {e}") # Log the error server-side
        # Return a 503 Service Unavailable error if the LLM is not available or failed
        raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}")
    except Exception as e:
        # Catch any other unexpected errors during story generation
        print(f"An unexpected error occurred during story generation: {e}") # Log server-side
        raise HTTPException(status_code=500, detail=f"An unexpected error occurred during story generation: {type(e).__name__}: {e}")


    # Step 4: Translate the generated story if the target language is not English
    # Check if language is provided and not English (case-insensitive)
    if language and language.lower() != "english":
        try:
            # Initialize GoogleTranslator with English source and requested target language
            translator = GoogleTranslator(source='english', target=language.lower())
            # Perform the translation
            translated_story = translator.translate(story)

            # Check if translation returned None or an empty string (indicates failure)
            if translated_story is None or translated_story == "":
                 print(f"Translation returned None or empty string for language: {language}") # Log failure
                 # If translation fails, return the original English story with a warning
                 return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}

            # If translation was successful, use the translated text
            story = translated_story

        except InvalidSourceOrTargetLanguage:
             print(f"Invalid target language requested: {language}") # Log invalid language
             raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
        except Exception as e:
             # Catch any other errors during translation (e.g., network issues, API problems)
             print(f"Translation failed for language {language}: {e}") # Log server-side
             raise HTTPException(status_code=500, detail=f"Translation failed: {type(e).__name__}: {e}")

    # Step 5: Return the final generated (and potentially translated) story as a JSON response
    return {"story": story}

# --- Optional: Serve a simple HTML form for testing ---
# To use this, uncomment the imports related to HTMLResponse, StaticFiles, Jinja2Templates, Request
# at the top of the file, and create a 'templates' directory with an 'index.html' file.
# from fastapi import Request
# from fastapi.templating import Jinja2Templates
# from fastapi.staticfiles import StaticFiles
# templates = Jinja2Templates(directory="templates")
# app.mount("/static", StaticFiles(directory="static"), name="static")
# @app.get("/", response_class=HTMLResponse)
# async def read_root(request: Request):
#     # Simple HTML form to upload an image and specify language
#     html_content = """
#     <!DOCTYPE html>
#     <html>
#     <head><title>Story Generator</title></head>
#     <body>
#         <h1>Generate a Story from an Image</h1>
#         <form action="/generate-story/" method="post" enctype="multipart/form-data">
#             <input type="file" name="image_file" accept="image/*" required><br><br>
#             Target Language (e.g., english, french, spanish): <input type="text" name="language" value="english"><br><br>
#             <button type="submit">Generate Story</button>
#         </form>
#     </body>
#     </html>
#     """
#     # If using templates: return templates.TemplateResponse("index.html", {"request": request})
#     return HTMLResponse(content=html_content) # Using direct HTML for simplicity if templates not set up