rahideer's picture
Update app.py
6273efa verified
raw
history blame
2.19 kB
import streamlit as st
from datasets import load_dataset
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
# Load a multilingual dataset (xnli or tydi_qa)
def load_data():
try:
# Load the 'xnli' dataset, validation split
dataset = load_dataset("xnli", split="validation")
st.write(f"Loaded {len(dataset)} examples from the 'validation' split.")
return dataset
except Exception as e:
st.write(f"Error loading 'xnli' dataset: {e}")
return None
# Initialize RAG model components
def initialize_rag():
try:
# Initialize tokenizer and retriever
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_data")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
return tokenizer, retriever, model
except Exception as e:
st.write(f"Error initializing RAG components: {e}")
return None, None, None
# Main function to run the app
def main():
st.title("Multilingual RAG Translator/Answer Bot")
# Load the dataset
dataset = load_data()
if dataset is None:
st.write("Dataset could not be loaded.")
return
# Initialize RAG model components
tokenizer, retriever, model = initialize_rag()
if tokenizer is None or retriever is None or model is None:
st.write("RAG components could not be initialized.")
return
# UI to input a query
query = st.text_input("Enter your question in Urdu, Hindi, or French:")
if query:
# Tokenize the input query
inputs = tokenizer(query, return_tensors="pt")
# Retrieve relevant documents
retrieved_docs = retriever.retrieve(query)
# Generate an answer using the model
generated = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['input_ids'])
answer = tokenizer.decode(generated[0], skip_special_tokens=True)
st.write("Answer:", answer)
# Run the Streamlit app
if __name__ == "__main__":
main()