rahideer commited on
Commit
2c19e76
·
verified ·
1 Parent(s): 97c48f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -60
app.py CHANGED
@@ -1,61 +1,35 @@
1
  import streamlit as st
2
- from datasets import load_dataset
3
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
4
-
5
- # Load a multilingual dataset (xnli or tydi_qa)
6
- def load_data():
7
- try:
8
- # Use a specific version of the dataset
9
- dataset = load_dataset("xnli", "all_languages", split="validation") # Using a direct name instead of a wildcard pattern
10
- st.write(f"Loaded {len(dataset)} examples from the 'validation' split.")
11
- return dataset
12
- except Exception as e:
13
- st.write(f"Error loading 'xnli' dataset: {e}")
14
- return None
15
-
16
- # Initialize RAG model components
17
- def initialize_rag():
18
- try:
19
- # Initialize tokenizer and retriever
20
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
21
- retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_data")
22
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
23
- return tokenizer, retriever, model
24
- except Exception as e:
25
- st.write(f"Error initializing RAG components: {e}")
26
- return None, None, None
27
-
28
- # Main function to run the app
29
- def main():
30
- st.title("Multilingual RAG Translator/Answer Bot")
31
-
32
- # Load the dataset
33
- dataset = load_data()
34
- if dataset is None:
35
- st.write("Dataset could not be loaded.")
36
- return
37
-
38
- # Initialize RAG model components
39
- tokenizer, retriever, model = initialize_rag()
40
- if tokenizer is None or retriever is None or model is None:
41
- st.write("RAG components could not be initialized.")
42
- return
43
-
44
- # UI to input a query
45
- query = st.text_input("Enter your question in Urdu, Hindi, or French:")
46
-
47
- if query:
48
- # Tokenize the input query
49
- inputs = tokenizer(query, return_tensors="pt")
50
-
51
- # Retrieve relevant documents
52
- retrieved_docs = retriever.retrieve(query)
53
- # Generate an answer using the model
54
- generated = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['input_ids'])
55
- answer = tokenizer.decode(generated[0], skip_special_tokens=True)
56
-
57
- st.write("Answer:", answer)
58
-
59
- # Run the Streamlit app
60
- if __name__ == "__main__":
61
- main()
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
+ import groq
4
+
5
+ # Initialize Groq API
6
+ groq_client = groq.Client()
7
+
8
+ # Initialize the zero-shot classification pipeline from Hugging Face
9
+ classifier = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli")
10
+
11
+ # Function to perform zero-shot classification
12
+ def classify_text(sequence, candidate_labels):
13
+ result = classifier(sequence, candidate_labels)
14
+ return result
15
+
16
+ # Streamlit UI elements
17
+ st.title("Zero-Shot Text Classification with XLM-RoBERTa")
18
+ st.markdown("Enter a text and select candidate labels for classification.")
19
+
20
+ # Text input from the user
21
+ sequence = st.text_area("Enter text to classify", "", height=150)
22
+
23
+ # Candidate labels
24
+ candidate_labels = st.text_input("Enter candidate labels (comma separated)", "politics, health, education")
25
+ candidate_labels = [label.strip() for label in candidate_labels.split(",")]
26
+
27
+ # When the classify button is pressed
28
+ if st.button("Classify Text"):
29
+ if sequence:
30
+ result = classify_text(sequence, candidate_labels)
31
+ st.write("Classification Results:")
32
+ st.write(f"Labels: {result['labels']}")
33
+ st.write(f"Scores: {result['scores']}")
34
+ else:
35
+ st.error("Please enter text to classify.")