gemma3-27b-RAG / app.py
Rohit1412's picture
Update app.py
56f4aaa verified
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer, util
import PyPDF2
import os
import time
import logging
from yacana import Agent, Task, LoggerManager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
yacana_logger = LoggerManager()
# Load retriever model
retriever_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Cache for document embeddings
embedding_cache = {}
def extract_text_from_pdf(pdf_file):
"""Extract text from a PDF file, returning a list of page texts."""
pages = []
try:
with open(pdf_file.name, "rb") as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text = page.extract_text()
if text:
pages.append(text.strip())
except Exception as e:
logger.error(f"Error reading PDF {pdf_file.name}: {str(e)}")
pages.append(f"Error reading PDF: {str(e)}")
return pages
def chunk_text(text, chunk_size=500):
"""Split text into chunks of approximately chunk_size characters."""
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
if current_length + len(word) > chunk_size and current_chunk:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(word)
current_length += len(word) + 1
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def get_document_embeddings(documents):
"""Compute embeddings for documents, using cache if available."""
embeddings = []
for doc in documents:
if doc in embedding_cache:
embeddings.append(embedding_cache[doc])
else:
emb = retriever_model.encode(doc, convert_to_tensor=True)
embedding_cache[doc] = emb
embeddings.append(emb)
return torch.stack(embeddings)
def retrieve_context(question, documents):
"""Retrieve top 3 relevant chunks."""
doc_embeddings = get_document_embeddings(documents)
query_embedding = retriever_model.encode(question, convert_to_tensor=True)
cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
top_results = torch.topk(cos_scores, k=min(3, len(documents)))
retrieved_context = ""
for score, idx in zip(top_results.values, top_results.indices):
retrieved_context += f"- {documents[idx]} (score: {score:.2f})\n"
return retrieved_context
def rag_pipeline(question, pdf_files):
"""RAG pipeline with Yacana and Phi-1."""
start_time = time.time()
documents = []
# Process PDFs or use default documents
if pdf_files:
for pdf in pdf_files:
pages = extract_text_from_pdf(pdf)
for page in pages:
chunks = chunk_text(page)
documents.extend(chunks)
else:
documents = [
"Artificial Intelligence (AI) is the simulation of human intelligence in machines.",
"Data Science involves extracting insights from structured and unstructured data using statistical methods.",
"AI and Data Science often work together to build predictive models and automate decision-making.",
"Machine learning, a subset of AI, is widely used in Data Science for pattern recognition.",
]
if not documents:
return "No valid text could be extracted from the PDFs."
# Retrieve context
retrieved_context = retrieve_context(question, documents)
logger.info(f"Retrieved context:\n{retrieved_context}")
# Define Yacana agents and tasks
agent = Agent("Phi1Agent", "phi", logger=yacana_logger) # Assumes phi-1 via Ollama
# Task 1: Initial Answer
initial_task = Task(
name="GenerateInitialAnswer",
instruction=(
f"Using the following context, provide a concise answer to the question:\n\n"
f"Context:\n{retrieved_context}\n\n"
f"Question: {question}\n\n"
f"Answer:"
),
agent=agent
)
# Task 2: Refine Answer
initial_result = initial_task.run()
refine_task = Task(
name="RefineAnswer",
instruction=(
f"Given the context and initial answer, refine and improve the response:\n\n"
f"Context:\n{retrieved_context}\n\n"
f"Question: {question}\n\n"
f"Initial Answer: {initial_result}\n\n"
f"Refined Answer:"
),
agent=agent
)
# Execute tasks
try:
refined_result = refine_task.run()
logger.info(f"Initial answer: {initial_result}")
logger.info(f"Refined answer: {refined_result}")
except Exception as e:
logger.error(f"Error in Yacana tasks: {str(e)}")
return f"Task execution error: {str(e)}"
logger.info(f"Processing time: {time.time() - start_time:.2f} seconds")
return refined_result if refined_result else "Unable to generate a meaningful response."
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# RAG Pipeline with microsoft/phi-1 and Yacana")
gr.Markdown(
"Upload PDFs (or use default AI/Data Science docs), ask a question, "
"and get refined answers using Phi-1 via Yacana on 2 vCPUs and 16GB RAM."
)
with gr.Row():
with gr.Column():
question_input = gr.Textbox(label="Your Question", placeholder="e.g., What is AI and Data Science?", lines=3)
pdf_input = gr.File(label="Upload PDF(s) (optional)", file_types=[".pdf"], file_count="multiple")
submit_button = gr.Button("Submit")
with gr.Column():
response_output = gr.Textbox(label="Response", placeholder="The answer will appear here...", lines=10)
submit_button.click(fn=rag_pipeline, inputs=[question_input, pdf_input], outputs=response_output)
demo.launch()