ikraamkb commited on
Commit
13eeded
Β·
verified Β·
1 Parent(s): 09a2e2c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +49 -32
main.py CHANGED
@@ -1,16 +1,14 @@
1
- # main.py
2
- from fastapi import FastAPI, UploadFile, File, Form, Request
3
  from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
6
- from fastapi.middleware.cors import CORSMiddleware
7
- import tempfile, os
8
-
9
- from app import summarize_document
10
- from appImage import caption_image
11
 
12
  app = FastAPI()
13
 
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
@@ -19,37 +17,56 @@ app.add_middleware(
19
  allow_headers=["*"],
20
  )
21
 
22
- app.mount("/static", StaticFiles(directory="static"), name="static")
23
  app.mount("/resources", StaticFiles(directory="resources"), name="resources")
 
 
 
24
  templates = Jinja2Templates(directory="templates")
25
 
26
- # Serve homepage
27
  @app.get("/", response_class=HTMLResponse)
28
  async def serve_home(request: Request):
29
- return templates.TemplateResponse("HomeS.html", {"request": request})
30
 
31
- # Document summarization endpoint
32
- @app.post("/summarize/")
33
- async def summarize(file: UploadFile = File(...), length: str = Form("medium")):
34
  try:
35
- result = await summarize_document(file, length)
36
- return JSONResponse(result)
37
- except Exception as e:
38
- return JSONResponse({"error": f"Summarization failed: {str(e)}"}, status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Image captioning endpoint
41
- @app.post("/imagecaption/")
42
- async def caption(file: UploadFile = File(...)):
43
- try:
44
- result = await caption_image(file)
45
- return JSONResponse(result)
46
  except Exception as e:
47
- return JSONResponse({"error": f"Image captioning failed: {str(e)}"}, status_code=500)
48
-
49
- # Serve audio/pdf generated
50
- @app.get("/files/{filename}")
51
- async def serve_file(filename: str):
52
- filepath = os.path.join(tempfile.gettempdir(), filename)
53
- if os.path.exists(filepath):
54
- return FileResponse(filepath)
55
- return JSONResponse({"error": "File not found"}, status_code=404)
 
1
+ from fastapi import FastAPI, UploadFile, Form, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
6
+ import shutil, os
7
+ from tempfile import gettempdir
 
 
 
8
 
9
  app = FastAPI()
10
 
11
+ # βœ… CORS to allow frontend access
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
 
17
  allow_headers=["*"],
18
  )
19
 
20
+ # βœ… Static assets
21
  app.mount("/resources", StaticFiles(directory="resources"), name="resources")
22
+ app.mount("/static", StaticFiles(directory="static"), name="static")
23
+
24
+ # βœ… Jinja2 Templates
25
  templates = Jinja2Templates(directory="templates")
26
 
27
+ # βœ… Serve Homepage
28
  @app.get("/", response_class=HTMLResponse)
29
  async def serve_home(request: Request):
30
+ return templates.TemplateResponse("home.html", {"request": request})
31
 
32
+ # βœ… Predict endpoint (handles image + document)
33
+ @app.post("/predict")
34
+ async def predict(question: str = Form(...), file: UploadFile = Form(...)):
35
  try:
36
+ temp_path = f"temp_{file.filename}"
37
+ with open(temp_path, "wb") as f:
38
+ shutil.copyfileobj(file.file, f)
39
+
40
+ is_image = file.content_type.startswith("image/")
41
+
42
+ if is_image:
43
+ from appImage import answer_question_from_image
44
+ from PIL import Image
45
+ image = Image.open(temp_path).convert("RGB")
46
+ answer, audio_path = answer_question_from_image(image, question)
47
+
48
+ else:
49
+ from app import answer_question_from_doc
50
+ class NamedFile:
51
+ def __init__(self, name): self.filename = name
52
+ def read(self): return open(self.filename, "rb").read()
53
+ answer, audio_path = answer_question_from_doc(NamedFile(temp_path), question)
54
+
55
+ os.remove(temp_path)
56
+
57
+ if audio_path and os.path.exists(audio_path):
58
+ return JSONResponse({
59
+ "answer": answer,
60
+ "audio": f"/audio/{os.path.basename(audio_path)}"
61
+ })
62
+ else:
63
+ return JSONResponse({"answer": answer})
64
 
 
 
 
 
 
 
65
  except Exception as e:
66
+ return JSONResponse({"error": str(e)}, status_code=500)
67
+
68
+ # βœ… Serve audio
69
+ @app.get("/audio/{filename}")
70
+ async def get_audio(filename: str):
71
+ filepath = os.path.join(gettempdir(), filename)
72
+ return FileResponse(filepath, media_type="audio/mpeg")