ikraamkb commited on
Commit
70781e0
·
verified ·
1 Parent(s): 935d12d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -38
app.py CHANGED
@@ -103,83 +103,80 @@ async def get_docs(request: Request):
103
  from fastapi import FastAPI, Form, File, UploadFile
104
  from fastapi.responses import RedirectResponse
105
  from fastapi.staticfiles import StaticFiles
106
- from pydantic import BaseModel
 
 
107
  from transformers import pipeline
108
  import os
109
  from PIL import Image
110
- import io
111
  import pdfplumber
112
  import docx
113
- import openpyxl
114
  import pytesseract
115
  from io import BytesIO
116
  import fitz # PyMuPDF
117
  import easyocr
118
- from fastapi.templating import Jinja2Templates
119
- from starlette.requests import Request
120
- from fastapi.middleware.cors import CORSMiddleware
121
 
122
- # Initialize the app
123
  app = FastAPI()
124
 
125
- # Enable CORS (for frontend interaction)
126
  app.add_middleware(
127
  CORSMiddleware,
128
- allow_origins=["*"], # Adjust this for security
129
  allow_credentials=True,
130
  allow_methods=["*"],
131
  allow_headers=["*"],
132
  )
133
 
134
- # Mount the static directory to serve HTML, CSS, JS files
135
  app.mount("/static", StaticFiles(directory="static"), name="static")
136
 
137
- # Initialize transformers pipelines
138
  qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
139
  image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")
140
 
141
- # Initialize EasyOCR for image-based text extraction
142
  reader = easyocr.Reader(['en'])
143
 
144
- # Define a template for rendering HTML
145
  templates = Jinja2Templates(directory="templates")
146
 
147
- # Ensure temp_files directory exists
148
  temp_dir = "temp_files"
149
  os.makedirs(temp_dir, exist_ok=True)
150
 
151
- # Function to process PDFs
152
  def extract_pdf_text(file_path: str):
153
  with pdfplumber.open(file_path) as pdf:
154
- text = ""
155
- for page in pdf.pages:
156
- extracted = page.extract_text()
157
- if extracted:
158
- text += extracted + "\n"
159
- return text
160
 
161
- # Function to process DOCX files
162
  def extract_docx_text(file_path: str):
163
  doc = docx.Document(file_path)
164
  return "\n".join([para.text for para in doc.paragraphs])
165
 
166
- # Function to process PPTX files
167
  def extract_pptx_text(file_path: str):
168
  from pptx import Presentation
169
  prs = Presentation(file_path)
170
  return "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
171
 
172
- # Function to extract text from images using OCR
173
  def extract_text_from_image(image: Image):
174
  return pytesseract.image_to_string(image)
175
 
176
- # Home route
177
  @app.get("/")
178
  def home():
179
- return RedirectResponse(url="/docs")
180
 
181
- # Function to answer questions based on document content
182
- @app.post("/docs")
 
 
 
 
 
183
  async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
184
  file_ext = file.filename.split(".")[-1].lower()
185
  file_path = os.path.join(temp_dir, file.filename)
