Update main.py
Browse files
main.py
CHANGED
@@ -37,18 +37,17 @@ from deep_translator.exceptions import InvalidSourceOrTargetLanguage
|
|
37 |
app = FastAPI()
|
38 |
|
39 |
# --- Llama.cpp Language Model Setup (Local CPU Inference) ---
|
40 |
-
# Repository on Hugging Face Hub containing the Qwen1.5
|
41 |
-
# Using the OFFICIAL Qwen repository:
|
42 |
-
LLM_MODEL_REPO = "Qwen/Qwen1.5-
|
43 |
|
44 |
# Specify the filename for a Q4_K_M quantized version (good balance of speed/quality on CPU)
|
45 |
-
# Based on DIRECT VERIFICATION
|
46 |
-
|
47 |
-
LLM_MODEL_FILE = "qwen1_5-1.8b-chat-Q4_K_M.gguf" # Correct filename for the OFFICIAL repo
|
48 |
|
49 |
# Original model name for the tokenizer (needed by transformers)
|
50 |
-
# This
|
51 |
-
ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-
|
52 |
|
53 |
tokenizer = None # Using transformers tokenizer for chat templating
|
54 |
llm_model = None # This will hold the llama_cpp.Llama instance
|
@@ -86,7 +85,6 @@ def load_language_model():
|
|
86 |
|
87 |
# --- Download GGUF model file (using huggingface_hub) ---
|
88 |
print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...")
|
89 |
-
# hf_hub_download downloads the file to the Hugging Face cache directory
|
90 |
model_path = hf_hub_download(
|
91 |
repo_id=LLM_MODEL_REPO,
|
92 |
filename=LLM_MODEL_FILE,
|
@@ -98,12 +96,12 @@ def load_language_model():
|
|
98 |
print(f"Loading GGUF model into llama_cpp...")
|
99 |
# Instantiate the Llama model from the downloaded GGUF file
|
100 |
# n_gpu_layers=0: Crucial for forcing CPU-only inference
|
101 |
-
# n_ctx: Context window size (tokens model can consider), match model's spec if possible (Qwen1.5
|
102 |
# n_threads: Number of CPU threads to use. Set to your vCPU count (2) for better performance.
|
103 |
llm_model = Llama(
|
104 |
model_path=model_path,
|
105 |
n_gpu_layers=0, # Explicitly use CPU
|
106 |
-
n_ctx=4096, # Context window size (
|
107 |
n_threads=2 # Use 2 CPU threads
|
108 |
)
|
109 |
print("Llama.cpp model loaded successfully.")
|
@@ -141,7 +139,7 @@ def initialize_caption_client():
|
|
141 |
# Load models and initialize clients when the app starts
|
142 |
@app.on_event("startup")
|
143 |
async def startup_event():
|
144 |
-
# Load the language model (Qwen1.5
|
145 |
load_language_model()
|
146 |
# Initialize the client for the external captioning Space
|
147 |
initialize_caption_client()
|
@@ -208,11 +206,11 @@ def generate_image_caption(image_file: UploadFile):
|
|
208 |
print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors
|
209 |
|
210 |
|
211 |
-
# --- Language Model Story Generation Function (Qwen1.5
|
212 |
# Renamed function to reflect the model being used
|
213 |
-
def
|
214 |
"""
|
215 |
-
Generates text using the loaded Qwen1.5
|
216 |
Uses the tokenizer to apply the chat template and calls llama.cpp's chat completion.
|
217 |
"""
|
218 |
# Check if the language model was loaded successfully at startup
|
@@ -230,7 +228,7 @@ def generate_story_qwen(prompt_text: str, max_new_tokens: int = 300, temperature
|
|
230 |
]
|
231 |
|
232 |
try:
|
233 |
-
print("Calling llama.cpp create_chat_completion for Qwen...")
|
234 |
# Call the create_chat_completion method from llama_cpp.Llama instance
|
235 |
# This method handles the chat templating internally for models like Qwen.
|
236 |
# max_tokens is the max number of tokens to generate
|
@@ -244,7 +242,7 @@ def generate_story_qwen(prompt_text: str, max_new_tokens: int = 300, temperature
|
|
244 |
# top_k=top_k,
|
245 |
stream=False # We want the full response at once
|
246 |
)
|
247 |
-
print("Llama.cpp completion received for Qwen.")
|
248 |
|
249 |
# Parse the response to get the generated text content
|
250 |
# The response structure is typically like OpenAI's chat API response
|
@@ -252,12 +250,12 @@ def generate_story_qwen(prompt_text: str, max_new_tokens: int = 300, temperature
|
|
252 |
story = response['choices'][0].get('message', {}).get('content', '')
|
253 |
else:
|
254 |
# Handle cases where the response is empty or has an unexpected structure
|
255 |
-
print("Warning: Llama.cpp Qwen response structure unexpected or content missing.")
|
256 |
story = "" # Return an empty string if content is not found
|
257 |
|
258 |
except Exception as e:
|
259 |
# Catch any exception that occurs during the llama.cpp inference process
|
260 |
-
print(f"Llama.cpp Qwen inference failed: {e}") # Log the error server-side
|
261 |
# Re-raise as a RuntimeError to indicate failure to the endpoint
|
262 |
raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}")
|
263 |
|
@@ -291,10 +289,10 @@ async def generate_story_endpoint(image_file: UploadFile = File(...), language:
|
|
291 |
# This prompt instructs the model on what to write and incorporates the caption.
|
292 |
prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
|
293 |
|
294 |
-
# Step 3: Generate the story using the local language model (Qwen via llama.cpp)
|
295 |
try:
|
296 |
-
# Call the Qwen story generation function
|
297 |
-
story =
|
298 |
prompt_text,
|
299 |
max_new_tokens=300, # Request ~300 new tokens
|
300 |
temperature=0.7, # Sampling parameters
|
@@ -304,7 +302,7 @@ async def generate_story_endpoint(image_file: UploadFile = File(...), language:
|
|
304 |
story = story.strip() # Basic cleanup of generated story text
|
305 |
|
306 |
except RuntimeError as e:
|
307 |
-
# Catch specific RuntimeError raised by
|
308 |
print(f"Language model generation error: {e}") # Log the error server-side
|
309 |
# Return a 503 Service Unavailable error if the LLM is not available or failed
|
310 |
raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}")
|
|
|
37 |
app = FastAPI()
|
38 |
|
39 |
# --- Llama.cpp Language Model Setup (Local CPU Inference) ---
|
40 |
+
# Repository on Hugging Face Hub containing the Qwen1.5 0.5B GGUF file
|
41 |
+
# Using the OFFICIAL Qwen 0.5B repository shown in the user's image:
|
42 |
+
LLM_MODEL_REPO = "Qwen/Qwen1.5-0.5B-Chat-GGUF" # Updated to official 0.5B repo
|
43 |
|
44 |
# Specify the filename for a Q4_K_M quantized version (good balance of speed/quality on CPU)
|
45 |
+
# Based on DIRECT VERIFICATION from the user's IMAGE of the 0.5B repo:
|
46 |
+
LLM_MODEL_FILE = "qwen1_5-0_5b-chat-q4_k_m.gguf" # Exact filename from the 0.5B repo image
|
|
|
47 |
|
48 |
# Original model name for the tokenizer (needed by transformers)
|
49 |
+
# This points to the base model repository for the tokenizer files.
|
50 |
+
ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" # Updated to the 0.5B Chat model
|
51 |
|
52 |
tokenizer = None # Using transformers tokenizer for chat templating
|
53 |
llm_model = None # This will hold the llama_cpp.Llama instance
|
|
|
85 |
|
86 |
# --- Download GGUF model file (using huggingface_hub) ---
|
87 |
print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...")
|
|
|
88 |
model_path = hf_hub_download(
|
89 |
repo_id=LLM_MODEL_REPO,
|
90 |
filename=LLM_MODEL_FILE,
|
|
|
96 |
print(f"Loading GGUF model into llama_cpp...")
|
97 |
# Instantiate the Llama model from the downloaded GGUF file
|
98 |
# n_gpu_layers=0: Crucial for forcing CPU-only inference
|
99 |
+
# n_ctx: Context window size (tokens model can consider), match model's spec if possible (Qwen1.5 0.5B has a smaller context than 1.8B, maybe 4096 or 8192 is standard)
|
100 |
# n_threads: Number of CPU threads to use. Set to your vCPU count (2) for better performance.
|
101 |
llm_model = Llama(
|
102 |
model_path=model_path,
|
103 |
n_gpu_layers=0, # Explicitly use CPU
|
104 |
+
n_ctx=4096, # Context window size (4096 is a common safe value)
|
105 |
n_threads=2 # Use 2 CPU threads
|
106 |
)
|
107 |
print("Llama.cpp model loaded successfully.")
|
|
|
139 |
# Load models and initialize clients when the app starts
|
140 |
@app.on_event("startup")
|
141 |
async def startup_event():
|
142 |
+
# Load the language model (Qwen1.5 0.5B GGUF via llama.cpp)
|
143 |
load_language_model()
|
144 |
# Initialize the client for the external captioning Space
|
145 |
initialize_caption_client()
|
|
|
206 |
print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors
|
207 |
|
208 |
|
209 |
+
# --- Language Model Story Generation Function (Qwen1.5 0.5B via llama.cpp) ---
|
210 |
# Renamed function to reflect the model being used
|
211 |
+
def generate_story_qwen_0_5b(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
|
212 |
"""
|
213 |
+
Generates text using the loaded Qwen1.5 0.5B model via llama.cpp.
|
214 |
Uses the tokenizer to apply the chat template and calls llama.cpp's chat completion.
|
215 |
"""
|
216 |
# Check if the language model was loaded successfully at startup
|
|
|
228 |
]
|
229 |
|
230 |
try:
|
231 |
+
print("Calling llama.cpp create_chat_completion for Qwen 0.5B...")
|
232 |
# Call the create_chat_completion method from llama_cpp.Llama instance
|
233 |
# This method handles the chat templating internally for models like Qwen.
|
234 |
# max_tokens is the max number of tokens to generate
|
|
|
242 |
# top_k=top_k,
|
243 |
stream=False # We want the full response at once
|
244 |
)
|
245 |
+
print("Llama.cpp completion received for Qwen 0.5B.")
|
246 |
|
247 |
# Parse the response to get the generated text content
|
248 |
# The response structure is typically like OpenAI's chat API response
|
|
|
250 |
story = response['choices'][0].get('message', {}).get('content', '')
|
251 |
else:
|
252 |
# Handle cases where the response is empty or has an unexpected structure
|
253 |
+
print("Warning: Llama.cpp Qwen 0.5B response structure unexpected or content missing.")
|
254 |
story = "" # Return an empty string if content is not found
|
255 |
|
256 |
except Exception as e:
|
257 |
# Catch any exception that occurs during the llama.cpp inference process
|
258 |
+
print(f"Llama.cpp Qwen 0.5B inference failed: {e}") # Log the error server-side
|
259 |
# Re-raise as a RuntimeError to indicate failure to the endpoint
|
260 |
raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}")
|
261 |
|
|
|
289 |
# This prompt instructs the model on what to write and incorporates the caption.
|
290 |
prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
|
291 |
|
292 |
+
# Step 3: Generate the story using the local language model (Qwen 0.5B via llama.cpp)
|
293 |
try:
|
294 |
+
# Call the Qwen 0.5B story generation function
|
295 |
+
story = generate_story_qwen_0_5b( # <--- Use the updated function name
|
296 |
prompt_text,
|
297 |
max_new_tokens=300, # Request ~300 new tokens
|
298 |
temperature=0.7, # Sampling parameters
|
|
|
302 |
story = story.strip() # Basic cleanup of generated story text
|
303 |
|
304 |
except RuntimeError as e:
|
305 |
+
# Catch specific RuntimeError raised by generate_story_qwen_0_5b if LLM loading or inference fails
|
306 |
print(f"Language model generation error: {e}") # Log the error server-side
|
307 |
# Return a 503 Service Unavailable error if the LLM is not available or failed
|
308 |
raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}")
|