Abs6187's picture
Upload 5 files
2ac4ccc verified
raw
history blame
18.1 kB
import streamlit as st
from llama_index.core.agent import ReActAgent
from llama_index.llms.groq import Groq
from llama_index.core.tools import FunctionTool
from llama_index.tools.tavily_research.base import TavilyToolSpec
import os
import json
import pandas as pd
from datetime import datetime
from dotenv import load_dotenv
import time
import base64
import plotly.graph_objects as go
import re
# Load environment variables
load_dotenv()
# Initialize session state if not already done
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = []
if 'api_key' not in st.session_state:
st.session_state.api_key = ""
if 'current_response' not in st.session_state:
st.session_state.current_response = None
if 'feedback_data' not in st.session_state:
st.session_state.feedback_data = []
if 'current_sources' not in st.session_state:
st.session_state.current_sources = []
# Custom CSS for better UI
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
color: #4527A0;
text-align: center;
margin-bottom: 1rem;
font-weight: bold;
}
.sub-header {
font-size: 1.5rem;
color: #5E35B1;
margin-bottom: 0.5rem;
}
.team-header {
font-size: 1.2rem;
color: #673AB7;
font-weight: bold;
margin-top: 1rem;
}
.team-member {
font-size: 1rem;
margin-left: 1rem;
color: #7E57C2;
}
.api-section {
background-color: #EDE7F6;
padding: 1rem;
border-radius: 10px;
margin-bottom: 1rem;
}
.response-container {
background-color: #F3E5F5;
padding: 1rem;
border-radius: 5px;
margin-top: 1rem;
}
.footer {
text-align: center;
margin-top: 2rem;
font-size: 0.8rem;
color: #9575CD;
}
.error-msg {
color: #D32F2F;
font-weight: bold;
}
.success-msg {
color: #388E3C;
font-weight: bold;
}
.history-item {
padding: 0.5rem;
border-radius: 5px;
margin-bottom: 0.5rem;
}
.query-text {
font-weight: bold;
color: #303F9F;
}
.response-text {
color: #1A237E;
}
.feedback-container {
background-color: #E8EAF6;
padding: 1rem;
border-radius: 5px;
margin-top: 1rem;
}
.feedback-btn {
margin-right: 0.5rem;
}
.star-rating {
display: flex;
justify-content: center;
margin-top: 0.5rem;
}
.analytics-container {
background-color: #E1F5FE;
padding: 1rem;
border-radius: 5px;
margin-top: 1rem;
}
.sources-container {
background-color: #E0F7FA;
padding: 1rem;
border-radius: 5px;
margin-top: 1rem;
}
.source-item {
background-color: #B2EBF2;
padding: 0.5rem;
border-radius: 5px;
margin-bottom: 0.5rem;
}
.source-url {
font-style: italic;
color: #0277BD;
word-break: break-all;
}
</style>
""", unsafe_allow_html=True)
# Main title and description
st.markdown('<div class="main-header">TechMatrix AI Web Search Agent</div>', unsafe_allow_html=True)
st.markdown('''
This intelligent agent uses state-of-the-art LLM technology to search the web and provide comprehensive answers to your questions.
Simply enter your query, and let our AI handle the rest!
''')
# Sidebar for team information
with st.sidebar:
st.markdown('<div class="team-header">TechMatrix Solvers</div>', unsafe_allow_html=True)
st.markdown('<div class="team-member">πŸ‘‘ Abhay Gupta (Team Leader)</div>', unsafe_allow_html=True)
st.markdown('[LinkedIn Profile](https://www.linkedin.com/in/abhay-gupta-197b17264/)')
st.markdown('<div class="team-member">🧠 Mayank Das Bairagi</div>', unsafe_allow_html=True)
st.markdown('[LinkedIn Profile](https://www.linkedin.com/in/mayank-das-bairagi-18639525a/)')
st.markdown('<div class="team-member">πŸ’» Kripanshu Gupta</div>', unsafe_allow_html=True)
st.markdown('[LinkedIn Profile](https://www.linkedin.com/in/kripanshu-gupta-a66349261/)')
st.markdown('<div class="team-member">πŸ” Bhumika Patel</div>', unsafe_allow_html=True)
st.markdown('[LinkedIn Profile](https://www.linkedin.com/in/bhumika-patel-ml/)')
st.markdown('---')
# Advanced Settings
st.markdown('<div class="sub-header">Advanced Settings</div>', unsafe_allow_html=True)
model_option = st.selectbox(
'LLM Model',
('gemma2-9b-it', 'llama3-8b-8192', 'mixtral-8x7b-32768'),
index=0
)
search_depth = st.slider('Search Depth', min_value=1, max_value=5, value=3,
help="Higher values will search more thoroughly but take longer")
# Clear history button
if st.button('Clear Conversation History'):
st.session_state.conversation_history = []
st.success('Conversation history cleared!')
# Analytics section in sidebar
if st.session_state.feedback_data:
st.markdown('---')
st.markdown('<div class="sub-header">Response Analytics</div>', unsafe_allow_html=True)
# Calculate average rating
ratings = [item['rating'] for item in st.session_state.feedback_data if 'rating' in item]
avg_rating = sum(ratings) / len(ratings) if ratings else 0
# Create a chart
fig = go.Figure(go.Indicator(
mode="gauge+number",
value=avg_rating,
title={'text': "Average Rating"},
domain={'x': [0, 1], 'y': [0, 1]},
gauge={
'axis': {'range': [0, 5]},
'bar': {'color': "#6200EA"},
'steps': [
{'range': [0, 2], 'color': "#FFD0D0"},
{'range': [2, 3.5], 'color': "#FFFFCC"},
{'range': [3.5, 5], 'color': "#D0FFD0"}
]
}
))
fig.update_layout(height=250, margin=dict(l=20, r=20, t=30, b=20))
st.plotly_chart(fig, use_container_width=True)
# Show feedback counts
feedback_counts = {"πŸ‘ Helpful": 0, "πŸ‘Ž Not Helpful": 0}
for item in st.session_state.feedback_data:
if 'feedback' in item:
if item['feedback'] == 'helpful':
feedback_counts["πŸ‘ Helpful"] += 1
elif item['feedback'] == 'not_helpful':
feedback_counts["πŸ‘Ž Not Helpful"] += 1
st.markdown("### Feedback Summary")
for key, value in feedback_counts.items():
st.markdown(f"**{key}:** {value}")
# API key input section
st.markdown('<div class="sub-header">API Credentials</div>', unsafe_allow_html=True)
with st.expander("Configure API Keys"):
st.markdown('<div class="api-section">', unsafe_allow_html=True)
api_key = st.text_input("Enter your Groq API key:",
type="password",
value=st.session_state.api_key,
help="Get your API key from https://console.groq.com/keys")
tavily_key = st.text_input("Enter your Tavily API key (optional):",
type="password",
help="Get your Tavily API key from https://tavily.com/#api")
if api_key:
st.session_state.api_key = api_key
os.environ['GROQ_API_KEY'] = api_key
if tavily_key:
os.environ['TAVILY_API_KEY'] = tavily_key
st.markdown('</div>', unsafe_allow_html=True)
# Function to create download link for text data
def get_download_link(text, filename, link_text):
b64 = base64.b64encode(text.encode()).decode()
href = f'<a href="data:file/txt;base64,{b64}" download="{filename}">{link_text}</a>'
return href
# Function to handle feedback submission
def submit_feedback(feedback_type, query, response):
feedback_entry = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"query": query,
"response": response,
"feedback": feedback_type
}
st.session_state.feedback_data.append(feedback_entry)
return True
# Function to submit rating
def submit_rating(rating, query, response):
# Find if there's an existing entry for this query/response
for entry in st.session_state.feedback_data:
if entry.get('query') == query and entry.get('response') == response:
entry['rating'] = rating
return True
# If not found, create a new entry
feedback_entry = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"query": query,
"response": response,
"rating": rating
}
st.session_state.feedback_data.append(feedback_entry)
return True
# Function to extract URLs from text
def extract_urls(text):
url_pattern = r'https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+'
return re.findall(url_pattern, text)
# Setup search tools
try:
if 'TAVILY_API_KEY' in os.environ and os.environ['TAVILY_API_KEY']:
search = TavilyToolSpec(api_key=os.environ['TAVILY_API_KEY'])
else:
# Fallback to a default key or inform the user
st.warning("Using default Tavily API key with limited quota. For better results, please provide your own key.")
search = TavilyToolSpec(api_key=os.getenv('TAVILY_API_KEY'))
def search_tool(prompt: str) -> list:
"""Search the web for information about the given prompt."""
try:
search_results = search.search(prompt, max_results=search_depth)
# Store source URLs
sources = []
for result in search_results:
if hasattr(result, 'url') and result.url:
sources.append({
'title': result.title if hasattr(result, 'title') else "Unknown Source",
'url': result.url
})
# Store in session state for later display
st.session_state.current_sources = sources
return [result.text for result in search_results]
except Exception as e:
return [f"Error during search: {str(e)}"]
search_toolkit = FunctionTool.from_defaults(fn=search_tool)
except Exception as e:
st.error(f"Error setting up search tools: {str(e)}")
search_toolkit = None
# Query input
query = st.text_input("What would you like to know?",
placeholder="Enter your question here...",
help="Ask any question, and our AI will search the web for answers")
# Search button
search_button = st.button("πŸ” Search")
# Process the search when button is clicked
if search_button and query:
# Check if API key is provided
if not st.session_state.api_key:
st.error("Please enter your Groq API key first!")
else:
try:
with st.spinner("🧠 Searching the web and analyzing results..."):
# Initialize the LLM and agent
llm = Groq(model=model_option)
agent = ReActAgent.from_tools([search_toolkit], llm=llm, verbose=True)
# Clear current sources before the new search
st.session_state.current_sources = []
# Get the response
start_time = time.time()
response = agent.chat(query)
end_time = time.time()
# Extract any additional URLs from the response
additional_urls = extract_urls(response.response)
for url in additional_urls:
if not any(source['url'] == url for source in st.session_state.current_sources):
st.session_state.current_sources.append({
'title': "Referenced Source",
'url': url
})
# Store the response in session state
st.session_state.current_response = {
"query": query,
"response": response.response,
"time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"duration": round(end_time - start_time, 2),
"sources": st.session_state.current_sources
}
# Add to conversation history
st.session_state.conversation_history.append(st.session_state.current_response)
# Display success message
st.success(f"Found results in {round(end_time - start_time, 2)} seconds!")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
# Display current response if available
if st.session_state.current_response:
with st.container():
st.markdown('<div class="response-container">', unsafe_allow_html=True)
st.markdown("### Response:")
st.write(st.session_state.current_response["response"])
# Export options
col1, col2 = st.columns(2)
with col1:
st.markdown(
get_download_link(
st.session_state.current_response["response"],
f"search_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
"Download as Text"
),
unsafe_allow_html=True
)
with col2:
# Create JSON with metadata
json_data = json.dumps({
"query": st.session_state.current_response["query"],
"response": st.session_state.current_response["response"],
"timestamp": st.session_state.current_response["time"],
"processing_time": st.session_state.current_response["duration"],
"sources": st.session_state.current_sources if "sources" in st.session_state.current_response else []
}, indent=4)
st.markdown(
get_download_link(
json_data,
f"search_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
"Download as JSON"
),
unsafe_allow_html=True
)
st.markdown('</div>', unsafe_allow_html=True)
# Display sources if available
if "sources" in st.session_state.current_response and st.session_state.current_response["sources"]:
with st.expander("View Sources", expanded=True):
st.markdown('<div class="sources-container">', unsafe_allow_html=True)
for i, source in enumerate(st.session_state.current_response["sources"]):
st.markdown(f'<div class="source-item">', unsafe_allow_html=True)
st.markdown(f"**Source {i+1}:** {source.get('title', 'Unknown Source')}")
st.markdown(f'<div class="source-url"><a href="{source["url"]}" target="_blank">{source["url"]}</a></div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# Feedback section
st.markdown('<div class="feedback-container">', unsafe_allow_html=True)
st.markdown("### Was this response helpful?")
col1, col2 = st.columns(2)
with col1:
if st.button("πŸ‘ Helpful", key="helpful_btn"):
if submit_feedback("helpful", st.session_state.current_response["query"], st.session_state.current_response["response"]):
st.success("Thank you for your feedback!")
with col2:
if st.button("πŸ‘Ž Not Helpful", key="not_helpful_btn"):
if submit_feedback("not_helpful", st.session_state.current_response["query"], st.session_state.current_response["response"]):
st.success("Thank you for your feedback! We'll work to improve our responses.")
st.markdown("### Rate this response:")
rating = st.slider("", min_value=1, max_value=5, value=4,
help="Rate the quality of this response from 1 (poor) to 5 (excellent)")
if st.button("Submit Rating"):
if submit_rating(rating, st.session_state.current_response["query"], st.session_state.current_response["response"]):
st.success("Rating submitted! Thank you for helping us improve.")
st.markdown('</div>', unsafe_allow_html=True)
# Display conversation history
if st.session_state.conversation_history:
with st.expander("View Conversation History"):
for i, item in enumerate(reversed(st.session_state.conversation_history)):
st.markdown(f'<div class="history-item">', unsafe_allow_html=True)
st.markdown(f'<span class="query-text">Q: {item["query"]}</span> <small>({item["time"]})</small>', unsafe_allow_html=True)
st.markdown(f'<div class="response-text">A: {item["response"][:200]}{"..." if len(item["response"]) > 200 else ""}</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
if i < len(st.session_state.conversation_history) - 1:
st.markdown('---')
# Footer with attribution
st.markdown('''
<div class="footer">
<p>Powered by Groq + Llama-Index + Tavily Search | Created by TechMatrix Solvers | 2024</p>
</div>
''', unsafe_allow_html=True)