File size: 5,185 Bytes
152d958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b29b9a
152d958
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# app_hybrid_llm.py
import os
import re
import numpy as np
import faiss
import gradio as gr
import openai
from openai import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from sentence_transformers import SentenceTransformer


DARTMOUTH_CHAT_API_KEY = os.getenv('DARTMOUTH_CHAT_API_KEY')
if DARTMOUTH_CHAT_API_KEY is None:
    raise ValueError("DARTMOUTH_CHAT_API_KEY not set.")

MODEL = "openai.gpt-4o-2024-08-06"

client = OpenAI(
    base_url="https://chat.dartmouth.edu/api",  # Replace with your endpoint URL
    api_key=DARTMOUTH_CHAT_API_KEY,  # Replace with your API key, if required
)

# --- Load and Prepare Data ---
with open("gen_agents.txt", "r", encoding="utf-8") as f:
    full_text = f.read()

text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=512, chunk_overlap=20)
docs = text_splitter.create_documents([full_text])
passages = [doc.page_content for doc in docs]

embedder = SentenceTransformer('all-MiniLM-L6-v2')
passage_embeddings = embedder.encode(passages, convert_to_tensor=False, show_progress_bar=True)
passage_embeddings = np.array(passage_embeddings).astype("float32")
d = passage_embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(passage_embeddings)

# --- Provided Functions ---
def retrieve_passages(query, embedder, index, passages, top_k=3):
    query_embedding = embedder.encode([query], convert_to_tensor=False)
    query_embedding = np.array(query_embedding).astype('float32')
    distances, indices = index.search(query_embedding, top_k)
    retrieved = [passages[i] for i in indices[0]]
    return retrieved

def process_llm_output_with_references(text, passages):
    """
    Replace tokens like <<PASSAGE_1>> in the LLM output with HTML block quotes.
    """
    def replacement(match):
        num = int(match.group(1))
        if 0 <= num < len(passages):
            passage_text = passages[num - 1]
            return (f"<blockquote style='background: #ffffff; color: #000000; padding: 10px; "
                    f"border-left: 5px solid #ccc; margin: 10px 0; font-size: 14px;'>{passage_text}</blockquote>")
        return match.group(0)
    return re.sub(r"<<PASSAGE_(\d+)>>", replacement, text)

def generate_answer_with_references(query, retrieved_text):
    """
    Generate an answer using GPT-4 with reference tokens.
    """
    context_str = "\n".join([f"<<PASSAGE_{i}>>: \"{passage}\"" for i, passage in enumerate(retrieved_text)])
    messages = [
        {"role": "system", "content": "You are a knowledgeable technical assistant."},
        {"role": "user", "content": (
            f"Using the following textbook passages as reference:\n{context_str}\n\n"
            "In your answer, include passage block quotes as references. "
            "Refer to the passages using tokens such as <<PASSAGE_0>>, <<PASSAGE_1>>, etc. "
            "They should appear after complete thoughts on a new line.\n\n"
            f"Answer the question: {query}"
        )}
    ]
    response = client.chat.completions.create(
        model=MODEL,
        messages=messages,
    )
    answer = response.choices[0].message.content.strip()
    return answer

# --- Gradio App Function ---
def get_hybrid_output(query):
    retrieved = retrieve_passages(query, embedder, index, passages, top_k=3)
    hybrid_raw = generate_answer_with_references(query, retrieved)
    hybrid_processed = process_llm_output_with_references(hybrid_raw, retrieved)
    return f"<div style='white-space: pre-wrap;'>{hybrid_processed}</div>"

def clear_output():
    return ""

# --- Custom CSS ---
custom_css = """
body {
    background-color: #343541 !important;
    color: #ECECEC !important;
    margin: 0;
    padding: 0;
    font-family: 'Inter', sans-serif;
}
#container {
    max-width: 900px;
    margin: 0 auto;
    padding: 20px;
}
label {
    color: #ECECEC;
    font-weight: 600;
}
textarea, input {
    background-color: #40414F;
    color: #ECECEC;
    border: 1px solid #565869;
}
button {
    background-color: #565869;
    color: #ECECEC;
    border: none;
    font-weight: 600;
    transition: background-color 0.2s ease;
}
button:hover {
    background-color: #6e7283;
}
.output-box {
    border: 1px solid #565869;
    border-radius: 4px;
    padding: 10px;
    margin-top: 8px;
    background-color: #40414F;
}
"""

# --- Build Gradio Interface ---
with gr.Blocks(css=custom_css) as demo:
    with gr.Column(elem_id="container"):
        gr.Markdown("## Anonymous Chatbot\n### Loaded Article: Generative Agents - Interactive Simulacra of Human Behavior (Park et al. 2023)\n [https://arxiv.org/pdf/2304.03442](https://arxiv.org/pdf/2304.03442)")
        gr.Markdown("Enter any questions about the article above in the prompt!")
        query_input = gr.Textbox(label="Query", placeholder="Enter your query here...", lines=1)
        with gr.Column():
            submit_button = gr.Button("Submit")
            clear_button = gr.Button("Clear")
        output_box = gr.HTML(label="Output", elem_classes="output-box")
        
        submit_button.click(fn=get_hybrid_output, inputs=query_input, outputs=output_box)
        clear_button.click(fn=clear_output, inputs=[], outputs=output_box)

demo.launch()