Spaces:
Sleeping
Sleeping
import streamlit as st | |
from datasets import load_dataset | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
# Load a multilingual dataset (xnli or tydi_qa) | |
def load_data(): | |
try: | |
# Load the 'xnli' dataset, validation split | |
dataset = load_dataset("xnli", split="validation") | |
st.write(f"Loaded {len(dataset)} examples from the 'validation' split.") | |
return dataset | |
except Exception as e: | |
st.write(f"Error loading 'xnli' dataset: {e}") | |
return None | |
# Initialize RAG model components | |
def initialize_rag(): | |
try: | |
# Initialize tokenizer and retriever | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") | |
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_data") | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") | |
return tokenizer, retriever, model | |
except Exception as e: | |
st.write(f"Error initializing RAG components: {e}") | |
return None, None, None | |
# Main function to run the app | |
def main(): | |
st.title("Multilingual RAG Translator/Answer Bot") | |
# Load the dataset | |
dataset = load_data() | |
if dataset is None: | |
st.write("Dataset could not be loaded.") | |
return | |
# Initialize RAG model components | |
tokenizer, retriever, model = initialize_rag() | |
if tokenizer is None or retriever is None or model is None: | |
st.write("RAG components could not be initialized.") | |
return | |
# UI to input a query | |
query = st.text_input("Enter your question in Urdu, Hindi, or French:") | |
if query: | |
# Tokenize the input query | |
inputs = tokenizer(query, return_tensors="pt") | |
# Retrieve relevant documents | |
retrieved_docs = retriever.retrieve(query) | |
# Generate an answer using the model | |
generated = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['input_ids']) | |
answer = tokenizer.decode(generated[0], skip_special_tokens=True) | |
st.write("Answer:", answer) | |
# Run the Streamlit app | |
if __name__ == "__main__": | |
main() | |