Spaces:
Running
Running
File size: 5,185 Bytes
152d958 9b29b9a 152d958 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# app_hybrid_llm.py
import os
import re
import numpy as np
import faiss
import gradio as gr
import openai
from openai import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from sentence_transformers import SentenceTransformer
DARTMOUTH_CHAT_API_KEY = os.getenv('DARTMOUTH_CHAT_API_KEY')
if DARTMOUTH_CHAT_API_KEY is None:
raise ValueError("DARTMOUTH_CHAT_API_KEY not set.")
MODEL = "openai.gpt-4o-2024-08-06"
client = OpenAI(
base_url="https://chat.dartmouth.edu/api", # Replace with your endpoint URL
api_key=DARTMOUTH_CHAT_API_KEY, # Replace with your API key, if required
)
# --- Load and Prepare Data ---
with open("gen_agents.txt", "r", encoding="utf-8") as f:
full_text = f.read()
text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=512, chunk_overlap=20)
docs = text_splitter.create_documents([full_text])
passages = [doc.page_content for doc in docs]
embedder = SentenceTransformer('all-MiniLM-L6-v2')
passage_embeddings = embedder.encode(passages, convert_to_tensor=False, show_progress_bar=True)
passage_embeddings = np.array(passage_embeddings).astype("float32")
d = passage_embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(passage_embeddings)
# --- Provided Functions ---
def retrieve_passages(query, embedder, index, passages, top_k=3):
query_embedding = embedder.encode([query], convert_to_tensor=False)
query_embedding = np.array(query_embedding).astype('float32')
distances, indices = index.search(query_embedding, top_k)
retrieved = [passages[i] for i in indices[0]]
return retrieved
def process_llm_output_with_references(text, passages):
"""
Replace tokens like <<PASSAGE_1>> in the LLM output with HTML block quotes.
"""
def replacement(match):
num = int(match.group(1))
if 0 <= num < len(passages):
passage_text = passages[num - 1]
return (f"<blockquote style='background: #ffffff; color: #000000; padding: 10px; "
f"border-left: 5px solid #ccc; margin: 10px 0; font-size: 14px;'>{passage_text}</blockquote>")
return match.group(0)
return re.sub(r"<<PASSAGE_(\d+)>>", replacement, text)
def generate_answer_with_references(query, retrieved_text):
"""
Generate an answer using GPT-4 with reference tokens.
"""
context_str = "\n".join([f"<<PASSAGE_{i}>>: \"{passage}\"" for i, passage in enumerate(retrieved_text)])
messages = [
{"role": "system", "content": "You are a knowledgeable technical assistant."},
{"role": "user", "content": (
f"Using the following textbook passages as reference:\n{context_str}\n\n"
"In your answer, include passage block quotes as references. "
"Refer to the passages using tokens such as <<PASSAGE_0>>, <<PASSAGE_1>>, etc. "
"They should appear after complete thoughts on a new line.\n\n"
f"Answer the question: {query}"
)}
]
response = client.chat.completions.create(
model=MODEL,
messages=messages,
)
answer = response.choices[0].message.content.strip()
return answer
# --- Gradio App Function ---
def get_hybrid_output(query):
retrieved = retrieve_passages(query, embedder, index, passages, top_k=3)
hybrid_raw = generate_answer_with_references(query, retrieved)
hybrid_processed = process_llm_output_with_references(hybrid_raw, retrieved)
return f"<div style='white-space: pre-wrap;'>{hybrid_processed}</div>"
def clear_output():
return ""
# --- Custom CSS ---
custom_css = """
body {
background-color: #343541 !important;
color: #ECECEC !important;
margin: 0;
padding: 0;
font-family: 'Inter', sans-serif;
}
#container {
max-width: 900px;
margin: 0 auto;
padding: 20px;
}
label {
color: #ECECEC;
font-weight: 600;
}
textarea, input {
background-color: #40414F;
color: #ECECEC;
border: 1px solid #565869;
}
button {
background-color: #565869;
color: #ECECEC;
border: none;
font-weight: 600;
transition: background-color 0.2s ease;
}
button:hover {
background-color: #6e7283;
}
.output-box {
border: 1px solid #565869;
border-radius: 4px;
padding: 10px;
margin-top: 8px;
background-color: #40414F;
}
"""
# --- Build Gradio Interface ---
with gr.Blocks(css=custom_css) as demo:
with gr.Column(elem_id="container"):
gr.Markdown("## Anonymous Chatbot\n### Loaded Article: Generative Agents - Interactive Simulacra of Human Behavior (Park et al. 2023)\n [https://arxiv.org/pdf/2304.03442](https://arxiv.org/pdf/2304.03442)")
gr.Markdown("Enter any questions about the article above in the prompt!")
query_input = gr.Textbox(label="Query", placeholder="Enter your query here...", lines=1)
with gr.Column():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear")
output_box = gr.HTML(label="Output", elem_classes="output-box")
submit_button.click(fn=get_hybrid_output, inputs=query_input, outputs=output_box)
clear_button.click(fn=clear_output, inputs=[], outputs=output_box)
demo.launch()
|