File size: 6,342 Bytes
40fd220 9009948 40fd220 44a1d08 40fd220 638e307 6ede436 638e307 6ede436 44a1d08 9009948 94621b3 638e307 40fd220 638e307 9009948 40fd220 638e307 5a59447 638e307 5917d5f 638e307 44a1d08 40fd220 638e307 dba1f58 638e307 dba1f58 638e307 dba1f58 638e307 dba1f58 638e307 94621b3 638e307 40fd220 638e307 40fd220 638e307 40fd220 638e307 9009948 638e307 44a1d08 638e307 dba1f58 638e307 40fd220 638e307 44a1d08 638e307 44a1d08 50b1af7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
import gc
import tempfile
import uuid
import logging
import streamlit as st
from dotenv import load_dotenv
from gitingest import ingest
from llama_index.core import Settings, PromptTemplate, VectorStoreIndex, SimpleDirectoryReader
from llama_index.core.node_parser import MarkdownNodeParser
from llama_index.llms.sambanovasystems import SambaNovaCloud
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# Load environment variables from .env
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Custom exception for application errors
class GitHubRAGError(Exception):
"""Custom exception for GitHub RAG application errors"""
pass
# Fetch API key for SambaNova
SAMBANOVA_API_KEY = os.getenv("SAMBANOVA_API_KEY")
if not SAMBANOVA_API_KEY:
raise ValueError("SAMBANOVA_API_KEY is not set in environment variables")
# Initialize Streamlit session state
if "id" not in st.session_state:
st.session_state.id = uuid.uuid4()
st.session_state.file_cache = {}
st.session_state.messages = []
session_id = st.session_state.id
@st.cache_resource
def load_llm():
"""
Load and cache the SambaNova LLM predictor
"""
return SambaNovaCloud(
api_key=SAMBANOVA_API_KEY,
model="DeepSeek-R1-Distill-Llama-70B",
temperature=0.1,
top_p=0.1,
)
def reset_chat():
"""Clear chat history and free resources"""
st.session_state.messages = []
gc.collect()
def process_with_gitingets(github_url: str):
"""Use gitingest to fetch and summarize the GitHub repository"""
summary, tree, content = ingest(github_url)
return summary, tree, content
# --- Sidebar: Load Repository ---
with st.sidebar:
st.header("Add your GitHub repository!")
github_url = st.text_input(
"GitHub repo URL", placeholder="https://github.com/user/repo"
)
load_btn = st.button("Load Repository")
if github_url and load_btn:
try:
repo_name = github_url.rstrip("/").split("/")[-1]
cache_key = f"{session_id}-{repo_name}"
# Only process if not cached
if cache_key not in st.session_state.file_cache:
with st.spinner("Processing repository..."):
summary, tree, content = process_with_gitingets(github_url)
with tempfile.TemporaryDirectory() as tmpdir:
md_path = os.path.join(tmpdir, f"{repo_name}.md")
with open(md_path, "w", encoding="utf-8") as f:
f.write(content)
loader = SimpleDirectoryReader(input_dir=tmpdir)
docs = loader.load_data()
embed_model = HuggingFaceEmbedding(
model_name="nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True,
)
Settings.embed_model = embed_model
llm_predictor = load_llm()
Settings.llm = llm_predictor
node_parser = MarkdownNodeParser()
index = VectorStoreIndex.from_documents(
documents=docs,
transformations=[node_parser],
show_progress=True,
)
qa_prompt = PromptTemplate(
"You are an AI assistant specialized in analyzing GitHub repositories.\n"
"Repository structure:\n{tree}\n---\n"
"Context:\n{context_str}\n---\n"
"Question: {query_str}\nAnswer:"
)
query_engine = index.as_query_engine(streaming=True)
query_engine.update_prompts({
"response_synthesizer:text_qa_template": qa_prompt
})
st.session_state.file_cache[cache_key] = (query_engine, tree)
st.success("Repository loaded and indexed. Ready to chat!")
else:
st.info("Repository already loaded.")
except Exception as e:
st.error(f"Error loading repository: {e}")
logger.error(f"Load error: {e}")
# --- Main UI: Chat Interface ---
col1, col2 = st.columns([6, 1])
with col1:
st.header("Chat with GitHub RAG")
with col2:
st.button("Clear Chat ↺", on_click=reset_chat)
# Display chat history
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# Chat input box
if prompt := st.chat_input("Ask a question about the repository..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
repo_name = github_url.rstrip("/").split("/")[-1]
cache_key = f"{session_id}-{repo_name}"
if cache_key not in st.session_state.file_cache:
st.error("Please load a repository first!")
else:
query_engine, tree = st.session_state.file_cache[cache_key]
with st.chat_message("assistant"):
placeholder = st.empty()
response_text = ""
try:
response = query_engine.query(prompt)
if hasattr(response, 'response_gen'):
for chunk in response.response_gen:
response_text += chunk
placeholder.markdown(response_text + "▌")
else:
response_text = str(response)
placeholder.markdown(response_text)
except GitHubRAGError as e:
st.error(str(e))
logger.error(f"Error in chat processing: {e}")
response_text = "Sorry, I couldn't process that request."
except Exception as e:
st.error("An unexpected error occurred while processing your query")
logger.error(f"Unexpected error in chat: {e}")
response_text = "Sorry, something went wrong."
placeholder.markdown(response_text)
st.session_state.messages.append({"role": "assistant", "content": response_text})
|