ikraamkb commited on
Commit
8b84edc
Β·
verified Β·
1 Parent(s): 4f031a5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +58 -24
main.py CHANGED
@@ -1,38 +1,72 @@
1
- # main.py
2
- from fastapi import FastAPI, Request, UploadFile, Form
3
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
6
- import shutil
7
- import os
8
- import app
9
- import appImage
10
 
11
  app = FastAPI()
12
 
13
- # Mount static and templates
14
- app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
 
 
 
 
 
 
15
  app.mount("/resources", StaticFiles(directory="resources"), name="resources")
 
16
 
 
17
  templates = Jinja2Templates(directory="templates")
18
 
19
- # Serve your main home page
20
  @app.get("/", response_class=HTMLResponse)
21
  async def serve_home(request: Request):
22
  return templates.TemplateResponse("home.html", {"request": request})
23
 
24
- # Route to handle prediction requests (image or doc)
25
  @app.post("/predict")
26
- async def predict(file: UploadFile = Form(...), question: str = Form(...)):
27
- ext = file.filename.split(".")[-1].lower()
28
- with open(file.filename, "wb") as buffer:
29
- shutil.copyfileobj(file.file, buffer)
30
-
31
- if ext in ["pdf", "docx", "pptx", "xlsx"]:
32
- answer, audio = app.answer_question_from_doc(file=buffer, question=question)
33
- else:
34
- image = appImage.Image.open(file.filename)
35
- answer, audio = appImage.answer_question_from_image(image, question)
36
-
37
- os.remove(file.filename)
38
- return JSONResponse({"answer": answer, "audio": audio})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
6
+ import shutil, os
7
+ from tempfile import gettempdir
 
 
8
 
9
  app = FastAPI()
10
 
11
+ # βœ… CORS to allow frontend access
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # βœ… Static assets
21
  app.mount("/resources", StaticFiles(directory="resources"), name="resources")
22
+ app.mount("/static", StaticFiles(directory="static"), name="static")
23
 
24
+ # βœ… Jinja2 Templates
25
  templates = Jinja2Templates(directory="templates")
26
 
27
+ # βœ… Serve Homepage
28
  @app.get("/", response_class=HTMLResponse)
29
  async def serve_home(request: Request):
30
  return templates.TemplateResponse("home.html", {"request": request})
31
 
32
+ # βœ… Predict endpoint (handles image + document)
33
  @app.post("/predict")
34
+ async def predict(question: str = Form(...), file: UploadFile = Form(...)):
35
+ try:
36
+ temp_path = f"temp_{file.filename}"
37
+ with open(temp_path, "wb") as f:
38
+ shutil.copyfileobj(file.file, f)
39
+
40
+ is_image = file.content_type.startswith("image/")
41
+
42
+ if is_image:
43
+ from appImage import answer_question_from_image
44
+ from PIL import Image
45
+ image = Image.open(temp_path).convert("RGB")
46
+ answer, audio_path = answer_question_from_image(image, question)
47
+
48
+ else:
49
+ from app import answer_question_from_doc
50
+ class NamedFile:
51
+ def __init__(self, name): self.filename = name
52
+ def read(self): return open(self.filename, "rb").read()
53
+ answer, audio_path = answer_question_from_doc(NamedFile(temp_path), question)
54
+
55
+ os.remove(temp_path)
56
+
57
+ if audio_path and os.path.exists(audio_path):
58
+ return JSONResponse({
59
+ "answer": answer,
60
+ "audio": f"/audio/{os.path.basename(audio_path)}"
61
+ })
62
+ else:
63
+ return JSONResponse({"answer": answer})
64
+
65
+ except Exception as e:
66
+ return JSONResponse({"error": str(e)}, status_code=500)
67
+
68
+ # βœ… Serve audio
69
+ @app.get("/audio/{filename}")
70
+ async def get_audio(filename: str):
71
+ filepath = os.path.join(gettempdir(), filename)
72
+ return FileResponse(filepath, media_type="audio/mpeg")