benkada commited on
Commit
6581e65
·
verified ·
1 Parent(s): 7de75e8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +90 -117
main.py CHANGED
@@ -1,22 +1,26 @@
1
- import os
2
  from fastapi import FastAPI, UploadFile, File, Form
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.responses import JSONResponse, HTMLResponse
 
5
  from huggingface_hub import InferenceClient
6
  from PyPDF2 import PdfReader
7
  from docx import Document
8
  from PIL import Image
9
- import io
10
  from io import BytesIO
11
- import requests
12
 
13
- # Remplace ce token par le tien de manière sécurisée (variable d'environnement recommandée en production)
14
- HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
 
 
 
15
 
16
- # Initialisation de l'app FastAPI
17
- app = FastAPI()
 
 
 
18
 
19
- # Autoriser les requêtes Cross-Origin
20
  app.add_middleware(
21
  CORSMiddleware,
22
  allow_origins=["*"],
@@ -25,142 +29,111 @@ app.add_middleware(
25
  allow_headers=["*"],
26
  )
27
 
28
- # Initialisation des clients Hugging Face avec authentification
29
- summary_client = InferenceClient(model="facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
30
- qa_client = InferenceClient(model="deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
31
- image_caption_client = InferenceClient(model="nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
 
 
 
 
 
 
 
 
 
32
 
33
- # Extraction du texte des fichiers
34
  def extract_text_from_pdf(content: bytes) -> str:
35
- text = ""
36
  reader = PdfReader(io.BytesIO(content))
37
- for page in reader.pages:
38
- if page.extract_text():
39
- text += page.extract_text() + "\n"
40
- return text.strip()
41
 
42
  def extract_text_from_docx(content: bytes) -> str:
43
- text = ""
44
  doc = Document(io.BytesIO(content))
45
- for para in doc.paragraphs:
46
- text += para.text + "\n"
47
- return text.strip()
48
 
49
  def process_uploaded_file(file: UploadFile) -> str:
50
- content = file.file.read()
51
- extension = file.filename.split('.')[-1].lower()
52
-
53
  if extension == "pdf":
54
  return extract_text_from_pdf(content)
55
- elif extension == "docx":
56
  return extract_text_from_docx(content)
57
- elif extension == "txt":
58
  return content.decode("utf-8").strip()
59
- else:
60
- raise ValueError("Type de fichier non supporté")
 
 
 
61
 
62
- # Point d'entrée HTML
63
  @app.get("/", response_class=HTMLResponse)
64
- async def serve_homepage():
65
- with open("index.html", "r", encoding="utf-8") as f:
66
- return HTMLResponse(content=f.read(), status_code=200)
 
 
67
 
68
- # Résumé
69
- @app.post("/analyze")
70
- async def analyze_file(file: UploadFile = File(...)):
71
  try:
72
  text = process_uploaded_file(file)
73
-
74
  if len(text) < 20:
75
- return {"summary": "Document trop court pour être résumé."}
76
-
77
- summary = summary_client.summarization(text[:3000])
78
- return {"summary": summary}
 
79
 
80
- except Exception as e:
81
- return JSONResponse(status_code=500, content={"error": f"Erreur lors de l'analyse: {str(e)}"})
82
 
83
- # Question-Réponse
84
- @app.post("/ask")
85
- async def ask_question(file: UploadFile = File(...), question: str = Form(...)):
86
  try:
87
- # Determine if the file is an image
88
- content_type = file.content_type
89
- if content_type.startswith("image/"):
90
- image_bytes = await file.read()
91
- image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
92
- image_pil.thumbnail((1024, 1024))
93
-
94
- img_byte_arr = BytesIO()
95
- image_pil.save(img_byte_arr, format='JPEG')
96
- img_byte_arr = img_byte_arr.getvalue()
97
-
98
- # Generate image description
99
- result = image_caption_client.image_to_text(img_byte_arr)
100
- if isinstance(result, dict):
101
- context = result.get("generated_text") or result.get("caption") or ""
102
- elif isinstance(result, list) and len(result) > 0:
103
- context = result[0].get("generated_text", "")
104
- elif isinstance(result, str):
105
- context = result
106
- else:
107
- context = ""
108
-
109
  else:
110
- # Not an image, process as document
111
- text = process_uploaded_file(file)
112
- if len(text) < 20:
113
- return {"answer": "Document trop court pour répondre à la question."}
114
- context = text[:3000]
115
 
116
- if not context:
117
- return {"answer": "Aucune information disponible pour répondre à la question."}
118
-
119
- result = qa_client.question_answering(question=question, context=context)
120
- return {"answer": result.get("answer", "Aucune réponse trouvée.")}
121
 
122
- except Exception as e:
123
- return JSONResponse(status_code=500, content={"error": f"Erreur lors de la recherche de réponse: {str(e)}"})
124
-
125
- # Interprétation d'Image
126
- @app.post("/interpret_image")
127
- async def interpret_image(image: UploadFile = File(...)):
128
  try:
129
- # Lire l'image
130
- image_bytes = await image.read()
131
-
132
- # Ouvrir l'image avec PIL
133
- image_pil = Image.open(io.BytesIO(image_bytes))
134
- image_pil = image_pil.convert("RGB")
135
- image_pil.thumbnail((1024, 1024))
136
-
137
- # Convertir en bytes (JPEG)
138
- img_byte_arr = BytesIO()
139
- image_pil.save(img_byte_arr, format='JPEG')
140
- img_byte_arr = img_byte_arr.getvalue()
141
-
142
- # Appeler le modèle
143
- result = image_caption_client.image_to_text(img_byte_arr)
144
-
145
- # 🔍 Affichage du résultat brut pour débogage
146
- print("Résultat brut du modèle image-to-text:", result)
147
-
148
- # Extraire la description si disponible
149
- if isinstance(result, dict):
150
- description = result.get("generated_text") or result.get("caption") or "Description non trouvée."
151
- elif isinstance(result, list) and len(result) > 0:
152
- description = result[0].get("generated_text", "Description non trouvée.")
153
- elif isinstance(result, str):
154
- description = result
155
  else:
156
- description = "Description non trouvée."
 
 
 
 
 
 
 
 
157
 
158
- return {"description": description}
 
 
159
 
160
- except Exception as e:
161
- return JSONResponse(status_code=500, content={"error": f"Erreur lors de l'interprétation de l'image: {str(e)}"})
 
162
 
163
- # Démarrage local
164
  if __name__ == "__main__":
165
  import uvicorn
166
- uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
1
+ import os, io
2
  from fastapi import FastAPI, UploadFile, File, Form
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
5
+ from fastapi.staticfiles import StaticFiles
6
  from huggingface_hub import InferenceClient
7
  from PyPDF2 import PdfReader
8
  from docx import Document
9
  from PIL import Image
 
10
  from io import BytesIO
 
11
 
12
+ # -----------------------------------------------------------------------------
13
+ # CONFIGURATION
14
+ # -----------------------------------------------------------------------------
15
+ HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") # injected as a secret
16
+ PORT = int(os.getenv("PORT", 7860)) # HF Spaces provides it
17
 
18
+ app = FastAPI(
19
+ title="AI‑Powered Web‑App API",
20
+ description="Backend endpoints for summarisation, captioning and QA",
21
+ version="1.1.0",
22
+ )
23
 
 
24
  app.add_middleware(
25
  CORSMiddleware,
26
  allow_origins=["*"],
 
29
  allow_headers=["*"],
30
  )
31
 
32
+ # Optional: serve static assets from /static (images, css, js)
33
+ app.mount("/static", StaticFiles(directory="static"), name="static")
34
+
35
+ # -----------------------------------------------------------------------------
36
+ # MODEL CLIENTS (remote HuggingFace Inference API)
37
+ # -----------------------------------------------------------------------------
38
+ summary_client = InferenceClient("facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
39
+ qa_client = InferenceClient("deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
40
+ image_caption_client = InferenceClient("nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
41
+
42
+ # -----------------------------------------------------------------------------
43
+ # UTILITY FUNCTIONS
44
+ # -----------------------------------------------------------------------------
45
 
 
46
  def extract_text_from_pdf(content: bytes) -> str:
 
47
  reader = PdfReader(io.BytesIO(content))
48
+ return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
 
 
 
49
 
50
  def extract_text_from_docx(content: bytes) -> str:
 
51
  doc = Document(io.BytesIO(content))
52
+ return "\n".join(p.text for p in doc.paragraphs).strip()
 
 
53
 
54
  def process_uploaded_file(file: UploadFile) -> str:
55
+ content = file.file.read()
56
+ extension = file.filename.split(".")[-1].lower()
 
57
  if extension == "pdf":
58
  return extract_text_from_pdf(content)
59
+ if extension == "docx":
60
  return extract_text_from_docx(content)
61
+ if extension == "txt":
62
  return content.decode("utf-8").strip()
63
+ raise ValueError("Unsupported file type")
64
+
65
+ # -----------------------------------------------------------------------------
66
+ # ROUTES
67
+ # -----------------------------------------------------------------------------
68
 
 
69
  @app.get("/", response_class=HTMLResponse)
70
+ async def serve_index():
71
+ """Send the frontend HTML."""
72
+ return FileResponse("index.html")
73
+
74
+ # ---------- Summarisation -----------------------------------------------------
75
 
76
+ @app.post("/api/summarize")
77
+ async def summarize_document(file: UploadFile = File(...)):
 
78
  try:
79
  text = process_uploaded_file(file)
 
80
  if len(text) < 20:
81
+ return {"result": "Document too short to summarise."}
82
+ summary_text = summary_client.summarization(text[:3000])
83
+ return {"result": str(summary_text)}
84
+ except Exception as exc:
85
+ return JSONResponse(status_code=500, content={"error": f"Analyse failure: {exc}"})
86
 
87
+ # ---------- Image Caption -----------------------------------------------------
 
88
 
89
+ @app.post("/api/caption")
90
+ async def caption_image(file: UploadFile = File(...)):
 
91
  try:
92
+ image_bytes = await file.read()
93
+ image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
94
+ image_pil.thumbnail((1024, 1024))
95
+ buf = BytesIO(); image_pil.save(buf, format="JPEG"); img = buf.getvalue()
96
+ result = image_caption_client.image_to_text(img)
97
+ if isinstance(result, dict):
98
+ caption = result.get("generated_text") or result.get("caption") or "No caption found."
99
+ elif isinstance(result, list):
100
+ caption = result[0].get("generated_text", "No caption found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  else:
102
+ caption = str(result)
103
+ return {"result": str(caption)}
104
+ except Exception as exc:
105
+ return JSONResponse(status_code=500, content={"error": f"Caption failure: {exc}"})
 
106
 
107
+ # ---------- Question Answering ----------------------------------------------
 
 
 
 
108
 
109
+ @app.post("/api/qa")
110
+ async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
 
 
 
 
111
  try:
112
+ if file.content_type.startswith("image/"):
113
+ image_bytes = await file.read()
114
+ pil = Image.open(io.BytesIO(image_bytes)).convert("RGB"); pil.thumbnail((1024, 1024))
115
+ buf = BytesIO(); pil.save(buf, format="JPEG"); img = buf.getvalue()
116
+ res = image_caption_client.image_to_text(img)
117
+ context = res.get("generated_text") if isinstance(res, dict) else str(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  else:
119
+ context = process_uploaded_file(file)[:3000]
120
+ if not context:
121
+ return {"result": "No context – cannot answer."}
122
+ answer = qa_client.question_answering(question=question, context=context)
123
+ return {"result": str(answer.get("answer", "No answer found."))}
124
+ except Exception as exc:
125
+ return JSONResponse(status_code=500, content={"error": f"QA failure: {exc}"})
126
+
127
+ # ---------- Health check ------------------------------------------------------
128
 
129
+ @app.get("/api/health")
130
+ async def health():
131
+ return {"status": "healthy", "hf_token_set": bool(HUGGINGFACE_TOKEN)}
132
 
133
+ # -----------------------------------------------------------------------------
134
+ # ENTRYPOINT
135
+ # -----------------------------------------------------------------------------
136
 
 
137
  if __name__ == "__main__":
138
  import uvicorn
139
+ uvicorn.run(app, host="0.0.0.0", port=PORT)