Makhinur commited on
Commit
013c0f5
·
verified ·
1 Parent(s): a113176

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +73 -144
main.py CHANGED
@@ -5,12 +5,12 @@ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
5
  # from fastapi.templating import Jinja2Templates
6
  # from fastapi.responses import FileResponse
7
 
8
- # Removed 'requests' as we'll primarily use gradio_client for captioning
9
  # import requests
10
- import base64 # Keep if you might need base64 for other purposes
11
  import os
12
  import random
13
- from typing import IO # Import IO for type hinting file-like objects
14
 
15
  # Import necessary classes from transformers
16
  import torch
@@ -19,6 +19,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
19
  # Import the Gradio Client and handle_file
20
  from gradio_client import Client, handle_file
21
 
 
 
 
 
 
22
  from deep_translator import GoogleTranslator
23
  from deep_translator.exceptions import InvalidSourceOrTargetLanguage
24
 
@@ -26,10 +31,7 @@ from deep_translator.exceptions import InvalidSourceOrTargetLanguage
26
  app = FastAPI()
27
 
28
  # --- Hugging Face Language Model Setup (Local Inference) ---
29
- # Model name for TinyLlama 1.1B Chat (instruction-tuned version)
30
- # Chosen for balance of quality and speed on CPU basic (faster than Gemma 2B, better than GPT-2 base)
31
- # If you get access to Gemma 2B-IT and prefer its quality (accepting slower speed),
32
- # change this to "google/gemma-2b-it" and use the generate_story_gemma function below.
33
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
34
  tokenizer = None
35
  model = None
@@ -47,23 +49,17 @@ def load_language_model():
47
  print(f"Loading language model: {model_name}...")
48
  try:
49
  tokenizer = AutoTokenizer.from_pretrained(model_name)
50
- # TinyLlama might not have an explicit pad token, setting it to EOS is common practice
51
  if tokenizer.pad_token is None:
52
  tokenizer.pad_token = tokenizer.eos_token
53
 
54
- # Load model weights. Using float16 to reduce memory footprint on CPU.
55
  model = AutoModelForCausalLM.from_pretrained(
56
  model_name,
57
  torch_dtype=torch.float16, # Use float16 precision
58
- # device_map="auto" # Not needed for single CPU
59
  )
60
- # model.to("cpu") # Explicitly ensure it's on CPU, although from_pretrained does this by default on CPU-only systems.
61
 
62
  print(f"Language model {model_name} loaded successfully.")
63
  except Exception as e:
64
  print(f"Error loading language model {model_name}: {e}")
65
- # Depending on requirements, you might want to exit or set model/tokenizer to None permanently
66
- # For now, we let the app start, but subsequent generation calls will fail gracefully.
67
  tokenizer = None
68
  model = None
69
 
@@ -73,69 +69,85 @@ def initialize_caption_client():
73
  global caption_client
74
  print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...")
75
  try:
76
- # If the target Gradio Space is private or requires authentication,
77
- # uncomment the lines below and set your HF_TOKEN as a Space Secret.
78
  # HF_TOKEN = os.environ.get("HF_TOKEN")
79
  # if HF_TOKEN:
80
  # caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN)
81
  # else:
82
  # caption_client = Client(CAPTION_SPACE_URL)
83
 
84
- # Assuming the caption space is public and does not require a token
85
  caption_client = Client(CAPTION_SPACE_URL)
86
  print("Gradio client initialized successfully.")
87
  except Exception as e:
88
  print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}")
89
- # Set client to None so the endpoint can check and return an error
90
  caption_client = None
91
 
92
 
93
  # Load models and initialize clients when the app starts
94
  @app.on_event("startup")
95
  async def startup_event():
96
- # Load the language model (TinyLlama or Gemma)
97
  load_language_model()
98
- # Initialize the client for the captioning Space
99
  initialize_caption_client()
100
 
101
 
102
- # --- Image Captioning Function (Using gradio_client) ---
103
  def generate_image_caption(image_file: UploadFile):
