amasood commited on
Commit
401d7df
·
verified ·
1 Parent(s): 229fd5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -36
app.py CHANGED
@@ -12,74 +12,89 @@ MODEL_NAME = "all-MiniLM-L6-v2"
12
  INDEX_FILE = "faiss_index.pkl"
13
  DOCS_FILE = "contexts.pkl"
14
 
15
- # Set up Groq client
16
  client = Groq(api_key=os.environ.get("MY_KEY"))
17
 
18
- # UI
19
- st.set_page_config(page_title="RAG App with Groq", layout="wide")
20
- st.title("🧠 Retrieval-Augmented Generation (RAG) App")
21
 
22
- # Load or create vector DB
23
  @st.cache_resource
24
  def setup_database():
25
- st.info("Loading dataset and setting up database...")
26
  progress = st.progress(0)
27
 
 
28
  dataset = load_dataset(DATASET_NAME, split="train")
29
  contexts = [entry["context"] for entry in dataset]
 
30
 
 
31
  embedder = SentenceTransformer(MODEL_NAME)
32
  embeddings = embedder.encode(contexts, show_progress_bar=True)
 
33
 
 
34
  dimension = embeddings[0].shape[0]
35
- index = faiss.IndexFlatL2(dimension)
36
- index.add(embeddings)
 
37
 
38
- # Save index and contexts
39
  with open(INDEX_FILE, "wb") as f:
40
- pickle.dump(index, f)
41
  with open(DOCS_FILE, "wb") as f:
42
  pickle.dump(contexts, f)
43
 
44
  progress.progress(100)
45
- return index, contexts
 
46
 
47
- # Load existing index or build
48
  if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE):
49
  with open(INDEX_FILE, "rb") as f:
50
  faiss_index = pickle.load(f)
51
  with open(DOCS_FILE, "rb") as f:
52
  all_contexts = pickle.load(f)
 
53
  else:
54
  faiss_index, all_contexts = setup_database()
55
 
56
- # Sample questions
57
  sample_questions = [
58
- "What is the role of Falcon RefinedWeb in this dataset?",
59
- "How can retrieval improve language generation?",
60
- "Explain the purpose of the RAG dataset."
 
61
  ]
62
 
63
  st.subheader("Ask a question based on the dataset:")
64
- question = st.text_input("Your question", value=sample_questions[0])
65
 
66
  if st.button("Ask"):
67
- with st.spinner("Retrieving relevant context and generating answer..."):
68
- embedder = SentenceTransformer(MODEL_NAME)
69
- question_embedding = embedder.encode([question])
70
- D, I = faiss_index.search(question_embedding, k=1)
71
-
72
- retrieved_context = all_contexts[I[0][0]]
73
- prompt = f"Context: {retrieved_context}\n\nQuestion: {question}\n\nAnswer:"
74
-
75
- response = client.chat.completions.create(
76
- messages=[{"role": "user", "content": prompt}],
77
- model="llama-3-70b-8192"
78
- )
79
-
80
- answer = response.choices[0].message.content
81
- st.success("Answer:")
82
- st.write(answer)
83
-
84
- with st.expander("Retrieved Context"):
85
- st.markdown(retrieved_context)
 
 
 
 
 
 
 
12
  INDEX_FILE = "faiss_index.pkl"
13
  DOCS_FILE = "contexts.pkl"
14
 
15
+ # Groq API client
16
  client = Groq(api_key=os.environ.get("MY_KEY"))
17
 
18
+ # Streamlit page setup
19
+ st.set_page_config(page_title="RAG App", layout="wide")
20
+ st.title("🧠 Retrieval-Augmented Generation (RAG) with Groq")
21
 
22
+ # Function to load or create database
23
  @st.cache_resource
24
  def setup_database():
25
+ st.info("Setting up vector database...")
26
  progress = st.progress(0)
27
 
28
+ # Step 1: Load dataset
29
  dataset = load_dataset(DATASET_NAME, split="train")
30
  contexts = [entry["context"] for entry in dataset]
31
+ progress.progress(25)
32
 
33
+ # Step 2: Compute embeddings
34
  embedder = SentenceTransformer(MODEL_NAME)
35
  embeddings = embedder.encode(contexts, show_progress_bar=True)
36
+ progress.progress(50)
37
 
38
+ # Step 3: Build FAISS index
39
  dimension = embeddings[0].shape[0]
40
+ faiss_index = faiss.IndexFlatL2(dimension)
41
+ faiss_index.add(embeddings)
42
+ progress.progress(75)
43
 
44
+ # Step 4: Save index and contexts for future use
45
  with open(INDEX_FILE, "wb") as f:
46
+ pickle.dump(faiss_index, f)
47
  with open(DOCS_FILE, "wb") as f:
48
  pickle.dump(contexts, f)
49
 
50
  progress.progress(100)
51
+ st.success("Database setup complete!")
52
+ return faiss_index, contexts
53
 
54
+ # Check if the index and contexts are saved, otherwise set up
55
  if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE):
56
  with open(INDEX_FILE, "rb") as f:
57
  faiss_index = pickle.load(f)
58
  with open(DOCS_FILE, "rb") as f:
59
  all_contexts = pickle.load(f)
60
+ st.info("Loaded existing database.")
61
  else:
62
  faiss_index, all_contexts = setup_database()
63
 
64
+ # UI for sample questions
65
  sample_questions = [
66
+ "What is the purpose of the RAG dataset?",
67
+ "How does Falcon RefinedWeb contribute to this dataset?",
68
+ "What are the benefits of using retrieval-augmented generation?",
69
+ "Explain the structure of the RAG-1200 dataset.",
70
  ]
71
 
72
  st.subheader("Ask a question based on the dataset:")
73
+ question = st.text_input("Enter your question:", value=sample_questions[0])
74
 
75
  if st.button("Ask"):
76
+ if question.strip() == "":
77
+ st.warning("Please enter a question.")
78
+ else:
79
+ with st.spinner("Retrieving and generating answer..."):
80
+ # Embed user query
81
+ embedder = SentenceTransformer(MODEL_NAME)
82
+ query_embedding = embedder.encode([question])
83
+ D, I = faiss_index.search(query_embedding, k=1)
84
+
85
+ # Get closest context
86
+ context = all_contexts[I[0][0]]
87
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
88
+
89
+ # Call Groq model
90
+ response = client.chat.completions.create(
91
+ messages=[{"role": "user", "content": prompt}],
92
+ model="llama3-70b-8192"
93
+ )
94
+
95
+ answer = response.choices[0].message.content
96
+ st.success("Answer:")
97
+ st.markdown(answer)
98
+
99
+ with st.expander("🔍 Retrieved Context"):
100
+ st.markdown(context)