ikraamkb commited on
Commit
7a6dca4
·
verified ·
1 Parent(s): 81bb8d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -46
app.py CHANGED
@@ -103,108 +103,101 @@ async def get_docs(request: Request):
103
  from fastapi import FastAPI, Form, File, UploadFile
104
  from fastapi.responses import RedirectResponse
105
  from fastapi.staticfiles import StaticFiles
106
- from fastapi.middleware.cors import CORSMiddleware
107
- from fastapi.templating import Jinja2Templates
108
- from starlette.requests import Request
109
  from transformers import pipeline
110
  import os
111
  from PIL import Image
 
112
  import pdfplumber
113
  import docx
 
114
  import pytesseract
115
  from io import BytesIO
116
  import fitz # PyMuPDF
117
  import easyocr
 
 
118
 
119
- # Initialize the FastAPI app
120
  app = FastAPI()
121
 
122
- # Enable CORS for frontend communication
123
- app.add_middleware(
124
- CORSMiddleware,
125
- allow_origins=["*"],
126
- allow_credentials=True,
127
- allow_methods=["*"],
128
- allow_headers=["*"],
129
- )
130
-
131
- # Mount static files (if you have HTML/CSS/JS)
132
  app.mount("/static", StaticFiles(directory="static"), name="static")
133
 
134
- # Initialize transformer models
135
  qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
136
  image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")
137
 
138
- # Initialize OCR
139
  reader = easyocr.Reader(['en'])
140
 
141
- # Define templates for HTML pages
142
  templates = Jinja2Templates(directory="templates")
143
 
144
- # Ensure the temp directory exists
145
  temp_dir = "temp_files"
146
  os.makedirs(temp_dir, exist_ok=True)
147
 
148
- # Function to extract text from PDF
149
  def extract_pdf_text(file_path: str):
150
  with pdfplumber.open(file_path) as pdf:
151
- return "\n".join([page.extract_text() or "" for page in pdf.pages])
 
 
 
152
 
153
- # Function to extract text from DOCX
154
  def extract_docx_text(file_path: str):
155
  doc = docx.Document(file_path)
156
- return "\n".join([para.text for para in doc.paragraphs])
 
157
 
158
- # Function to extract text from PPTX
159
  def extract_pptx_text(file_path: str):
160
  from pptx import Presentation
161
  prs = Presentation(file_path)
162
- return "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
 
163
 
164
- # Function to extract text from images
165
  def extract_text_from_image(image: Image):
166
  return pytesseract.image_to_string(image)
167
 
168
- # Redirect home to custom frontend page
169
  @app.get("/")
170
  def home():
171
- return RedirectResponse(url="/app-ui")
172
-
173
- # Serve HTML interface (instead of showing FastAPI docs)
174
- @app.get("/app-ui")
175
- async def get_ui(request: Request):
176
- return templates.TemplateResponse("index.html", {"request": request})
177
 
178
- # New endpoint for document-based question answering
179
- @app.post("/qa-docs") # 🚨 Changed from `/docs` to `/qa-docs`
180
  async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
181
- file_ext = file.filename.split(".")[-1].lower()
182
  file_path = os.path.join(temp_dir, file.filename)
183
-
184
  with open(file_path, "wb") as f:
185
  f.write(await file.read())
186
 
187
- if file_ext == "pdf":
188
  text = extract_pdf_text(file_path)
189
- elif file_ext == "docx":
190
  text = extract_docx_text(file_path)
191
- elif file_ext == "pptx":
192
  text = extract_pptx_text(file_path)
193
  else:
194
  return {"error": "Unsupported file format"}
195
 
196
  qa_result = qa_pipeline(question=question, context=text)
197
- return {"answer": qa_result['answer']}
198
 
199
- # New endpoint for image-based question answering
200
- @app.post("/qa-images") # 🚨 Changed from `/images` to `/qa-images`
201
  async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
202
  image = Image.open(BytesIO(await image_file.read()))
203
-
204
- # Extract text from image using OCR
205
  image_text = extract_text_from_image(image)
206
 