@@ -187,11 +184,11 @@ async def question_answering_doc(question: str = Form(...), file: UploadFile = F
187
  with open(file_path, "wb") as f:
188
  f.write(await file.read())
189
 
190
- if file_ext in ["pdf"]:
191
  text = extract_pdf_text(file_path)
192
- elif file_ext in ["docx"]:
193
  text = extract_docx_text(file_path)
194
- elif file_ext in ["pptx"]:
195
  text = extract_pptx_text(file_path)
196
  else:
197
  return {"error": "Unsupported file format"}
@@ -199,8 +196,8 @@ async def question_answering_doc(question: str = Form(...), file: UploadFile = F
199
  qa_result = qa_pipeline(question=question, context=text)
200
  return {"answer": qa_result['answer']}
201
 
202
- # Function to answer questions based on images
203
- @app.post("/images")
204
  async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
205
  image = Image.open(BytesIO(await image_file.read()))
206
 
@@ -211,8 +208,3 @@ async def question_answering_image(question: str = Form(...), image_file: Upload
211
  image_qa_result = image_qa_pipeline({"image": image, "question": question})
212
 
213
  return {"answer": image_qa_result[0]['answer'], "image_text": image_text}
214
-
215
- # Serve the application in Hugging Face space
216
- @app.get("/docs")
217
- async def get_docs(request: Request):
218
- return templates.TemplateResponse("index.html", {"request": request})
 
103
  from fastapi import FastAPI, Form, File, UploadFile
104
  from fastapi.responses import RedirectResponse
105
  from fastapi.staticfiles import StaticFiles
106
+ from fastapi.middleware.cors import CORSMiddleware
107
+ from fastapi.templating import Jinja2Templates
108
+ from starlette.requests import Request
109
  from transformers import pipeline
110
  import os
111
  from PIL import Image
 
112
  import pdfplumber
113
  import docx
 
114
  import pytesseract
115
  from io import BytesIO
116
  import fitz # PyMuPDF
117
  import easyocr
 
 
 
118
 
119
+ # Initialize the FastAPI app
120
  app = FastAPI()
121
 
122
+ # Enable CORS for frontend communication
123
  app.add_middleware(
124
  CORSMiddleware,
125
+ allow_origins=["*"],
126
  allow_credentials=True,
127
  allow_methods=["*"],
128
  allow_headers=["*"],
129
  )
130
 
131
+ # Mount static files (if you have HTML/CSS/JS)
132
  app.mount("/static", StaticFiles(directory="static"), name="static")
133
 
134
+ # Initialize transformer models
135
  qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
136
  image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")
137
 
138
+ # Initialize OCR
139
  reader = easyocr.Reader(['en'])
140
 
141
+ # Define templates for HTML pages
142
  templates = Jinja2Templates(directory="templates")
143
 
144
+ # Ensure the temp directory exists
145
  temp_dir = "temp_files"
146
  os.makedirs(temp_dir, exist_ok=True)
147
 
148
+ # Function to extract text from PDF
149
  def extract_pdf_text(file_path: str):
150
  with pdfplumber.open(file_path) as pdf:
151
+ return "\n".join([page.extract_text() or "" for page in pdf.pages])
 
 
 
 
 
152
 
153
+ # Function to extract text from DOCX
154
  def extract_docx_text(file_path: str):
155
  doc = docx.Document(file_path)
156
  return "\n".join([para.text for para in doc.paragraphs])
157
 
158
+ # Function to extract text from PPTX
159
  def extract_pptx_text(file_path: str):
160
  from pptx import Presentation
161
  prs = Presentation(file_path)
162
  return "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
163
 
164
+ # Function to extract text from images
165
  def extract_text_from_image(image: Image):
166
  return pytesseract.image_to_string(image)
167
 
168
+ # Redirect home to custom frontend page
169
  @app.get("/")
170
  def home():
171
+ return RedirectResponse(url="/app-ui")
172
 
173
+ # Serve HTML interface (instead of showing FastAPI docs)
174
+ @app.get("/app-ui")
175
+ async def get_ui(request: Request):
176
+ return templates.TemplateResponse("index.html", {"request": request})
177
+
178
+ # New endpoint for document-based question answering
179
+ @app.post("/qa-docs") # 🚨 Changed from `/docs` to `/qa-docs`
180
  async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
181
  file_ext = file.filename.split(".")[-1].lower()
182
  file_path = os.path.join(temp_dir, file.filename)
 
184
  with open(file_path, "wb") as f:
185
  f.write(await file.read())
186
 
187
+ if file_ext == "pdf":
188
  text = extract_pdf_text(file_path)
189
+ elif file_ext == "docx":
190
  text = extract_docx_text(file_path)
191
+ elif file_ext == "pptx":
192
  text = extract_pptx_text(file_path)
193
  else:
194
  return {"error": "Unsupported file format"}
 
196
  qa_result = qa_pipeline(question=question, context=text)
197
  return {"answer": qa_result['answer']}
198
 
199
+ # New endpoint for image-based question answering
200
+ @app.post("/qa-images") # 🚨 Changed from `/images` to `/qa-images`
201
  async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
202
  image = Image.open(BytesIO(await image_file.read()))
203
 
 
208
  image_qa_result = image_qa_pipeline({"image": image, "question": question})
209
 
210
  return {"answer": image_qa_result[0]['answer'], "image_text": image_text}