amasood commited on
Commit
5e8a326
Β·
verified Β·
1 Parent(s): 1907b77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -90
app.py CHANGED
@@ -1,104 +1,85 @@
1
  import os
2
  import streamlit as st
3
- import pandas as pd
4
- import numpy as np
5
  import faiss
 
6
  from datasets import load_dataset
7
  from sentence_transformers import SentenceTransformer
8
  from groq import Groq
9
 
10
- # Constants for saving/loading index
11
- INDEX_FILE = "faiss_index.index"
12
- QUESTIONS_FILE = "questions.npy"
 
 
13
 
14
- # Load dataset
15
- @st.cache_data
16
- def load_data():
17
- dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train")
18
- df = pd.DataFrame(dataset)
19
- return df[["question", "answer"]]
20
 
21
- # Build or load FAISS index
22
- @st.cache_resource
23
- def setup_faiss(data):
24
- model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
25
-
26
- if os.path.exists(INDEX_FILE) and os.path.exists(QUESTIONS_FILE):
27
- st.info("πŸ” Loading FAISS index from disk...")
28
- index = faiss.read_index(INDEX_FILE)
29
- questions = np.load(QUESTIONS_FILE, allow_pickle=True)
30
- else:
31
- st.info("βš™οΈ FAISS index not found. Building new index...")
32
-
33
- questions = data["question"].tolist()
34
- embeddings = []
35
- progress_bar = st.progress(0, text="Embedding questions...")
36
- total = len(questions)
37
-
38
- for i, chunk in enumerate(np.array_split(questions, 10)):
39
- emb = model.encode(chunk)
40
- embeddings.extend(emb)
41
- progress_bar.progress((i + 1) / 10, text=f"Embedding... {int((i + 1) * 10)}%")
42
-
43
- embeddings = np.array(embeddings)
44
- index = faiss.IndexFlatL2(embeddings.shape[1])
45
- index.add(embeddings)
46
-
47
- faiss.write_index(index, INDEX_FILE)
48
- np.save(QUESTIONS_FILE, np.array(questions, dtype=object))
49
-
50
- progress_bar.empty()
51
- st.success("βœ… FAISS index built and saved!")
52
-
53
- return model, index, questions
54
-
55
-
56
- # Retrieve relevant context
57
- def retrieve_context(query, model, index, questions, data, top_k=1):
58
- query_vec = model.encode([query])
59
- distances, indices = index.search(np.array(query_vec), top_k)
60
- results = [questions[i] + "\n\n" + data.iloc[i]["answer"] for i in indices[0]]
61
- return "\n\n".join(results)
62
-
63
- # Call Groq LLM
64
- def query_groq(context, query):
65
- prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
66
- client = Groq(api_key=st.secrets["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"])
67
- response = client.chat.completions.create(
68
- messages=[{"role": "user", "content": prompt}],
69
- model="llama-3-70b-8192"
70
- )
71
- return response.choices[0].message.content
72
-
73
- # Streamlit UI
74
  st.set_page_config(page_title="RAG App with Groq", layout="wide")
75
- st.title("πŸ” RAG App using Groq API + RAG-Instruct Dataset")
76
-
77
- # Load data and setup
78
- data = load_data()
79
- model, index, questions = setup_faiss(data)
80
-
81
- st.markdown("Ask a question based on the QA knowledge base.")
82
 
83
- # Optional queries
84
- optional_queries = [
85
- "What is retrieval-augmented generation?",
86
- "How can I fine-tune a language model?",
87
- "What are the components of a RAG system?",
88
- "Explain prompt engineering basics.",
89
- "How does FAISS indexing help in RAG?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  ]
91
 
92
- query = st.text_input("Enter your question:", value=optional_queries[0])
 
 
93
  if st.button("Ask"):
94
- with st.spinner("Retrieving and generating response..."):
95
- context = retrieve_context(query, model, index, questions, data)
96
- answer = query_groq(context, query)
97
- st.subheader("πŸ“„ Retrieved Context")
98
- st.write(context)
99
- st.subheader("πŸ’¬ Answer from Groq LLM")
100
- st.write(answer)
101
-
102
- st.markdown("### πŸ’‘ Optional Queries to Try:")
103
- for q in optional_queries:
104
- st.markdown(f"- {q}")
 
 
 
 
 
 
 
 
 
1
  import os
2
  import streamlit as st
 
 
3
  import faiss
4
+ import pickle
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer
7
  from groq import Groq
8
 
9
+ # Constants
10
+ DATASET_NAME = "neural-bridge/rag-dataset-1200"
11
+ MODEL_NAME = "all-MiniLM-L6-v2"
12
+ INDEX_FILE = "faiss_index.pkl"
13
+ DOCS_FILE = "contexts.pkl"
14
 
15
+ # Set up Groq client
16
+ client = Groq(api_key=os.environ.get("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"))
 
 
 
 
17
 
18
+ # UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  st.set_page_config(page_title="RAG App with Groq", layout="wide")
20
+ st.title("🧠 Retrieval-Augmented Generation (RAG) App")
 
 
 
 
 
 
21
 
22
+ # Load or create vector DB
23
+ @st.cache_resource
24
+ def setup_database():
25
+ st.info("Loading dataset and setting up database...")
26
+ progress = st.progress(0)
27
+
28
+ dataset = load_dataset(DATASET_NAME, split="train")
29
+ contexts = [entry["context"] for entry in dataset]
30
+
31
+ embedder = SentenceTransformer(MODEL_NAME)
32
+ embeddings = embedder.encode(contexts, show_progress_bar=True)
33
+
34
+ dimension = embeddings[0].shape[0]
35
+ index = faiss.IndexFlatL2(dimension)
36
+ index.add(embeddings)
37
+
38
+ # Save index and contexts
39
+ with open(INDEX_FILE, "wb") as f:
40
+ pickle.dump(index, f)
41
+ with open(DOCS_FILE, "wb") as f:
42
+ pickle.dump(contexts, f)
43
+
44
+ progress.progress(100)
45
+ return index, contexts
46
+
47
+ # Load existing index or build
48
+ if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE):
49
+ with open(INDEX_FILE, "rb") as f:
50
+ faiss_index = pickle.load(f)
51
+ with open(DOCS_FILE, "rb") as f:
52
+ all_contexts = pickle.load(f)
53
+ else:
54
+ faiss_index, all_contexts = setup_database()
55
+
56
+ # Sample questions
57
+ sample_questions = [
58
+ "What is the role of Falcon RefinedWeb in this dataset?",
59
+ "How can retrieval improve language generation?",
60
+ "Explain the purpose of the RAG dataset."
61
  ]
62
 
63
+ st.subheader("Ask a question based on the dataset:")
64
+ question = st.text_input("Your question", value=sample_questions[0])
65
+
66
  if st.button("Ask"):
67
+ with st.spinner("Retrieving relevant context and generating answer..."):
68
+ embedder = SentenceTransformer(MODEL_NAME)
69
+ question_embedding = embedder.encode([question])
70
+ D, I = faiss_index.search(question_embedding, k=1)
71
+
72
+ retrieved_context = all_contexts[I[0][0]]
73
+ prompt = f"Context: {retrieved_context}\n\nQuestion: {question}\n\nAnswer:"
74
+
75
+ response = client.chat.completions.create(
76
+ messages=[{"role": "user", "content": prompt}],
77
+ model="llama-3-70b-8192"
78
+ )
79
+
80
+ answer = response.choices[0].message.content
81
+ st.success("Answer:")
82
+ st.write(answer)
83
+
84
+ with st.expander("Retrieved Context"):
85
+ st.markdown(retrieved_context)