Rohit1412 commited on
Commit
3a1d8c8
·
verified ·
1 Parent(s): 12e6360

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -144
app.py CHANGED
@@ -1,161 +1,92 @@
1
- import os
2
  import gradio as gr
3
- import faiss
4
- import numpy as np
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
- from sentence_transformers import SentenceTransformer
7
-
8
- # ---------------------------
9
- # Load Models (cached on first run)
10
- # ---------------------------
11
- def load_models():
12
- hf_token = os.getenv("HF_TOKEN") # Set this secret in your HF Space settings
13
- embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # For embeddings
14
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it", use_auth_token=hf_token)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- "google/gemma-3-4b-it",
17
- device_map="auto",
18
- low_cpu_mem_usage=True,
19
- use_auth_token=hf_token
20
- )
21
- return embed_model, tokenizer, model
22
-
23
- embed_model, tokenizer, model = load_models()
24
 
25
- # ---------------------------
26
- # Global state for FAISS index and document chunks.
27
- # Using a dictionary to hold state.
28
- state = {
29
- "faiss_index": None,
30
- "doc_chunks": []
31
- }
32
 
33
- # ---------------------------
34
- # Document Processing Function
35
- # ---------------------------
36
- def process_document(file, chunk_size, chunk_overlap):
37
  """
38
- Reads the uploaded file (PDF or text), extracts text, splits into chunks,
39
- computes embeddings, and builds a FAISS index.
40
  """
41
- if file is None:
42
- return "No file uploaded."
43
-
44
- file_bytes = file.read()
45
- file_name = file.name
46
- text = ""
47
-
48
- if file_name.lower().endswith(".pdf"):
49
- try:
50
- from PyPDF2 import PdfReader
51
- except ImportError:
52
- return "Error: PyPDF2 is required for PDF extraction."
53
- # Save file to temporary path
54
- temp_path = os.path.join("temp", file_name)
55
- os.makedirs("temp", exist_ok=True)
56
- with open(temp_path, "wb") as f:
57
- f.write(file_bytes)
58
- reader = PdfReader(temp_path)
59
- for page in reader.pages:
60
- text += page.extract_text() or ""
61
- else:
62
- # Assume it's a text file
63
- text = file_bytes.decode("utf-8", errors="ignore")
64
-
65
- if text.strip() == "":
66
- return "No text found in the document."
67
-
68
- # Split text into overlapping chunks
69
- chunks = []
70
- for start in range(0, len(text), chunk_size - chunk_overlap):
71
- chunk_text = text[start: start + chunk_size]
72
- chunks.append(chunk_text)
73
-
74
- # Compute embeddings for each chunk using the embedding model.
75
- embeddings = embed_model.encode(chunks, normalize_embeddings=True).astype('float32')
76
- dim = embeddings.shape[1]
77
-
78
- # Build FAISS index using cosine similarity (normalized vectors -> inner product)
79
- index = faiss.IndexFlatIP(dim)
80
- index.add(embeddings)
81
-
82
- # Update global state
83
- state["faiss_index"] = index
84
- state["doc_chunks"] = chunks
85
-
86
- # Return a preview (first 500 characters of the first chunk) and status.
87
- preview = chunks[0][:500] if chunks else "No content"
88
- return f"Indexed {len(chunks)} chunks.\n\n**Document Preview:**\n{preview}"
89
 
90
- # ---------------------------
91
- # Question Answering Function
92
- # ---------------------------
93
- def answer_question(query, top_k):
94
  """
95
- Retrieves the top_k chunks most relevant to the query using the FAISS index,
96
- builds a prompt with the retrieved context, and generates an answer using the Gemma model.
 
97
  """
98
- index = state.get("faiss_index")
99
- chunks = state.get("doc_chunks")
100
- if index is None or len(chunks) == 0:
101
- return "No document processed. Please upload a document first."
102
-
103
- # Encode query using the same embedding model
104
- query_vec = embed_model.encode([query], normalize_embeddings=True).astype('float32')
105
- D, I = index.search(query_vec, top_k)
 
 
 
 
 
 
 
106
 
107
- # Concatenate retrieved chunks as context
108
- retrieved_text = ""
109
- for idx in I[0]:
110
- retrieved_text += chunks[idx] + "\n"
111
 
112
- # Formulate the prompt for the generative model
113
- prompt = f"Context:\n{retrieved_text}\nQuestion: {query}\nAnswer:"
 
114
 
115
- # Tokenize and generate answer
116
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
117
- output_ids = model.generate(input_ids, max_new_tokens=200, temperature=0.2)
118
- answer = tokenizer.decode(output_ids[0][input_ids.size(1):], skip_special_tokens=True)
119
- return answer.strip()
 
120
 
121
- # ---------------------------
122
- # Gradio Interface
123
- # ---------------------------
124
- with gr.Blocks(title="RAG System with Gemma‑3‑4B‑it") as demo:
125
  gr.Markdown(
126
- """
127
- # RAG System with Gemma‑3‑4B‑it
128
- Upload a document (PDF or TXT) below. The system will extract text, split it into chunks,
129
- build a vector index using FAISS, and then allow you to ask questions based on the document.
130
- """
131
  )
