rahideer commited on
Commit
6273efa
·
verified ·
1 Parent(s): 8a21666

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -39
app.py CHANGED
@@ -1,42 +1,61 @@
1
  import streamlit as st
2
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  from datasets import load_dataset
4
- from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification
5
-
6
- # Load a multilingual dataset (use "xnli" or "tydi_qa")
7
- try:
8
- dataset = load_dataset("xnli", "en", split="validation") # Using English subset as an example
9
- except Exception as e:
10
- st.error(f"Error loading the dataset: {e}")
11
-
12
- # Initialize tokenizer and retriever for multilingual support (using XLM-Roberta)
13
- tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
14
- retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_multilingual_dataset")
15
-
16
- # Initialize the RAG model
17
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
18
-
19
- # Define Streamlit app
20
- st.title('Multilingual RAG Translator/Answer Bot')
21
-
22
- st.markdown("This app uses a multilingual RAG model to answer your questions in the language of the query. Ask questions in languages like Urdu, Hindi, or French!")
23
-
24
- # User input for query
25
- user_query = st.text_input("Ask a question in Urdu, Hindi, or French:")
26
-
27
- if user_query:
28
- # Tokenize the input question
29
- inputs = tokenizer(user_query, return_tensors="pt", padding=True, truncation=True)
30
- input_ids = inputs['input_ids']
31
-
32
- # Use the retriever to get relevant context
33
- retrieved_docs = retriever.retrieve(input_ids)
34
-
35
- # Generate an answer using the context
36
- generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs)
37
- answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
38
-
39
- # Display the answer
40
- st.write(f"Answer: {answer}")
41
 
42
- # Display the most relevant documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
2
  from datasets import load_dataset
3
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Load a multilingual dataset (xnli or tydi_qa)
6
+ def load_data():
7
+ try:
8
+ # Load the 'xnli' dataset, validation split
9
+ dataset = load_dataset("xnli", split="validation")
10
+ st.write(f"Loaded {len(dataset)} examples from the 'validation' split.")
11
+ return dataset
12
+ except Exception as e:
13
+ st.write(f"Error loading 'xnli' dataset: {e}")
14
+ return None
15
+
16
+ # Initialize RAG model components
17
+ def initialize_rag():
18
+ try:
19
+ # Initialize tokenizer and retriever
20
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
21
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_data")
22
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
23
+ return tokenizer, retriever, model
24
+ except Exception as e:
25
+ st.write(f"Error initializing RAG components: {e}")
26
+ return None, None, None
27
+
28
+ # Main function to run the app
29
+ def main():
30
+ st.title("Multilingual RAG Translator/Answer Bot")
31
+
32
+ # Load the dataset
33
+ dataset = load_data()
34
+ if dataset is None:
35
+ st.write("Dataset could not be loaded.")
36
+ return
37
+
38
+ # Initialize RAG model components
39
+ tokenizer, retriever, model = initialize_rag()
40
+ if tokenizer is None or retriever is None or model is None:
41
+ st.write("RAG components could not be initialized.")
42
+ return
43
+
44
+ # UI to input a query
45
+ query = st.text_input("Enter your question in Urdu, Hindi, or French:")
46
+
47
+ if query:
48
+ # Tokenize the input query
49
+ inputs = tokenizer(query, return_tensors="pt")
50
+
51
+ # Retrieve relevant documents
52
+ retrieved_docs = retriever.retrieve(query)
53
+ # Generate an answer using the model
54
+ generated = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['input_ids'])
55
+ answer = tokenizer.decode(generated[0], skip_special_tokens=True)
56
+
57
+ st.write("Answer:", answer)
58
+
59
+ # Run the Streamlit app
60
+ if __name__ == "__main__":
61
+ main()