104
  """
105
  Generates a caption for the uploaded image using the external Gradio Space API.
106
- Reads the file content and uses handle_file for correct API input format.
107
  """
108
  if caption_client is None:
109
  error_msg = "Gradio caption client not initialized. Cannot generate caption."
110
- print(error_msg) # Log the error server-side
111
- return f"Error: {error_msg}" # Return an error string to the caller
 
 
112
 
113
  try:
114
  print(f"Calling caption API /predict for file {image_file.filename}...")
115
 
116
- # Read the content of the uploaded file into bytes
117
- # It's important to seek(0) in case the file-like object has been read before
118
- image_file.file.seek(0)
119
  image_bytes = image_file.file.read()
120
 
121
- # Use handle_file() with the byte content. This prepares the bytes
122
- # into the format expected by the Gradio API (often base64).
123
- prepared_input = handle_file(image_bytes)
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Call the predict method on the initialized client with the prepared input
126
  caption = caption_client.predict(img=prepared_input, api_name="/predict")
127
 
128
  print(f"Caption generated: {caption}")
129
  return caption # Return the successful caption string
 
130
  except Exception as e:
131
- # Catch potential exceptions from gradio_client.predict (network, API error, etc.)
132
- print(f"Error during caption generation API call: {e}") # Log the exception server-side
133
- # Return an informative error string including the exception type and message
134
  return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}"
135
 
 
 
 
 
 
 
 
 
 
136
 
