simple_rag / app.py
mischeiwiller's picture
Update to app.py to rag system
12f7691 verified
raw
history blame contribute delete
2.46 kB
import gradio as gr
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import torch
# Initialize models
retriever = SentenceTransformer("all-MiniLM-L6-v2")
generator = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
# Simple vector store
class VectorStore:
def __init__(self):
self.documents = []
self.embeddings = []
def add_document(self, document):
self.documents.append(document)
embedding = retriever.encode(document)
self.embeddings.append(embedding)
def search(self, query, k=3):
query_embedding = retriever.encode(query)
similarities = np.dot(self.embeddings, query_embedding) / (
np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
)
top_k_indices = np.argsort(similarities)[-k:][::-1]
return [self.documents[i] for i in top_k_indices]
# Initialize vector store
vector_store = VectorStore()
# Load sample dataset (e.g., Wikipedia snippets)
dataset = load_dataset(
"wikipedia", "20220301.simple", split="train[:1000]", trust_remote_code=True
)
for doc in dataset["text"]:
vector_store.add_document(doc)
# RAG function
def rag_query(query, max_length=100):
# Retrieve relevant documents
retrieved_docs = vector_store.search(query)
context = " ".join(retrieved_docs)
# Generate response
input_text = f"Context: {context}\n\nQuestion: {query}\nAnswer:"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = generator.generate(
inputs.input_ids,
max_length=max_length + len(inputs.input_ids[0]),
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("Answer:")[-1].strip()
# Gradio interface
def gradio_interface(query):
return rag_query(query)
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(label="Enter your question"),
outputs=gr.Textbox(label="Answer"),
title="RAG System with Hugging Face and Gradio",
description="Ask questions based on a Wikipedia-based knowledge base.",
)
if __name__ == "__main__":
iface.launch()