File size: 6,342 Bytes
40fd220
 
 
 
9009948
40fd220
44a1d08
 
 
40fd220
638e307
6ede436
638e307
6ede436
44a1d08
 
 
9009948
 
 
 
 
94621b3
 
 
 
 
638e307
 
 
 
 
 
40fd220
 
638e307
9009948
40fd220
 
 
638e307
 
 
5a59447
638e307
 
 
5917d5f
638e307
 
 
 
44a1d08
 
 
 
 
 
 
 
 
 
 
 
 
40fd220
638e307
 
 
 
dba1f58
638e307
dba1f58
638e307
 
 
 
 
 
 
dba1f58
638e307
 
 
 
dba1f58
638e307
 
 
 
 
94621b3
 
 
638e307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40fd220
638e307
40fd220
638e307
40fd220
638e307
9009948
638e307
 
44a1d08
638e307
dba1f58
638e307
40fd220
 
638e307
 
 
44a1d08
638e307
 
 
 
 
 
 
44a1d08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50b1af7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import gc
import tempfile
import uuid
import logging

import streamlit as st
from dotenv import load_dotenv

from gitingest import ingest
from llama_index.core import Settings, PromptTemplate, VectorStoreIndex, SimpleDirectoryReader
from llama_index.core.node_parser import MarkdownNodeParser
from llama_index.llms.sambanovasystems import SambaNovaCloud
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

# Load environment variables from .env
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Custom exception for application errors
class GitHubRAGError(Exception):
    """Custom exception for GitHub RAG application errors"""
    pass

# Fetch API key for SambaNova
SAMBANOVA_API_KEY = os.getenv("SAMBANOVA_API_KEY")
if not SAMBANOVA_API_KEY:
    raise ValueError("SAMBANOVA_API_KEY is not set in environment variables")

# Initialize Streamlit session state
if "id" not in st.session_state:
    st.session_state.id = uuid.uuid4()
    st.session_state.file_cache = {}
    st.session_state.messages = []

session_id = st.session_state.id

@st.cache_resource
def load_llm():
    """
    Load and cache the SambaNova LLM predictor
    """
    return SambaNovaCloud(
        api_key=SAMBANOVA_API_KEY,
        model="DeepSeek-R1-Distill-Llama-70B",
        temperature=0.1,
        top_p=0.1,
    )


def reset_chat():
    """Clear chat history and free resources"""
    st.session_state.messages = []
    gc.collect()


def process_with_gitingets(github_url: str):
    """Use gitingest to fetch and summarize the GitHub repository"""
    summary, tree, content = ingest(github_url)
    return summary, tree, content

# --- Sidebar: Load Repository ---
with st.sidebar:
    st.header("Add your GitHub repository!")
    github_url = st.text_input(
        "GitHub repo URL", placeholder="https://github.com/user/repo"
    )
    load_btn = st.button("Load Repository")

    if github_url and load_btn:
        try:
            repo_name = github_url.rstrip("/").split("/")[-1]
            cache_key = f"{session_id}-{repo_name}"

            # Only process if not cached
            if cache_key not in st.session_state.file_cache:
                with st.spinner("Processing repository..."):
                    summary, tree, content = process_with_gitingets(github_url)

                    with tempfile.TemporaryDirectory() as tmpdir:
                        md_path = os.path.join(tmpdir, f"{repo_name}.md")
                        with open(md_path, "w", encoding="utf-8") as f:
                            f.write(content)

                        loader = SimpleDirectoryReader(input_dir=tmpdir)
                        docs = loader.load_data()

                        embed_model = HuggingFaceEmbedding(
                        model_name="nomic-ai/nomic-embed-text-v2-moe",
                        trust_remote_code=True,
                    )
                        Settings.embed_model = embed_model

                        llm_predictor = load_llm()
                        Settings.llm = llm_predictor

                        node_parser = MarkdownNodeParser()
                        index = VectorStoreIndex.from_documents(
                            documents=docs,
                            transformations=[node_parser],
                            show_progress=True,
                        )

                        qa_prompt = PromptTemplate(
                            "You are an AI assistant specialized in analyzing GitHub repositories.\n"
                            "Repository structure:\n{tree}\n---\n"
                            "Context:\n{context_str}\n---\n"
                            "Question: {query_str}\nAnswer:"
                        )
                        query_engine = index.as_query_engine(streaming=True)
                        query_engine.update_prompts({
                            "response_synthesizer:text_qa_template": qa_prompt
                        })

                        st.session_state.file_cache[cache_key] = (query_engine, tree)
                st.success("Repository loaded and indexed. Ready to chat!")
            else:
                st.info("Repository already loaded.")
        except Exception as e:
            st.error(f"Error loading repository: {e}")
            logger.error(f"Load error: {e}")

# --- Main UI: Chat Interface ---
col1, col2 = st.columns([6, 1])
with col1:
    st.header("Chat with GitHub RAG")
with col2:
    st.button("Clear Chat ↺", on_click=reset_chat)

# Display chat history
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

# Chat input box
if prompt := st.chat_input("Ask a question about the repository..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    repo_name = github_url.rstrip("/").split("/")[-1]
    cache_key = f"{session_id}-{repo_name}"

    if cache_key not in st.session_state.file_cache:
        st.error("Please load a repository first!")
    else:
        query_engine, tree = st.session_state.file_cache[cache_key]
        with st.chat_message("assistant"):
            placeholder = st.empty()
            response_text = ""
            try:
                response = query_engine.query(prompt)
                if hasattr(response, 'response_gen'):
                    for chunk in response.response_gen:
                        response_text += chunk
                        placeholder.markdown(response_text + "▌")
                else:
                    response_text = str(response)
                    placeholder.markdown(response_text)
            except GitHubRAGError as e:
                st.error(str(e))
                logger.error(f"Error in chat processing: {e}")
                response_text = "Sorry, I couldn't process that request."
            except Exception as e:
                st.error("An unexpected error occurred while processing your query")
                logger.error(f"Unexpected error in chat: {e}")
                response_text = "Sorry, something went wrong."
            placeholder.markdown(response_text)
            st.session_state.messages.append({"role": "assistant", "content": response_text})