137
- # --- Language Model Story Generation Function ---
138
- # This function uses the loaded TinyLlama model to generate the story.
139
  def generate_story_tinyllama(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
140
  """
141
  Generates text using the loaded TinyLlama model based on the prompt.
@@ -144,117 +156,56 @@ def generate_story_tinyllama(prompt_text: str, max_new_tokens: int = 300, temper
144
  if tokenizer is None or model is None:
145
  raise RuntimeError("Language model and tokenizer not loaded. Cannot generate story.")
146
 
147
- # TinyLlama-Chat uses a chat template (similar to Llama/Gemma's instruction format)
148
  messages = [
149
  {"role": "user", "content": prompt_text}
150
- # Add {"role": "system", "content": "Your system prompt here"} if needed
151
  ]
152
 
153
- # Apply the chat template to format the prompt correctly for the model
154
  try:
155
  input_text = tokenizer.apply_chat_template(
156
  messages,
157
- tokenize=False, # Return as string before tokenizing
158
- add_generation_prompt=True # Adds the assistant turn prompt token(s)
159
  )
160
  except AttributeError:
161
- # Fallback for models that don't have a chat template defined
162
- print("Warning: apply_chat_template not found for this tokenizer. Using basic prompt formatting.")
163
- input_text = f"<s>[INST] {prompt_text} [/INST]" # Basic Llama/TinyLlama instruction format
164
 
165
- # Encode the templated prompt into input IDs
166
- # max_length should be within the model's context window (e.g., 4096 for TinyLlama)
167
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024) # Truncate if prompt is too long
168
 
169
- # Ensure input tensors are on the same device as the model (CPU by default)
170
- # inputs = {k: v.to(model.device) for k, v in inputs.items()} # Redundant if model is on CPU
171
-
172
- # Generate new tokens based on the input prompt
173
  generate_ids = model.generate(
174
  inputs.input_ids,
175
- max_new_tokens=max_new_tokens, # Maximum number of tokens to generate
176
- do_sample=True, # Enable sampling for creative output
177
- temperature=temperature, # Control randomness
178
- top_p=top_p, # Control diversity (nucleus sampling)
179
- top_k=top_k, # Control diversity (top-k sampling)
180
- pad_token_id=tokenizer.pad_token_id, # Specify pad token for generation
181
- # eos_token_id=tokenizer.eos_token_id # Optional: specify end-of-sequence token id to stop generation early
182
  )
183
 
184
- # Decode the generated token IDs back into text
185
- # Slice [0, inputs.input_ids.shape[-1]:] to get only the newly generated tokens
186
- # skip_special_tokens=True removes tokens like <s>, </s>, <pad>
187
  generated_text = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
188
-
189
- return generated_text.strip() # Return the generated text, stripped of leading/trailing whitespace
190
-
191
- # --- Optional: Gemma 2B Story Generation Function (if you prefer Gemma and get access) ---
192
- # def generate_story_gemma(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
193
- # """
194
- # Generates text using the loaded Gemma model based on the prompt.
195
- # Applies the Gemma-IT Chat template.
196
- # """
197
- # if tokenizer is None or model is None:
198
- # raise RuntimeError("Language model and tokenizer not loaded. Cannot generate story.")
199
-
200
- # # Gemma-IT uses a specific chat template
201
- # messages = [
202
- # {"role": "user", "content": prompt_text}
203
- # # {"role": "system", "content": "Your system prompt here"} # Gemma also supports system prompts
204
- # ]
205
- # input_text = tokenizer.apply_chat_template(
206
- # messages,
207
- # tokenize=False,
208
- # add_generation_prompt=True # Adds the assistant turn prompt token(s)
209
- # )
210
-
211
- # inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
212
-
213
- # generate_ids = model.generate(
214
- # inputs.input_ids,
215
- # max_new_tokens=max_new_tokens,
216
- # do_sample=True,
217
- # temperature=temperature,
218
- # top_p=top_p,
219
- # top_k=top_k,
220
- # pad_token_id=tokenizer.pad_token_id,
221
- # )
222
-
223
- # generated_text = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
224
- # return generated_text.strip()
225
-
226
 
227
  # --- FastAPI Endpoint for Story Generation ---
228
  @app.post("/generate-story/")
229
  async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
230
- # No need to manually read image data here, generate_image_caption handles it
231
-
232
- # Choose a random theme for the story prompt
233
  story_theme = random.choice([
234
- 'an adventurous journey',
235
- 'a mysterious encounter',
236
- 'a heroic quest',
237
- 'a magical adventure',
238
- 'a thrilling escape',
239
- 'an unexpected discovery',
240
- 'a dangerous mission',
241
- 'a romantic escapade',
242
- 'an epic battle',
243
  'a journey into the unknown'
244
  ])
245
 
246
  # Step 1: Get image caption using the external API via gradio_client
247
- # Pass the UploadFile object directly to the captioning function
248
  caption = generate_image_caption(image_file)
249
 
250
  # Check if caption generation failed
251
- if caption.startswith("Error:"): # Check if the returned string indicates an error
252
- print(f"Caption generation failed: {caption}") # Log the error detail server-side
253
- # Return a 500 Internal Server Error with the error message
254
  raise HTTPException(status_code=500, detail=caption)
255
 
256
  # Step 2: Construct the prompt for the language model
257
- # We instruct the model to write a story based on the theme and incorporate the caption.
258
  prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
259
 
260
  # Step 3: Generate the story using the local language model
@@ -262,72 +213,50 @@ async def generate_story_endpoint(image_file: UploadFile = File(...), language:
262
  # Call the appropriate story generation function (TinyLlama in this case)
263
  story = generate_story_tinyllama(
264
  prompt_text,
265
- max_new_tokens=300, # Aim for ~300 new tokens
266
- temperature=0.7, # Standard creative sampling parameters
267
  top_p=0.9,
268
  top_k=50
269
  )
270
- story = story.strip() # Basic cleanup of potential extra whitespace
271
 
272
  except RuntimeError as e:
273
- # Catch errors specifically from model not being loaded
274
  print(f"Language model not loaded error: {e}")
275
  raise HTTPException(status_code=503, detail=f"Language model not available: {e}")
276
  except Exception as e:
277
- # Catch other potential errors during generation
278
- print(f"Story generation failed: {e}") # Log the exception server-side
279
  raise HTTPException(status_code=500, detail=f"Story generation failed: {type(e).__name__}: {e}. Please check Space logs for details.")
280
 
281
 
282
  # Step 4: Translate the story if the target language is not English
283
  if language and language.lower() != "english":
284
  try:
285
- # Use GoogleTranslator with specified source and target languages
286
  translator = GoogleTranslator(source='english', target=language.lower())
287
  translated_story = translator.translate(story)
288
 
289
- # Check if translation was successful or returned None
290
  if translated_story is None or translated_story == "":
291
  print(f"Translation returned None or empty string for language: {language}")
292
- # If translation fails, return the English story with a warning message
293
  return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}
294
 
295
- story = translated_story # Use the translated story
296
 
297
  except InvalidSourceOrTargetLanguage:
298
  print(f"Invalid target language requested: {language}")
299
  raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
300
  except Exception as e:
301
- # Catch other potential translation errors (network, API issues, etc.)
302
- print(f"Translation failed for language {language}: {e}") # Log server-side
303
  raise HTTPException(status_code=500, detail=f"Translation failed: {type(e).__name__}: {e}")
304
 
305
-
306
- # Step 5: Return the final generated (and potentially translated) story
307
  return {"story": story}
308
 
309
  # --- Optional: Serve a simple HTML form for testing ---
310
- # To enable this, create a 'templates' directory and an 'index.html' file inside it.
311
- # Also uncomment the imports at the top related to HTMLResponse, StaticFiles, Jinja2Templates, Request.
312
  # from fastapi import Request
313
  # from fastapi.templating import Jinja2Templates
314
  # from fastapi.staticfiles import StaticFiles
315
  # templates = Jinja2Templates(directory="templates")
316
  # app.mount("/static", StaticFiles(directory="static"), name="static")
317
-
318
  # @app.get("/", response_class=HTMLResponse)
319
  # async def read_root(request: Request):
320
- # # Example index.html structure for a simple form:
321
- # # <!DOCTYPE html>
322
- # # <html>
323
- # # <head><title>Story Generator</title></head>
324
- # # <body>
325
- # # <h1>Generate a Story from an Image</h1>
326
- # # <form action="/generate-story/" method="post" enctype="multipart/form-data">
327
- # # <input type="file" name="image_file" accept="image/*" required><br><br>
328
- # # Target Language (e.g., english, french, spanish): <input type="text" name="language" value="english"><br><br>
329
- # # <button type="submit">Generate Story</button>
330
- # # </form>
331
- # # </body>
332
- # # </html>
333
  # return templates.TemplateResponse("index.html", {"request": request})
 
5
  # from fastapi.templating import Jinja2Templates
6
  # from fastapi.responses import FileResponse
7
 
8
+ # Removed 'requests'
9
  # import requests
10
+ import base64 # Keep if needed elsewhere
11
  import os
12
  import random
13
+ from typing import IO
14
 
15
  # Import necessary classes from transformers
16
  import torch
 
19
  # Import the Gradio Client and handle_file
20
  from gradio_client import Client, handle_file
21
 
22
+ # Import necessary modules for temporary file handling
23
+ import tempfile
24
+ import shutil # Useful for potential cleanup, although os.remove is sufficient here
25
+
26
+
27
  from deep_translator import GoogleTranslator
28
  from deep_translator.exceptions import InvalidSourceOrTargetLanguage
29
 
 
31
  app = FastAPI()
32
 
33
  # --- Hugging Face Language Model Setup (Local Inference) ---
34
+ # Using TinyLlama 1.1B Chat as the example
 
 
 
35
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
36
  tokenizer = None
37
  model = None
 
49
  print(f"Loading language model: {model_name}...")
50
  try:
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
52
  if tokenizer.pad_token is None:
53
  tokenizer.pad_token = tokenizer.eos_token
54
 
 
55
  model = AutoModelForCausalLM.from_pretrained(
56
  model_name,
57
  torch_dtype=torch.float16, # Use float16 precision
 
58
  )
 
59
 
60
  print(f"Language model {model_name} loaded successfully.")
61
  except Exception as e:
62
  print(f"Error loading language model {model_name}: {e}")
 
 
63
  tokenizer = None
64
  model = None
65
 
 
69
  global caption_client
70
  print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...")
71
  try:
72
+ # If the target Space is private or requires authentication,
73
+ # uncomment and set HF_TOKEN as a Space Secret.
74
  # HF_TOKEN = os.environ.get("HF_TOKEN")
75
  # if HF_TOKEN:
76
  # caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN)
77
  # else:
78
  # caption_client = Client(CAPTION_SPACE_URL)
79
 
80
+ # Assuming the caption space is public
81
  caption_client = Client(CAPTION_SPACE_URL)
82
  print("Gradio client initialized successfully.")
83
  except Exception as e:
84
  print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}")
 
