Rohit1412 commited on
Commit
1c5c7d4
·
verified ·
1 Parent(s): 16025b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -72
app.py CHANGED
@@ -1,92 +1,133 @@
1
  import gradio as gr
2
- import requests
3
- import os
 
4
  import PyPDF2
 
 
 
 
 
 
 
5
 
6
- # Set your Hugging Face API token.
7
- # Option 1: Set it as an environment variable named "HF_API_TOKEN".
8
- # Option 2: Replace "YOUR_HUGGINGFACE_API_TOKEN" with your token directly.
9
- API_TOKEN = os.environ.get("HF_TOKEN", "YOUR_HUGGINGFACE_API_TOKEN")
10
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
11
-
12
- def extract_pdf_text(pdf_file):
13
- """
14
- Extracts text from a PDF file using PyPDF2.
15
- """
16
- pdf_text = ""
17
  try:
18
- with open(pdf_file, "rb") as f:
19
  reader = PyPDF2.PdfReader(f)
20
  for page in reader.pages:
21
  text = page.extract_text()
22
  if text:
23
- pdf_text += text + "\n"
24
  except Exception as e:
25
- print("Error reading PDF:", e)
26
- return pdf_text
27
-
28
- def generate_response(query, pdf_file=None):
29
- """
30
- If a PDF file is uploaded, extract its text and combine a limited part of it
31
- with the user query to form a prompt. Then send the prompt to the Hugging Face
32
- Inference API using the RAG model.
33
- """
34
- pdf_text = ""
35
- if pdf_file is not None:
36
- pdf_text = extract_pdf_text(pdf_file)
37
-
38
- # If PDF text is available, append its (truncated) content as context
39
- if pdf_text:
40
- # Limit the context to avoid token overflow; adjust as needed.
41
- context = pdf_text[:2000]
42
- full_input = "Context: " + context + "\n\nQuestion: " + query
43
- else:
44
- full_input = query
45
-
46
- # Define the model and endpoint for the RAG model.
47
- model_id = "google/flan-t5-large"
48
- api_url = f"https://api-inference.huggingface.co/models/{model_id}"
49
-
50
- payload = {"inputs": full_input}
51
-
52
- response = requests.post(api_url, headers=headers, json=payload)
53
- if response.status_code != 200:
54
- return "Error: " + response.text
55
-
56
- result = response.json()
57
- # Extract the generated text if available.
58
- if isinstance(result, list) and result and "generated_text" in result[0]:
59
- return result[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
60
  else:
61
- return str(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
 
 
63
  with gr.Blocks() as demo:
64
- gr.Markdown("# Retrieval Augmented Generation (RAG) Chatbot with PDF Input")
65
  gr.Markdown(
66
- "Powered by the Hugging Face Inference API. "
67
- "Optionally upload a PDF file and ask a question related to its content. "
68
- "If no PDF is uploaded, the model will answer based solely on the query."
69
  )
70
-
71
  with gr.Row():
72
  with gr.Column():
73
- pdf_input = gr.File(label="Upload PDF (optional)", file_types=[".pdf"])
74
- query_input = gr.Textbox(label="Your Question", placeholder="Type your question here...", lines=3)
75
  submit_button = gr.Button("Submit")
76
- gr.Examples(
77
- examples=[
78
- ["What is the main argument in the document?"],
79
- ["Summarize the content of the PDF."],
80
- ["What conclusions can be drawn from the report?"],
81
- ],
82
- inputs=query_input,
83
- label="Try one of these examples:"
84
- )
85
  with gr.Column():
86
  response_output = gr.Textbox(label="Response", placeholder="The answer will appear here...", lines=10)
87
 
88
- # Link the button click to the generate_response function.
89
- submit_button.click(fn=generate_response, inputs=[query_input, pdf_input], outputs=response_output)
90
-
91
- # Launch the app locally
92
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from sentence_transformers import SentenceTransformer, util
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  import PyPDF2
6
+ import os
7
+ import time
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
+ # Load models
15
+ retriever_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
16
+ gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
17
+ gen_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
18
+
19
+ # Cache for document embeddings
20
+ embedding_cache = {}
21
+
22
+ def extract_text_from_pdf(pdf_file):
23
+ """Extract text from a PDF file, returning a list of page texts."""
24
+ pages = []
25
  try:
26
+ with open(pdf_file.name, "rb") as f:
27
  reader = PyPDF2.PdfReader(f)
28
  for page in reader.pages:
29
  text = page.extract_text()
30
  if text:
31
+ pages.append(text.strip())
32
  except Exception as e:
33
+ logger.error(f"Error reading PDF {pdf_file.name}: {str(e)}")
34
+ pages.append(f"Error reading PDF: {str(e)}")
35
+ return pages
36
+
37
+ def chunk_text(text, chunk_size=500):
38
+ """Split text into chunks of approximately chunk_size characters."""
39
+ words = text.split()
40
+ chunks = []
41
+ current_chunk = []
42
+ current_length = 0
43
+ for word in words:
44
+ if current_length + len(word) > chunk_size and current_chunk:
45
+ chunks.append(" ".join(current_chunk))
46
+ current_chunk = []
47
+ current_length = 0
48
+ current_chunk.append(word)
49
+ current_length += len(word) + 1 # +1 for space
50
+ if current_chunk:
51
+ chunks.append(" ".join(current_chunk))
52
+ return chunks
53
+
54
+ def get_document_embeddings(documents):
55
+ """Compute embeddings for documents, using cache if available."""
56
+ embeddings = []
57
+ for doc in documents:
58
+ if doc in embedding_cache:
59
+ embeddings.append(embedding_cache[doc])
60
+ else:
61
+ emb = retriever_model.encode(doc, convert_to_tensor=True)
62
+ embedding_cache[doc] = emb
63
+ embeddings.append(emb)
64
+ return embeddings
65
+
66
+ def rag_pipeline(question, pdf_files):
67
+ """Optimized RAG pipeline with caching, chunking, and improved retrieval."""
68
+ start_time = time.time()
69
+ documents = []
70
+
71
+ # Process PDFs if provided
72
+ if pdf_files:
73
+ for pdf in pdf_files:
74
+ pages = extract_text_from_pdf(pdf)
75
+ for page in pages:
76
+ chunks = chunk_text(page)
77
+ documents.extend(chunks)
78
  else:
79
+ # Default documents if no PDFs
80
+ documents = [
81
+ "Paris is the capital of France and is known for its art, gastronomy, and culture.",
82
+ "France is a country in Western Europe with diverse landscapes and a rich history.",
83
+ "The Eiffel Tower is one of the most famous landmarks in Paris, France.",
84
+ "Paris has a population of over 2 million people and is a major global city.",
85
+ ]
86
+
87
+ if not documents:
88
+ return "No valid text could be extracted from the PDFs."
89
+
90
+ # Compute embeddings with caching
91
+ doc_embeddings = get_document_embeddings(documents)
92
+
93
+ # Embed the query
94
+ query_embedding = retriever_model.encode(question, convert_to_tensor=True)
95
+
96
+ # Retrieve top 3 chunks using cosine similarity
97
+ cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
98
+ top_results = torch.topk(cos_scores, k=min(3, len(documents)))
99
+ retrieved_context = ""
100
+ for score, idx in zip(top_results.values, top_results.indices):
101
+ retrieved_context += f"Context: {documents[idx]}\n"
102
+
103
+ # Optimized prompt for the generator
104
+ prompt = f"Using the provided context, answer the following question:\n\nContext:\n{retrieved_context}\n\nQuestion: {question}\n\nAnswer:"
105
 
106
+ # Generate answer
107
+ inputs = gen_tokenizer(prompt, return_tensors="pt")
108
+ outputs = gen_model.generate(**inputs, max_new_tokens=100)
109
+ answer = gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
110
+
111
+ # Log processing time
112
+ logger.info(f"Processing time: {time.time() - start_time:.2f} seconds")
113
+ return answer
114
+
115
+ # Gradio UI
116
  with gr.Blocks() as demo:
117
+ gr.Markdown("# Improved Lightweight Local RAG Pipeline with PDF Input")
118
  gr.Markdown(
119
+ "Upload one or more PDF files (or leave blank for default documents), enter your question, "
120
+ "and get an answer generated using an optimized retrieval step (all-MiniLM-L6-v2) and a small "
121
+ "generator model (flan-t5-small). Designed for 2 vCPUs and 16GB RAM."
122
  )
 
123
  with gr.Row():
124
  with gr.Column():
125
+ question_input = gr.Textbox(label="Your Question", placeholder="Type your question here...", lines=3)
126
+ pdf_input = gr.File(label="Upload PDF(s) (optional)", file_types=[".pdf"], file_count="multiple")
127
  submit_button = gr.Button("Submit")
 
 
 
 
 
 
 
 
 
128
  with gr.Column():
129
  response_output = gr.Textbox(label="Response", placeholder="The answer will appear here...", lines=10)
130
 
131
+ submit_button.click(fn=rag_pipeline, inputs=[question_input, pdf_input], outputs=response_output)
132
+
133
+ demo.launch()