smartdocai / app.py
malaknihed's picture
Update app.py
c4767a5 verified
raw
history blame
14.6 kB
# 🔧 FastAPI & middlewares
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse, RedirectResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
# 🧠 Transformers (NLP)
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
MarianMTModel,
MarianTokenizer,
M2M100ForConditionalGeneration,
M2M100Tokenizer
)
# 📄 Lecture de fichiers
from PyPDF2 import PdfReader
from pdfminer.high_level import extract_text
from docx import Document
import docx2txt
from pptx import Presentation
import openpyxl # Pour fichiers Excel (.xlsx)
import fitz # PyMuPDF
# 🖼️ Images
from PIL import Image
# 📊 Visualisation (Matplotlib/Seaborn)
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
# 🧰 Divers
import os
import shutil
import io
import logging
import re
import pandas as pd
# Configuration du logging
logging.basicConfig(level=logging.INFO)
app = FastAPI()
# Configuration CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model_name = "facebook/m2m100_418M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Fonction pour extraire le texte
def extract_text_from_pdf(file):
doc = fitz.open(stream=file.file.read(), filetype="pdf")
return "\n".join([page.get_text() for page in doc]).strip()
def extract_text_from_docx(file):
doc = Document(io.BytesIO(file.file.read()))
return "\n".join([para.text for para in doc.paragraphs]).strip()
def extract_text_from_pptx(file):
prs = Presentation(io.BytesIO(file.file.read()))
return "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]).strip()
def extract_text_from_excel(file):
wb = openpyxl.load_workbook(io.BytesIO(file.file.read()), data_only=True)
text = [str(cell) for sheet in wb.worksheets for row in sheet.iter_rows(values_only=True) for cell in row if cell]
return "\n".join(text).strip()
@app.post("/translate/")
async def translate_document(file: UploadFile = File(...), target_lang: str = Form(...)):
"""API pour traduire un document."""
try:
logging.info(f"📥 Fichier reçu : {file.filename}")
logging.info(f"🌍 Langue cible reçue : {target_lang}")
if model is None or tokenizer is None:
return JSONResponse(status_code=500, content={"error": "Modèle de traduction non chargé"})
# Extraction du texte
if file.filename.endswith(".pdf"):
text = extract_text_from_pdf(file)
elif file.filename.endswith(".docx"):
text = extract_text_from_docx(file)
elif file.filename.endswith(".pptx"):
text = extract_text_from_pptx(file)
elif file.filename.endswith(".xlsx"):
text = extract_text_from_excel(file)
else:
return JSONResponse(status_code=400, content={"error": "Format non supporté"})
logging.info(f"📜 Texte extrait : {text[:50]}...")
if not text:
return JSONResponse(status_code=400, content={"error": "Aucun texte trouvé dans le document"})
# Vérifier si la langue cible est supportée
target_lang_id = tokenizer.get_lang_id(target_lang)
if target_lang_id is None:
return JSONResponse(
status_code=400,
content={"error": f"Langue cible '{target_lang}' non supportée. Langues disponibles : {list(tokenizer.lang_code_to_id.keys())}"}
)
# Traduction
tokenizer.src_lang = "fr"
encoded_text = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
logging.info(f"🔍 ID de la langue cible : {target_lang_id}")
generated_tokens = model.generate(**encoded_text, forced_bos_token_id=target_lang_id)
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
logging.info(f"✅ Traduction réussie : {translated_text[:50]}...")
return {"translated_text": translated_text}
except Exception as e:
logging.error(f"❌ Erreur lors de la traduction : {e}")
return JSONResponse(status_code=500, content={"error": "Échec de la traduction"})
# Charger le modèle pour la génération de code
codegen_model_name = "Salesforce/codegen-350M-mono"
device = "cuda" if torch.cuda.is_available() else "cpu"
codegen_tokenizer = AutoTokenizer.from_pretrained(codegen_model_name)
codegen_model = AutoModelForCausalLM.from_pretrained(codegen_model_name).to(device)
VALID_PLOTS = {"histplot", "scatterplot", "barplot", "lineplot", "boxplot"}
@app.post("/generate_viz/")
async def generate_viz(file: UploadFile = File(...), query: str = Form(...)):
try:
if query not in VALID_PLOTS:
return {"error": f"Type de graphique invalide. Choisissez parmi : {', '.join(VALID_PLOTS)}"}
df = pd.read_excel(file.file)
numeric_cols = df.select_dtypes(include=["number"]).columns
if len(numeric_cols) < 2:
return {"error": "Le fichier doit contenir au moins deux colonnes numériques."}
x_col, y_col = numeric_cols[:2]
# Contraintes spécifiques pour éviter l'erreur avec histplot
if query == "histplot":
prompt_y = ""
else:
prompt_y = f', y="{y_col}"'
# Générer l'invite pour le modèle
prompt = f"""
### Génère uniquement du code Python fonctionnel pour tracer un {query} avec Matplotlib et Seaborn ###
# Contraintes :
# - Utilise 'df' sans recréer de nouvelles données
# - Axe X : '{x_col}'
# - Enregistre le graphique sous 'plot.png'
# - Ne génère que du code Python valide, sans texte explicatif
# Contraintes spécifiques pour sns.histplot :
# - N'inclut pas "y=" car histplot ne supporte qu'un axe
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(8,6))
sns.{query}(data=df, x="{x_col}"{prompt_y})
plt.savefig("plot.png")
plt.close()
"""
# Génération du code
inputs = codegen_tokenizer(prompt, return_tensors="pt").to(device)
outputs = codegen_model.generate(**inputs, max_new_tokens=120, pad_token_id=codegen_tokenizer.eos_token_id)
generated_code = codegen_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Nettoyage du code
generated_code = re.sub(r"(import matplotlib.pyplot as plt\nimport seaborn as sns\n)+", "import matplotlib.pyplot as plt\nimport seaborn as sns\n", generated_code)
if generated_code.strip().endswith("sns."):
generated_code = generated_code.rsplit("\n", 1)[0] # Supprime la dernière ligne incomplète
print("🔹 Code généré par l'IA :\n", generated_code)
# Vérification syntaxique avant exécution
try:
compile(generated_code, "<string>", "exec")
except SyntaxError as e:
return {"error": f"Erreur de syntaxe détectée : {e}\nCode généré :\n{generated_code}"}
# Vérification des données
print(df.head()) # Affiche les premières lignes du dataframe
print(df.dtypes) # Vérifie les types de colonnes
print(f"Colonne '{x_col}' - Valeurs uniques:", df[x_col].unique())
if df.empty or x_col not in df.columns or df[x_col].isnull().all():
return {"error": f"La colonne '{x_col}' est absente ou ne contient pas de données valides."}
# Exécution du code généré
exec_env = {"df": df, "plt": plt, "sns": sns, "pd": pd}
exec(generated_code, exec_env)
# Vérification de l'image générée
img_path = "plot.png"
if not os.path.exists(img_path):
return {"error": "Le fichier plot.png n'a pas été généré."}
if os.path.getsize(img_path) == 0:
return {"error": "Le fichier plot.png est vide."}
plt.close()
return FileResponse(img_path, media_type="image/png")
except Exception as e:
return {"error": f"Erreur lors de la génération du graphique : {str(e)}"}
# Charger le modèle de résumé
summarizer = None
try:
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
logging.info("✅ Modèle de résumé chargé avec succès !")
except Exception as e:
logging.error(f"❌ Erreur chargement modèle résumé : {e}")
try:
image_captioning = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
logging.info("✅ Modèle d'image chargé avec succès !")
except Exception as e:
image_captioning = None
logging.error(f"❌ Erreur chargement modèle image : {e}")
# Fonction pour extraire le texte d'un fichier Word
def extract_text_from_docx(docx_file):
doc = Document(BytesIO(docx_file))
text = "\n".join([para.text for para in doc.paragraphs])
return text
# Fonction pour extraire le texte d'un fichier Excel
def extract_text_from_excel(xlsx_file):
# Utiliser pandas pour lire le fichier Excel
df = pd.read_excel(BytesIO(xlsx_file))
text = df.to_string(index=False)
return text
# Fonction pour extraire le texte d'un fichier PowerPoint
def extract_text_from_pptx(pptx_file):
presentation = Presentation(BytesIO(pptx_file))
text = ""
for slide in presentation.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
return text
# Endpoint pour la fonctionnalité de résumé
@app.post("/summarize/")
async def summarize(file: UploadFile = File(...)):
# Si le modèle n'est pas encore chargé, retourner un message indiquant que le modèle est en train de se charger
if summarizer is None:
return {"message": "Le modèle est en cours de chargement, veuillez patienter..."}
# Extraire le contenu du fichier téléchargé
contents = await file.read()
# Identifier le type de fichier et extraire le texte
if file.filename.endswith(".pdf"):
text = extract_text(BytesIO(contents))
elif file.filename.endswith(".docx"):
text = extract_text_from_docx(contents)
elif file.filename.endswith(".xls") or file.filename.endswith(".xlsx"):
text = extract_text_from_excel(contents)
elif file.filename.endswith(".pptx") or file.filename.endswith(".ppt"):
text = extract_text_from_pptx(contents)
else:
return {"summary": "Résumé non disponible pour ce format de fichier."}
# Si un modèle de résumé est chargé, effectuer le résumé
try:
if summarizer:
summary = summarizer(text[:1024]) # Limiter la taille d'entrée pour le modèle
summary_text = summary[0]['summary_text']
else:
summary_text = "❌ Modèle de résumé non disponible."
except Exception as e:
summary_text = f"❌ Erreur lors de la génération du résumé : {e}"
# Retourner le résumé généré
return {"summary": summary_text}
@app.post("/image-caption/")
async def caption_image(file: UploadFile = File(...)):
if image_captioning is None:
return JSONResponse(content={"error": "Le modèle de captioning n'est pas disponible."}, status_code=500)
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
result = image_captioning(image)
caption = result[0]['generated_text']
return {"caption": caption}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
try:
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
logging.info("✅ Modèle QA Texte chargé avec succès !")
except Exception as e:
qa_pipeline = None
logging.error(f"❌ Erreur chargement modèle QA Texte : {e}")
try:
image_qa_pipeline = pipeline("visual-question-answering", model="Salesforce/blip-vqa-base")
logging.info("✅ Modèle QA Image chargé avec succès !")
except Exception as e:
image_qa_pipeline = None
logging.error(f"❌ Erreur chargement modèle QA Image : {e}")
@app.post("/doc-qa/")
async def doc_question_answer(file: UploadFile = File(...), question: str = Form(...)):
if qa_pipeline is None:
return JSONResponse(content={"error": "Modèle indisponible."}, status_code=500)
try:
contents = await file.read()
filename = file.filename.lower()
if filename.endswith(".docx"):
with open("temp.docx", "wb") as f:
f.write(contents)
context = docx2txt.process("temp.docx")
elif filename.endswith((".xlsx", ".xls")):
df = pd.read_excel(BytesIO(contents))
context = df.to_string(index=False)
elif filename.endswith(".pptx"):
presentation = Presentation(BytesIO(contents))
context = ""
for slide in presentation.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
context += shape.text + "\n"
elif filename.endswith(".pdf"):
context = extract_text(BytesIO(contents))
else:
return JSONResponse(content={"error": "Format non supporté."}, status_code=400)
result = qa_pipeline(question=question, context=context)
return {"answer": result["answer"]}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/image-qa/")
async def image_qa(file: UploadFile = File(...), question: str = Form(...)):
if image_qa_pipeline is None:
return JSONResponse(content={"error": "Le modèle n'est pas disponible."}, status_code=500)
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
result = image_qa_pipeline(image=image, question=question)
answer = result[0]['answer']
return {"answer": answer}
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Servir les fichiers statiques (HTML, CSS, JS)
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
async def root():
return RedirectResponse(url="/static/principal.html")