85
  caption_client = None
86
 
87
 
88
  # Load models and initialize clients when the app starts
89
  @app.on_event("startup")
90
  async def startup_event():
 
91
  load_language_model()
 
92
  initialize_caption_client()
93
 
94
 
95
+ # --- Image Captioning Function (Using gradio_client and temporary file) ---
96
  def generate_image_caption(image_file: UploadFile):
97
  """
98
  Generates a caption for the uploaded image using the external Gradio Space API.
99
+ Saves the uploaded file to a temporary file and uses its path with handle_file.
100
  """
101
  if caption_client is None:
102
  error_msg = "Gradio caption client not initialized. Cannot generate caption."
103
+ print(error_msg)
104
+ return f"Error: {error_msg}"
105
+
106
+ temp_file_path = None # Initialize temporary file path variable
107
 
108
  try:
109
  print(f"Calling caption API /predict for file {image_file.filename}...")
110
 
111
+ # Read the content of the uploaded file
112
+ image_file.file.seek(0) # Ensure pointer is at the beginning
 
113
  image_bytes = image_file.file.read()
114
 
115
+ # Create a temporary file and write the bytes to it
116
+ # use delete=False so the file isn't deleted automatically when closed
117
+ # Use suffix to hint at the file type
118
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image_file.filename)[1] or '.jpg')
119
+ temp_file.write(image_bytes)
120
+ temp_file.close() # Close the handle so gradio_client can access it
121
+ temp_file_path = temp_file.name # Get the actual path string
122
+
123
+ print(f"Saved uploaded file temporarily to: {temp_file_path}")
124
+
125
+ # Use handle_file() with the path to the temporary file (a string)
126
+ # This aligns with the documentation "img filepath Required" and example handle_file('...')
127
+ prepared_input = handle_file(temp_file_path)
128
 
