|
import streamlit as st |
|
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
|
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") |
|
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") |
|
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="faiss", trust_remote_code=True) |
|
|
|
|
|
def setup_faiss(): |
|
|
|
model_embed = SentenceTransformer('sentence-transformers/LaBSE') |
|
|
|
|
|
docs = [ |
|
"How to learn programming?", |
|
"Comment apprendre la programmation?", |
|
"پروگرامنگ سیکھنے کا طریقہ کیا ہے؟" |
|
] |
|
|
|
embeddings = model_embed.encode(docs, convert_to_tensor=True) |
|
faiss_index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
faiss_index.add(np.array(embeddings)) |
|
|
|
return faiss_index, docs |
|
|
|
|
|
faiss_index, docs = setup_faiss() |
|
|
|
|
|
def retrieve_docs(query): |
|
|
|
query_embedding = SentenceTransformer('sentence-transformers/LaBSE').encode([query], convert_to_tensor=True) |
|
|
|
|
|
D, I = faiss_index.search(np.array(query_embedding), 1) |
|
|
|
|
|
return docs[I[0][0]] |
|
|
|
|
|
def answer_question(query): |
|
|
|
retrieved_doc = retrieve_docs(query) |
|
|
|
|
|
inputs = tokenizer(query, retrieved_doc, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
generated = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) |
|
|
|
|
|
answer = tokenizer.decode(generated[0], skip_special_tokens=True) |
|
return answer |
|
|
|
|
|
st.title("Multilingual RAG Translator/Answer Bot") |
|
st.write("Ask a question in your preferred language (Urdu, French, Hindi)") |
|
|
|
query = st.text_input("Enter your question:") |
|
|
|
if query: |
|
answer = answer_question(query) |
|
st.write(f"Answer: {answer}") |
|
|