amasood commited on
Commit
319855f
·
verified ·
1 Parent(s): b48c2eb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import numpy as np
7
+ from groq import Groq
8
+
9
+ # Load dataset
10
+ @st.cache_data
11
+ def load_data():
12
+ dataset = load_dataset("FreedomIntelligence/RAG-Instruct", split="train")
13
+ df = pd.DataFrame(dataset)
14
+ return df[["instruction", "response"]]
15
+
16
+ # Generate embeddings and index
17
+ @st.cache_resource
18
+ def setup_faiss(data):
19
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
20
+ embeddings = model.encode(data["instruction"].tolist())
21
+ index = faiss.IndexFlatL2(embeddings.shape[1])
22
+ index.add(np.array(embeddings))
23
+ return model, index, embeddings
24
+
25
+ # Retrieve relevant context
26
+ def retrieve_context(query, model, index, data, top_k=1):
27
+ query_vec = model.encode([query])
28
+ distances, indices = index.search(np.array(query_vec), top_k)
29
+ results = [data.iloc[i]["instruction"] + "\n\n" + data.iloc[i]["response"] for i in indices[0]]
30
+ return "\n\n".join(results)
31
+
32
+ # Call Groq LLM
33
+ def query_groq(context, query):
34
+ prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
35
+ client = Groq(api_key=st.secrets["GROQ_API_KEY"])
36
+ response = client.chat.completions.create(
37
+ messages=[{"role": "user", "content": prompt}],
38
+ model="llama-3-70b-8192"
39
+ )
40
+ return response.choices[0].message.content
41
+
42
+ # Streamlit UI
43
+ st.set_page_config(page_title="RAG Demo with Groq", layout="wide")
44
+ st.title("🧠 RAG App using Groq API + RAG-Instruct Dataset")
45
+
46
+ data = load_data()
47
+ model, index, _ = setup_faiss(data)
48
+
49
+ st.markdown("Ask a question based on the instruction-response knowledge base.")
50
+
51
+ # Optional queries
52
+ optional_queries = [
53
+ "How to use a specific API function?",
54
+ "Explain how to fine-tune a model.",
55
+ "What is the difference between pretraining and finetuning?",
56
+ "How does retrieval-augmented generation work?",
57
+ "Explain self-supervised learning."
58
+ ]
59
+
60
+ query = st.text_input("Enter your question:", value=optional_queries[0])
61
+ if st.button("Ask"):
62
+ with st.spinner("Retrieving and generating response..."):
63
+ context = retrieve_context(query, model, index, data)
64
+ answer = query_groq(context, query)
65
+ st.subheader("📄 Retrieved Context")
66
+ st.write(context)
67
+ st.subheader("💬 Answer from Groq LLM")
68
+ st.write(answer)
69
+
70
+ st.markdown("### Optional Queries to Try:")
71
+ st.write(", ".join(optional_queries))