amasood commited on
Commit
8b835fd
Β·
verified Β·
1 Parent(s): 576a1d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -16
app.py CHANGED
@@ -1,11 +1,16 @@
 
1
  import streamlit as st
2
  import pandas as pd
 
 
3
  from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer
5
- import faiss
6
- import numpy as np
7
  from groq import Groq
8
 
 
 
 
 
9
  # Load dataset
10
  @st.cache_data
11
  def load_data():
@@ -13,26 +18,51 @@ def load_data():
13
  df = pd.DataFrame(dataset)
14
  return df[["question", "answer"]]
15
 
16
- # Generate embeddings and index
17
  @st.cache_resource
18
  def setup_faiss(data):
19
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
20
- embeddings = model.encode(data["question"].tolist())
21
- index = faiss.IndexFlatL2(embeddings.shape[1])
22
- index.add(np.array(embeddings))
23
- return model, index, embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Retrieve relevant context
26
- def retrieve_context(query, model, index, data, top_k=1):
27
  query_vec = model.encode([query])
28
  distances, indices = index.search(np.array(query_vec), top_k)
29
- results = [data.iloc[i]["question"] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]]
30
  return "\n\n".join(results)
31
 
32
  # Call Groq LLM
33
  def query_groq(context, query):
34
  prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
35
- #client = Groq(api_key=st.secrets[GROQ_API_KEY])
36
  client = Groq(api_key=GROQ_API_KEY)
37
  response = client.chat.completions.create(
38
  messages=[{"role": "user", "content": prompt}],
@@ -41,11 +71,12 @@ def query_groq(context, query):
41
  return response.choices[0].message.content
42
 
43
  # Streamlit UI
44
- st.set_page_config(page_title="RAG Demo with Groq", layout="wide")
45
- st.title("🧠 RAG App using Groq API + RAG-Instruct Dataset")
46
 
 
47
  data = load_data()
48
- model, index, _ = setup_faiss(data)
49
 
50
  st.markdown("Ask a question based on the QA knowledge base.")
51
 
@@ -61,12 +92,13 @@ optional_queries = [
61
  query = st.text_input("Enter your question:", value=optional_queries[0])
62
  if st.button("Ask"):
63
  with st.spinner("Retrieving and generating response..."):
64
- context = retrieve_context(query, model, index, data)
65
  answer = query_groq(context, query)
66
  st.subheader("πŸ“„ Retrieved Context")
67
  st.write(context)
68
  st.subheader("πŸ’¬ Answer from Groq LLM")
69
  st.write(answer)
70
 
71
- st.markdown("### Optional Queries to Try:")
72
- st.write(", ".join(optional_queries))
 
 
1
+ import os
2
  import streamlit as st
3
  import pandas as pd
4
+ import numpy as np
5
+ import faiss
6
  from datasets import load_dataset
7
  from sentence_transformers import SentenceTransformer
 
 
8
  from groq import Groq
9
 
10
+ # Constants for saving/loading index
11
+ INDEX_FILE = "faiss_index.index"
12
+ QUESTIONS_FILE = "questions.npy"
13
+
14
  # Load dataset
15
  @st.cache_data
16
  def load_data():
 
18
  df = pd.DataFrame(dataset)
19
  return df[["question", "answer"]]
20
 
21
+ # Build or load FAISS index
22
  @st.cache_resource
23
  def setup_faiss(data):
24
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
25
+
26
+ if os.path.exists(INDEX_FILE) and os.path.exists(QUESTIONS_FILE):
27
+ st.info("πŸ” Loading FAISS index from disk...")
28
+ index = faiss.read_index(INDEX_FILE)
29
+ questions = np.load(QUESTIONS_FILE, allow_pickle=True)
30
+ else:
31
+ st.info("βš™οΈ FAISS index not found. Building new index...")
32
+
33
+ questions = data["question"].tolist()
34
+ embeddings = []
35
+ progress_bar = st.progress(0, text="Embedding questions...")
36
+ total = len(questions)
37
+
38
+ for i, chunk in enumerate(np.array_split(questions, 10)):
39
+ emb = model.encode(chunk)
40
+ embeddings.extend(emb)
41
+ progress_bar.progress((i + 1) / 10, text=f"Embedding... {int((i + 1) * 10)}%")
42
+
43
+ embeddings = np.array(embeddings)
44
+ index = faiss.IndexFlatL2(embeddings.shape[1])
45
+ index.add(embeddings)
46
+
47
+ faiss.write_index(index, INDEX_FILE)
48
+ np.save(QUESTIONS_FILE, np.array(questions, dtype=object))
49
+
50
+ progress_bar.empty()
51
+ st.success("βœ… FAISS index built and saved!")
52
+
53
+ return model, index, questions
54
+
55
 
56
  # Retrieve relevant context
57
+ def retrieve_context(query, model, index, questions, data, top_k=1):
58
  query_vec = model.encode([query])
59
  distances, indices = index.search(np.array(query_vec), top_k)
60
+ results = [questions[i] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]]
61
  return "\n\n".join(results)
62
 
63
  # Call Groq LLM
64
  def query_groq(context, query):
65
  prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
 
66
  client = Groq(api_key=GROQ_API_KEY)
67
  response = client.chat.completions.create(
68
  messages=[{"role": "user", "content": prompt}],
 
71
  return response.choices[0].message.content
72
 
73
  # Streamlit UI
74
+ st.set_page_config(page_title="RAG App with Groq", layout="wide")
75
+ st.title("πŸ” RAG App using Groq API + RAG-Instruct Dataset")
76
 
77
+ # Load data and setup
78
  data = load_data()
79
+ model, index, questions = setup_faiss(data)
80
 
81
  st.markdown("Ask a question based on the QA knowledge base.")
82
 
 
92
  query = st.text_input("Enter your question:", value=optional_queries[0])
93
  if st.button("Ask"):
94
  with st.spinner("Retrieving and generating response..."):
95
+ context = retrieve_context(query, model, index, questions, data)
96
  answer = query_groq(context, query)
97
  st.subheader("πŸ“„ Retrieved Context")
98
  st.write(context)
99
  st.subheader("πŸ’¬ Answer from Groq LLM")
100
  st.write(answer)
101
 
102
+ st.markdown("### πŸ’‘ Optional Queries to Try:")
103
+ for q in optional_queries:
104
+ st.markdown(f"- {q}")