amasood commited on
Commit
b7416a2
Β·
verified Β·
1 Parent(s): d2ffcaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -3,19 +3,25 @@ import pandas as pd
3
  import os
4
  import faiss
5
  import pickle
 
 
6
  from sentence_transformers import SentenceTransformer
7
  from groq import Groq
8
- from datasets import load_dataset
9
 
10
  # Load environment variables
11
- from dotenv import load_dotenv
12
  load_dotenv()
13
 
14
  # Setup Groq client
15
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
16
- MODEL_NAME = "llama-3-70b-8192" # Or use "llama-3-8b-8192", "llama-3-3b-8192"
17
 
18
- # Load dataset
 
 
 
 
 
19
  @st.cache_data
20
  def load_data():
21
  dataset = load_dataset("llmware/rag_instruct_benchmark_tester", split="train")
@@ -34,16 +40,16 @@ def load_embeddings(df):
34
 
35
  return index, embeddings, embed_model
36
 
37
- # Retrieve top k similar context passages
38
  def retrieve_context(query, embed_model, index, df, k=3):
39
  query_embedding = embed_model.encode([query])
40
  D, I = index.search(query_embedding, k)
41
  context_passages = df.iloc[I[0]]['context'].tolist()
42
  return context_passages
43
 
44
- # Ask Groq LLM
45
  def ask_groq(query, context):
46
- prompt = f"""You are a helpful assistant. Use the provided context to answer the question.
47
 
48
  Context:
49
  {context}
@@ -52,38 +58,43 @@ Question:
52
  {query}
53
 
54
  Answer:"""
 
55
  response = client.chat.completions.create(
56
  messages=[{"role": "user", "content": prompt}],
57
  model=MODEL_NAME
58
  )
59
  return response.choices[0].message.content
60
 
61
- # Streamlit UI
62
- st.title("πŸ“š RAG App with Groq API")
63
- st.markdown("Use this Retrieval-Augmented Generation app to ask enterprise, legal, and financial questions.")
64
-
65
  df = load_data()
66
  index, embeddings, embed_model = load_embeddings(df)
67
 
 
 
68
  sample_queries = df['query'].dropna().unique().tolist()
69
-
70
- query = st.text_input("Enter your question:", "")
71
- if st.button("Use Random Sample"):
72
- import random
73
- query = random.choice(sample_queries)
74
- st.session_state["query"] = query
75
- st.experimental_rerun()
76
-
 
77
  if query:
78
  st.markdown(f"**Your Query:** {query}")
79
- with st.spinner("Retrieving relevant context..."):
 
80
  contexts = retrieve_context(query, embed_model, index, df)
81
  combined_context = "\n\n".join(contexts)
82
- with st.spinner("Getting answer from Groq..."):
 
83
  answer = ask_groq(query, combined_context)
 
84
  st.markdown("### πŸ’‘ Answer")
85
  st.write(answer)
 
86
  st.markdown("### πŸ“„ Retrieved Context")
87
  for i, ctx in enumerate(contexts, 1):
88
- st.markdown(f"**Context {i}:**")
89
- st.write(ctx)
 
3
  import os
4
  import faiss
5
  import pickle
6
+ import random
7
+ from datasets import load_dataset
8
  from sentence_transformers import SentenceTransformer
9
  from groq import Groq
10
+ from dotenv import load_dotenv
11
 
12
  # Load environment variables
 
13
  load_dotenv()
14
 
15
  # Setup Groq client
16
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
17
+ MODEL_NAME = "llama-3-70b-8192" # or try "llama-3-8b-8192" or "llama-3-3b-8192"
18
 
19
+ # Streamlit UI
20
+ st.set_page_config(page_title="RAG with Groq", layout="wide")
21
+ st.title("πŸ“š RAG App using Groq API")
22
+ st.markdown("Ask enterprise, financial, and legal questions using Retrieval-Augmented Generation (RAG).")
23
+
24
+ # Load dataset from Hugging Face
25
  @st.cache_data
26
  def load_data():
27
  dataset = load_dataset("llmware/rag_instruct_benchmark_tester", split="train")
 
40
 
41
  return index, embeddings, embed_model
42
 
43
+ # Retrieve top-k relevant context
44
  def retrieve_context(query, embed_model, index, df, k=3):
45
  query_embedding = embed_model.encode([query])
46
  D, I = index.search(query_embedding, k)
47
  context_passages = df.iloc[I[0]]['context'].tolist()
48
  return context_passages
49
 
50
+ # Ask the Groq LLM
51
  def ask_groq(query, context):
52
+ prompt = f"""You are a helpful assistant. Use the context to answer the question.
53
 
54
  Context:
55
  {context}
 
58
  {query}
59
 
60
  Answer:"""
61
+
62
  response = client.chat.completions.create(
63
  messages=[{"role": "user", "content": prompt}],
64
  model=MODEL_NAME
65
  )
66
  return response.choices[0].message.content
67
 
68
+ # Load everything
 
 
 
69
  df = load_data()
70
  index, embeddings, embed_model = load_embeddings(df)
71
 
72
+ # User input
73
+ st.subheader("πŸ” Ask your question")
74
  sample_queries = df['query'].dropna().unique().tolist()
75
+ col1, col2 = st.columns([3, 1])
76
+ with col1:
77
+ query = st.text_input("Enter your question here:")
78
+ with col2:
79
+ if st.button("🎲 Random Sample"):
80
+ query = random.choice(sample_queries)
81
+ st.experimental_rerun()
82
+
83
+ # Handle query
84
  if query:
85
  st.markdown(f"**Your Query:** {query}")
86
+
87
+ with st.spinner("πŸ”Ž Retrieving relevant context..."):
88
  contexts = retrieve_context(query, embed_model, index, df)
89
  combined_context = "\n\n".join(contexts)
90
+
91
+ with st.spinner("πŸ€– Querying Groq LLM..."):
92
  answer = ask_groq(query, combined_context)
93
+
94
  st.markdown("### πŸ’‘ Answer")
95
  st.write(answer)
96
+
97
  st.markdown("### πŸ“„ Retrieved Context")
98
  for i, ctx in enumerate(contexts, 1):
99
+ with st.expander(f"Context {i}"):
100
+ st.write(ctx)