rahideer commited on
Commit
9eac318
Β·
verified Β·
1 Parent(s): d1f387c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -57
app.py CHANGED
@@ -1,81 +1,91 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import faiss
5
- import os
6
- import zipfile
7
  from langdetect import detect
 
 
8
  from sentence_transformers import SentenceTransformer
9
  from transformers import MBartForConditionalGeneration, MBart50Tokenizer
 
 
 
10
 
11
-
12
- st.set_page_config(page_title="Multilingual RAG Translator/Answer Bot", layout="wide")
13
-
14
- st.title("🌍 Multilingual RAG Translator/Answer Bot")
15
- st.markdown("Ask in Urdu, French, Hindi, etc., and get culturally-aware, context-grounded answers.")
16
-
17
- # --- ZIP extraction ---
18
- zip_file = "all_languages_test.csv.zip"
19
- csv_file = "all_languages_test.csv"
20
-
21
- if not os.path.exists(csv_file):
22
- with zipfile.ZipFile(zip_file, "r") as zip_ref:
23
- zip_ref.extractall()
24
-
25
- # --- Language map and translation model ---
26
- lang_map = {
27
- "en": "en_XX", "fr": "fr_XX", "ur": "ur_PK", "hi": "hi_IN",
28
- "es": "es_XX", "de": "de_DE", "zh-cn": "zh_CN", "ar": "ar_AR"
29
- }
30
 
31
  @st.cache_resource
32
  def load_resources():
33
- df = pd.read_csv(csv_file).dropna()
34
- df["context"] = df["premise"] + " " + df["hypothesis"]
35
- corpus = df["context"].tolist()
36
 
37
- embedder = SentenceTransformer("distiluse-base-multilingual-cased-v2")
38
- embeddings = embedder.encode(corpus, show_progress_bar=True)
39
 
40
- index = faiss.IndexFlatL2(embeddings.shape[1])
41
- index.add(np.array(embeddings))
42
 
43
- tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
44
 
45
- model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
 
 
46
 
47
  return embedder, index, corpus, tokenizer, model
48
 
49
- embedder, index, corpus, tokenizer, model = load_resources()
 
 
 
 
50
 
51
- # --- Answer generation ---
52
- def generate_answer(query, k=3):
53
- lang = detect(query)
54
- token_lang = lang_map.get(lang, "en_XX")
55
 
56
- query_vec = embedder.encode([query])
57
- D, I = index.search(np.array(query_vec), k)
58
- contexts = [corpus[i] for i in I[0]]
59
- context = " ".join(contexts)
60
 
61
- full_input = f"question: {query} context: {context}"
62
- tokenizer.src_lang = token_lang
63
- encoded_input = tokenizer(full_input, return_tensors="pt")
64
  generated_tokens = model.generate(
65
- **encoded_input,
66
- forced_bos_token_id=tokenizer.lang_code_to_id[token_lang]
 
 
 
67
  )
68
-
69
  return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
70
 
71
- # --- UI ---
72
- user_input = st.text_area("πŸ’¬ Enter your question in any supported language:")
 
 
73
 
74
- if st.button("Get Answer"):
75
- if user_input.strip():
76
- with st.spinner("Generating answer..."):
77
- response = generate_answer(user_input)
78
- st.success("Answer:")
79
- st.write(response)
80
  else:
81
- st.warning("Please enter a question first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
2
  from langdetect import detect
3
+ import faiss
4
+ import torch
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import MBartForConditionalGeneration, MBart50Tokenizer
7
+ import numpy as np
8
+ import pandas as pd
9
+ import os
10
 
11
+ st.set_page_config(page_title="🌍 Multilingual RAG Translator/Answer Bot", layout="centered")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @st.cache_resource
14
  def load_resources():
15
+ embedder = SentenceTransformer("sentence-transformers/distiluse-base-multilingual-cased-v1")
16
+ tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
17
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
18
 
19
+ # Load multilingual dataset CSV
20
+ df = pd.read_csv("all_languages_test.csv")
21
 
22
+ # Construct corpus
23
+ corpus = (df["premise"] + " " + df["hypothesis"]).fillna("").tolist()
24
 
25
+ # Compute embeddings for corpus
26
+ corpus_embeddings = embedder.encode(corpus, convert_to_numpy=True, show_progress_bar=True)
27
 
28
+ # Create FAISS index
29
+ dimension = corpus_embeddings.shape[1]
30
+ index = faiss.IndexFlatL2(dimension)
31
+ index.add(corpus_embeddings)
32
 
33
  return embedder, index, corpus, tokenizer, model
34
 
35
+ def detect_lang(text):
36
+ try:
37
+ return detect(text)
38
+ except:
39
+ return "en"
40
 
41
+ def get_top_k_passages(query, embedder, index, corpus, k=3):
42
+ query_embedding = embedder.encode([query], convert_to_numpy=True)
43
+ distances, indices = index.search(query_embedding, k)
44
+ return [corpus[i] for i in indices[0] if i < len(corpus)]
45
 
46
+ def generate_answer(query, context, tokenizer, model, src_lang):
47
+ model.eval()
48
+ tokenizer.src_lang = src_lang
49
+ joined_context = " ".join(context)
50
 
51
+ inputs = tokenizer(query + " " + joined_context, return_tensors="pt", max_length=1024, truncation=True)
 
 
52
  generated_tokens = model.generate(
53
+ **inputs,
54
+ forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"],
55
+ max_length=256,
56
+ num_beams=5,
57
+ early_stopping=True
58
  )
 
59
  return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
60
 
61
+ st.title("🌍 Multilingual RAG Translator/Answer Bot")
62
+ st.caption("Ask in Urdu, French, Hindi, etc., and get culturally-aware, context-grounded answers.")
63
+
64
+ query = st.text_input("πŸ’¬ Enter your question in any supported language:")
65
 
66
+ if query:
67
+ if len(query.strip()) < 3:
68
+ st.warning("Please enter a more complete question for better results.")
 
 
 
69
  else:
70
+ with st.spinner("Thinking..."):
71
+ embedder, index, corpus, tokenizer, model = load_resources()
72
+ lang = detect_lang(query)
73
+
74
+ lang_map = {
75
+ "en": "en_XX", "fr": "fr_XX", "ur": "ur_PK", "hi": "hi_IN",
76
+ "es": "es_XX", "de": "de_DE", "zh": "zh_CN", "ar": "ar_AR",
77
+ "ru": "ru_RU", "tr": "tr_TR", "it": "it_IT", "pt": "pt_XX",
78
+ }
79
+
80
+ src_lang = lang_map.get(lang, "en_XX")
81
+ context = get_top_k_passages(query, embedder, index, corpus)
82
+
83
+ if not context:
84
+ st.error("⚠️ Could not find any relevant context to answer your question.")
85
+ else:
86
+ try:
87
+ answer = generate_answer(query, context, tokenizer, model, src_lang)
88
+ st.markdown("### πŸ“Œ Answer:")
89
+ st.success(answer)
90
+ except Exception as e:
91
+ st.error(f"⚠️ Something went wrong while generating the answer.\n\n{e}")