import streamlit as st from datasets import load_dataset from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration # Load a multilingual dataset (xnli or tydi_qa) def load_data(): try: # Load the 'xnli' dataset, validation split dataset = load_dataset("xnli", split="validation") st.write(f"Loaded {len(dataset)} examples from the 'validation' split.") return dataset except Exception as e: st.write(f"Error loading 'xnli' dataset: {e}") return None # Initialize RAG model components def initialize_rag(): try: # Initialize tokenizer and retriever tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_data") model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") return tokenizer, retriever, model except Exception as e: st.write(f"Error initializing RAG components: {e}") return None, None, None # Main function to run the app def main(): st.title("Multilingual RAG Translator/Answer Bot") # Load the dataset dataset = load_data() if dataset is None: st.write("Dataset could not be loaded.") return # Initialize RAG model components tokenizer, retriever, model = initialize_rag() if tokenizer is None or retriever is None or model is None: st.write("RAG components could not be initialized.") return # UI to input a query query = st.text_input("Enter your question in Urdu, Hindi, or French:") if query: # Tokenize the input query inputs = tokenizer(query, return_tensors="pt") # Retrieve relevant documents retrieved_docs = retriever.retrieve(query) # Generate an answer using the model generated = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['input_ids']) answer = tokenizer.decode(generated[0], skip_special_tokens=True) st.write("Answer:", answer) # Run the Streamlit app if __name__ == "__main__": main()