132
 
133
- with gr.Tab("Document Upload & Processing"):
134
- with gr.Row():
135
- file_input = gr.File(label="Upload Document (PDF or TXT)", file_count="single")
136
- with gr.Row():
137
- chunk_size_input = gr.Number(label="Chunk Size (characters)", value=1000, precision=0)
138
- chunk_overlap_input = gr.Number(label="Chunk Overlap (characters)", value=100, precision=0)
139
- process_btn = gr.Button("Process Document")
140
- process_output = gr.Markdown()
 
 
 
 
 
 
 
 
141
 
142
- with gr.Tab("Ask a Question"):
143
- query_input = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
144
- top_k_input = gr.Number(label="Number of Chunks to Retrieve", value=3, precision=0)
145
- answer_btn = gr.Button("Get Answer")
146
- answer_output = gr.Markdown(label="Answer")
147
 
148
- # Set up actions
149
- process_btn.click(
150
- fn=process_document,
151
- inputs=[file_input, chunk_size_input, chunk_overlap_input],
152
- outputs=process_output
153
- )
154
- answer_btn.click(
155
- fn=answer_question,
156
- inputs=[query_input, top_k_input],
157
- outputs=answer_output
158
- )
159
-
160
- if __name__ == "__main__":
161
- demo.launch()
 
 
1
  import gradio as gr
2
+ import requests
3
+ import os
4
+ import PyPDF2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Set your Hugging Face API token.
7
+ # Option 1: Set it as an environment variable named "HF_API_TOKEN".
8
+ # Option 2: Replace "YOUR_HUGGINGFACE_API_TOKEN" with your token directly.
9
+ API_TOKEN = os.environ.get("HF_TOKEN", "YOUR_HUGGINGFACE_API_TOKEN")
10
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
 
 
11
 
12
+ def extract_pdf_text(pdf_file):
 
 
 
13
  """
14
+ Extracts text from a PDF file using PyPDF2.
 
15
  """
16
+ pdf_text = ""
17
+ try:
18
+ with open(pdf_file, "rb") as f:
19
+ reader = PyPDF2.PdfReader(f)
20
+ for page in reader.pages:
21
+ text = page.extract_text()
22
+ if text:
23
+ pdf_text += text + "\n"
24
+ except Exception as e:
25
+ print("Error reading PDF:", e)
26
+ return pdf_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def generate_response(query, pdf_file=None):
 
 
 
29
  """
30
+ If a PDF file is uploaded, extract its text and combine a limited part of it
31
+ with the user query to form a prompt. Then send the prompt to the Hugging Face
32
+ Inference API using the RAG model.
33
  """
34
+ pdf_text = ""
35
+ if pdf_file is not None:
36
+ pdf_text = extract_pdf_text(pdf_file)
37
+
38
+ # If PDF text is available, append its (truncated) content as context
39
+ if pdf_text:
40
+ # Limit the context to avoid token overflow; adjust as needed.
41
+ context = pdf_text[:2000]
42
+ full_input = "Context: " + context + "\n\nQuestion: " + query
43
+ else:
44
+ full_input = query
45
+
46
+ # Define the model and endpoint for the RAG model.
47
+ model_id = "facebook/rag-token-nq"
48
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
49
 
50
+ payload = {"inputs": full_input}
 
 
 
51
 
52
+ response = requests.post(api_url, headers=headers, json=payload)
53
+ if response.status_code != 200:
54
+ return "Error: " + response.text
55
 
56
+ result = response.json()
57
+ # Extract the generated text if available.
58
+ if isinstance(result, list) and result and "generated_text" in result[0]:
59
+ return result[0]["generated_text"]
60
+ else:
61
+ return str(result)
62
 
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("# Retrieval Augmented Generation (RAG) Chatbot with PDF Input")
 
 
65
  gr.Markdown(
66
+ "Powered by the Hugging Face Inference API. "
67
+ "Optionally upload a PDF file and ask a question related to its content. "
68
+ "If no PDF is uploaded, the model will answer based solely on the query."
 
 
69
  )
70
 
71
+ with gr.Row():
72
+ with gr.Column():
73
+ pdf_input = gr.File(label="Upload PDF (optional)", file_types=[".pdf"])
74
+ query_input = gr.Textbox(label="Your Question", placeholder="Type your question here...", lines=3)
75
+ submit_button = gr.Button("Submit")
76
+ gr.Examples(
77
+ examples=[
78
+ ["What is the main argument in the document?"],
79
+ ["Summarize the content of the PDF."],
80
+ ["What conclusions can be drawn from the report?"],
81
+ ],
82
+ inputs=query_input,
83
+ label="Try one of these examples:"
84
+ )
85
+ with gr.Column():
86
+ response_output = gr.Textbox(label="Response", placeholder="The answer will appear here...", lines=10)
87
 
88
+ # Link the button click to the generate_response function.
89
+ submit_button.click(fn=generate_response, inputs=[query_input, pdf_input], outputs=response_output)
 
 
 
90
 
91
+ # Launch the app locally
92
+ demo.launch()