rahideer's picture
Update app.py
c858539 verified
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
# Load pre-trained multilingual model for retrieval and generation with trust_remote_code=True
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="faiss", trust_remote_code=True)
# Set up FAISS for multilingual document retrieval
def setup_faiss():
# Load multilingual embeddings for documents (e.g., using LaBSE or multilingual BERT)
model_embed = SentenceTransformer('sentence-transformers/LaBSE')
# Example multilingual documents
docs = [
"How to learn programming?",
"Comment apprendre la programmation?",
"پروگرامنگ سیکھنے کا طریقہ کیا ہے؟"
]
embeddings = model_embed.encode(docs, convert_to_tensor=True)
faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
faiss_index.add(np.array(embeddings))
return faiss_index, docs
# Set up FAISS index
faiss_index, docs = setup_faiss()
# Retrieve documents based on query
def retrieve_docs(query):
# Embed the query
query_embedding = SentenceTransformer('sentence-transformers/LaBSE').encode([query], convert_to_tensor=True)
# Perform retrieval using FAISS
D, I = faiss_index.search(np.array(query_embedding), 1)
# Get the most relevant document
return docs[I[0][0]]
# Handle question-answering
def answer_question(query):
# Retrieve relevant document
retrieved_doc = retrieve_docs(query)
# Tokenize the input
inputs = tokenizer(query, retrieved_doc, return_tensors="pt", padding=True, truncation=True)
# Generate an answer
generated = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
# Decode the answer
answer = tokenizer.decode(generated[0], skip_special_tokens=True)
return answer
# Streamlit interface for user input
st.title("Multilingual RAG Translator/Answer Bot")
st.write("Ask a question in your preferred language (Urdu, French, Hindi)")
query = st.text_input("Enter your question:")
if query:
answer = answer_question(query)
st.write(f"Answer: {answer}")