ikraamkb commited on
Commit
935d12d
·
verified ·
1 Parent(s): c7979e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -1
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, Form, File, UploadFile
2
  from fastapi.responses import RedirectResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from pydantic import BaseModel
@@ -99,3 +99,120 @@ async def question_answering_image(question: str = Form(...), image_file: Upload
99
  @app.get("/docs")
100
  async def get_docs(request: Request):
101
  return templates.TemplateResponse("index.html", {"request": request})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """from fastapi import FastAPI, Form, File, UploadFile
2
  from fastapi.responses import RedirectResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from pydantic import BaseModel
 
99
  @app.get("/docs")
100
  async def get_docs(request: Request):
101
  return templates.TemplateResponse("index.html", {"request": request})
102
+ """
103
+ from fastapi import FastAPI, Form, File, UploadFile
104
+ from fastapi.responses import RedirectResponse
105
+ from fastapi.staticfiles import StaticFiles
106
+ from pydantic import BaseModel
107
+ from transformers import pipeline
108
+ import os
109
+ from PIL import Image
110
+ import io
111
+ import pdfplumber
112
+ import docx
113
+ import openpyxl
114
+ import pytesseract
115
+ from io import BytesIO
116
+ import fitz # PyMuPDF
117
+ import easyocr
118
+ from fastapi.templating import Jinja2Templates
119
+ from starlette.requests import Request
120
+ from fastapi.middleware.cors import CORSMiddleware
121
+
122
+ # Initialize the app
123
+ app = FastAPI()
124
+
125
+ # Enable CORS (for frontend interaction)
126
+ app.add_middleware(
127
+ CORSMiddleware,
128
+ allow_origins=["*"], # Adjust this for security
129
+ allow_credentials=True,
130
+ allow_methods=["*"],
131
+ allow_headers=["*"],
132
+ )
133
+
134
+ # Mount the static directory to serve HTML, CSS, JS files
135
+ app.mount("/static", StaticFiles(directory="static"), name="static")
136
+
137
+ # Initialize transformers pipelines
138
+ qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
139
+ image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")
140
+
141
+ # Initialize EasyOCR for image-based text extraction
142
+ reader = easyocr.Reader(['en'])
143
+
144
+ # Define a template for rendering HTML
145
+ templates = Jinja2Templates(directory="templates")
146
+
147
+ # Ensure temp_files directory exists
148
+ temp_dir = "temp_files"
149
+ os.makedirs(temp_dir, exist_ok=True)
150
+
151
+ # Function to process PDFs
152
+ def extract_pdf_text(file_path: str):
153
+ with pdfplumber.open(file_path) as pdf:
154
+ text = ""
155
+ for page in pdf.pages:
156
+ extracted = page.extract_text()
157
+ if extracted:
158
+ text += extracted + "\n"
159
+ return text
160
+
161
+ # Function to process DOCX files
162
+ def extract_docx_text(file_path: str):
163
+ doc = docx.Document(file_path)
164
+ return "\n".join([para.text for para in doc.paragraphs])
165
+
166
+ # Function to process PPTX files
167
+ def extract_pptx_text(file_path: str):
168
+ from pptx import Presentation
169
+ prs = Presentation(file_path)
170
+ return "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
171
+
172
+ # Function to extract text from images using OCR
173
+ def extract_text_from_image(image: Image):
174
+ return pytesseract.image_to_string(image)
175
+
176
+ # Home route
177
+ @app.get("/")
178
+ def home():
179
+ return RedirectResponse(url="/docs")
180
+
181
+ # Function to answer questions based on document content
182
+ @app.post("/docs")
183
+ async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
184
+ file_ext = file.filename.split(".")[-1].lower()
185
+ file_path = os.path.join(temp_dir, file.filename)
186
+
187
+ with open(file_path, "wb") as f:
188
+ f.write(await file.read())
189
+
190
+ if file_ext in ["pdf"]:
191
+ text = extract_pdf_text(file_path)
192
+ elif file_ext in ["docx"]:
193
+ text = extract_docx_text(file_path)
194
+ elif file_ext in ["pptx"]:
195
+ text = extract_pptx_text(file_path)
196
+ else:
197
+ return {"error": "Unsupported file format"}
198
+
199
+ qa_result = qa_pipeline(question=question, context=text)
200
+ return {"answer": qa_result['answer']}
201
+
202
+ # Function to answer questions based on images
203
+ @app.post("/images")
204
+ async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
205
+ image = Image.open(BytesIO(await image_file.read()))
206
+
207
+ # Extract text from image using OCR
208
+ image_text = extract_text_from_image(image)
209
+
210
+ # Get answer from the VQA model
211
+ image_qa_result = image_qa_pipeline({"image": image, "question": question})
212
+
213
+ return {"answer": image_qa_result[0]['answer'], "image_text": image_text}
214
+
215
+ # Serve the application in Hugging Face space
216
+ @app.get("/docs")
217
+ async def get_docs(request: Request):
218
+ return templates.TemplateResponse("index.html", {"request": request})