129
  # Call the predict method on the initialized client with the prepared input
130
  caption = caption_client.predict(img=prepared_input, api_name="/predict")
131
 
132
  print(f"Caption generated: {caption}")
133
  return caption # Return the successful caption string
134
+
135
  except Exception as e:
136
+ # Catch potential exceptions during the process
137
+ print(f"Error during caption generation API call: {e}")
 
138
  return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}"
139
 
140
+ finally:
141
+ # Clean up the temporary file
142
+ if temp_file_path and os.path.exists(temp_file_path):
143
+ print(f"Cleaning up temporary file: {temp_file_path}")
144
+ try:
145
+ os.remove(temp_file_path)
146
+ except OSError as e:
147
+ print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors
148
+
149
 
150
+ # --- Language Model Story Generation Function (TinyLlama) ---
 
151
  def generate_story_tinyllama(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
152
  """
153
  Generates text using the loaded TinyLlama model based on the prompt.
 
156
  if tokenizer is None or model is None:
157
  raise RuntimeError("Language model and tokenizer not loaded. Cannot generate story.")
158
 
 
159
  messages = [
160
  {"role": "user", "content": prompt_text}
 
161
  ]
162
 
 
163
  try:
164
  input_text = tokenizer.apply_chat_template(
165
  messages,
166
+ tokenize=False,
167
+ add_generation_prompt=True
168
  )
169
  except AttributeError:
170
+ print("Warning: apply_chat_template not found. Using basic prompt formatting.")
171
+ input_text = f"<s>[INST] {prompt_text} [/INST]"
 
172
 
173
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
 
 
174
 
 
 
 
 
175
  generate_ids = model.generate(
176
  inputs.input_ids,
177
+ max_new_tokens=max_new_tokens,
178
+ do_sample=True,
179
+ temperature=temperature,
180
+ top_p=top_p,
181
+ top_k=top_k,
182
+ pad_token_id=tokenizer.pad_token_id,
 
183
  )
184
 
 
 
 
185
  generated_text = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
186
+ return generated_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  # --- FastAPI Endpoint for Story Generation ---
189
  @app.post("/generate-story/")
190
  async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
191
+ # Choose a random theme
 
 
192
  story_theme = random.choice([
193
+ 'an adventurous journey', 'a mysterious encounter', 'a heroic quest',
194
+ 'a magical adventure', 'a thrilling escape', 'an unexpected discovery',
195
+ 'a dangerous mission', 'a romantic escapade', 'an epic battle',
 
 
 
 
 
 
196
  'a journey into the unknown'
197
  ])
198
 
199
  # Step 1: Get image caption using the external API via gradio_client
200
+ # Pass the UploadFile object to the captioning function
201
  caption = generate_image_caption(image_file)
202
 
203
  # Check if caption generation failed
204
+ if caption.startswith("Error:"):
205
+ print(f"Caption generation failed: {caption}")
 
206
  raise HTTPException(status_code=500, detail=caption)
207
 
208
  # Step 2: Construct the prompt for the language model
 
209
  prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
210
 
211
  # Step 3: Generate the story using the local language model
 
213
  # Call the appropriate story generation function (TinyLlama in this case)
214
  story = generate_story_tinyllama(
215
  prompt_text,
216
+ max_new_tokens=300,
217
+ temperature=0.7,
218
  top_p=0.9,
219
  top_k=50
220
  )
221
+ story = story.strip()
222
 
223
  except RuntimeError as e:
 
224
  print(f"Language model not loaded error: {e}")
225
  raise HTTPException(status_code=503, detail=f"Language model not available: {e}")
226
  except Exception as e:
227
+ print(f"Story generation failed: {e}")
 
228
  raise HTTPException(status_code=500, detail=f"Story generation failed: {type(e).__name__}: {e}. Please check Space logs for details.")
229
 
230
 
231
  # Step 4: Translate the story if the target language is not English
232
  if language and language.lower() != "english":
233
  try:
 
234
  translator = GoogleTranslator(source='english', target=language.lower())
235
  translated_story = translator.translate(story)
236
 
 
237
  if translated_story is None or translated_story == "":
238
  print(f"Translation returned None or empty string for language: {language}")
 
239
  return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}
240
 
241
+ story = translated_story
242
 
243
  except InvalidSourceOrTargetLanguage:
244
  print(f"Invalid target language requested: {language}")
245
  raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
246
  except Exception as e:
247
+ print(f"Translation failed for language {language}: {e}")
 
248
  raise HTTPException(status_code=500, detail=f"Translation failed: {type(e).__name__}: {e}")
249
 
250
+ # Step 5: Return the final story
 
251
  return {"story": story}
252
 
253
  # --- Optional: Serve a simple HTML form for testing ---
254
+ # Requires uncommenting imports above and creating 'templates/index.html'
 
255
  # from fastapi import Request
256
  # from fastapi.templating import Jinja2Templates
257
  # from fastapi.staticfiles import StaticFiles
258
  # templates = Jinja2Templates(directory="templates")
259
  # app.mount("/static", StaticFiles(directory="static"), name="static")
 
260
  # @app.get("/", response_class=HTMLResponse)
261
  # async def read_root(request: Request):
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  # return templates.TemplateResponse("index.html", {"request": request})