207
- # Get answer from the VQA model
208
  image_qa_result = image_qa_pipeline({"image": image, "question": question})
209
 
210
- return {"answer": image_qa_result[0]['answer'], "image_text": image_text}
 
 
 
 
 
 
103
  from fastapi import FastAPI, Form, File, UploadFile
104
  from fastapi.responses import RedirectResponse
105
  from fastapi.staticfiles import StaticFiles
106
+ from pydantic import BaseModel
 
 
107
  from transformers import pipeline
108
  import os
109
  from PIL import Image
110
+ import io
111
  import pdfplumber
112
  import docx
113
+ import openpyxl
114
  import pytesseract
115
  from io import BytesIO
116
  import fitz # PyMuPDF
117
  import easyocr
118
+ from fastapi.templating import Jinja2Templates
119
+ from starlette.requests import Request
120
 
121
+ # Initialize the app
122
  app = FastAPI()
123
 
124
+ # Mount the static directory to serve HTML, CSS, JS files
 
 
 
 
 
 
 
 
 
125
  app.mount("/static", StaticFiles(directory="static"), name="static")
126
 
127
+ # Initialize transformers pipelines
128
  qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
129
  image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")
130
 
131
+ # Initialize EasyOCR for image-based text extraction
132
  reader = easyocr.Reader(['en'])
133
 
134
+ # Define a template for rendering HTML
135
  templates = Jinja2Templates(directory="templates")
136
 
137
+ # Ensure temp_files directory exists
138
  temp_dir = "temp_files"
139
  os.makedirs(temp_dir, exist_ok=True)
140
 
141
+ # Function to process PDFs
142
  def extract_pdf_text(file_path: str):
143
  with pdfplumber.open(file_path) as pdf:
144
+ text = ""
145
+ for page in pdf.pages:
146
+ text += page.extract_text()
147
+ return text
148
 
149
+ # Function to process DOCX files
150
  def extract_docx_text(file_path: str):
151
  doc = docx.Document(file_path)
152
+ text = "\n".join([para.text for para in doc.paragraphs])
153
+ return text
154
 
155
+ # Function to process PPTX files
156
  def extract_pptx_text(file_path: str):
157
  from pptx import Presentation
158
  prs = Presentation(file_path)
159
+ text = "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
160
+ return text
161
 
162
+ # Function to extract text from images using OCR
163
  def extract_text_from_image(image: Image):
164
  return pytesseract.image_to_string(image)
165
 
166
+ # Home route
167
  @app.get("/")
168
  def home():
169
+ return RedirectResponse(url="/docs")
 
 
 
 
 
170
 
171
+ # Function to answer questions based on document content
172
+ @app.post("/question-answering-doc")
173
  async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
 
174
  file_path = os.path.join(temp_dir, file.filename)
 
175
  with open(file_path, "wb") as f:
176
  f.write(await file.read())
177
 
178
+ if file.filename.endswith(".pdf"):
179
  text = extract_pdf_text(file_path)
180
+ elif file.filename.endswith(".docx"):
181
  text = extract_docx_text(file_path)
182
+ elif file.filename.endswith(".pptx"):
183
  text = extract_pptx_text(file_path)
184
  else:
185
  return {"error": "Unsupported file format"}
186
 
187
  qa_result = qa_pipeline(question=question, context=text)
188
+ return templates.TemplateResponse("index.html", {"request": Request, "answer": qa_result['answer']})
189
 
190
+ # Function to answer questions based on images
191
+ @app.post("/question-answering-image")
192
  async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
193
  image = Image.open(BytesIO(await image_file.read()))
 
 
194
  image_text = extract_text_from_image(image)
195
 
 
196
  image_qa_result = image_qa_pipeline({"image": image, "question": question})
197
 
198
+ return templates.TemplateResponse("index.html", {"request": Request, "answer": image_qa_result[0]['answer'], "image_text": image_text})
199
+
200
+ # Serve the application in Hugging Face space
201
+ @app.get("/docs")
202
+ async def get_docs(request: Request):
203
+ return templates.TemplateResponse("index.html", {"request": request})