r1-agents / app.py
wuhp's picture
Update app.py
b9d6d53 verified
raw
history blame
18.6 kB
import gradio as gr
import os
import spaces # Import the spaces library
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
import torch
from threading import Thread
import logging
from typing import Tuple, List, Dict, Generator
# --- Logging Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Model & Quantization Settings ---
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
# Dictionaries to store the loaded model and tokenizer
models: Dict[str, AutoModelForCausalLM] = {}
tokenizers: Dict[str, AutoTokenizer] = {}
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
)
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""
Lazy-load the model and tokenizer if not already loaded.
Returns:
Tuple[model, tokenizer]: The loaded model and tokenizer.
"""
if "7B" not in models:
logging.info(f"Loading 7B model: {MODEL_ID} on demand")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config_4bit,
torch_dtype=torch.bfloat16, # Or torch.float16 if needed
device_map='auto',
trust_remote_code=True,
)
model.eval() # Set the model to evaluation mode
models["7B"] = model
tokenizers["7B"] = tokenizer
logging.info("Loaded 7B model on demand.")
except Exception as e:
logging.error(f"Failed to load model and tokenizer: {e}")
raise e
return models["7B"], tokenizers["7B"]
# --- Default Prompt Templates ---
default_prompt_brainstorm = """**Brainstorming Task (Round 1)**
As a Senior Code Analyst, provide an initial analysis of the problem below.
**User Request:**
{user_prompt}
**Guidelines:**
1. Identify key challenges and constraints.
2. Suggest multiple potential approaches.
3. Outline any potential edge cases or critical considerations.
"""
default_prompt_code_generation = """**Advanced Reasoning & Code Generation (Round 2)**
Based on the initial analysis below:
**Initial Analysis:**
{brainstorm_response}
**User Request:**
{user_prompt}
**Task:**
1. Develop a detailed solution that includes production-ready code.
2. Explain the reasoning behind the chosen approach.
3. Incorporate advanced reasoning to handle edge cases.
4. Provide commented code that is clear and maintainable.
"""
default_prompt_synthesis = """**Synthesis & Final Refinement (Round 3)**
Review the detailed code generation and reasoning below, and produce a final, refined response that:
1. Synthesizes the brainstorming insights and advanced reasoning.
2. Provides a concise summary of the solution.
3. Highlights any potential improvements or considerations.
**Detailed Response:**
{code_response}
"""
# --- Memory Management ---
class MemoryManager:
"""Encapsulate shared memory for storing and retrieving conversation items."""
def __init__(self) -> None:
self.shared_memory: List[str] = []
def store(self, item: str) -> None:
"""
Store a memory item and log an excerpt.
Args:
item (str): The memory content to store.
"""
self.shared_memory.append(item)
logging.info(f"[Memory Stored]: {item[:50]}...")
def retrieve(self, query: str, top_k: int = 3) -> List[str]:
"""
Retrieve memory items that contain the query text (case-insensitive).
Args:
query (str): The text query to search for.
top_k (int): Maximum number of memory items to return.
Returns:
List[str]: A list of up to top_k memory items.
"""
query_lower = query.lower()
relevant = [item for item in self.shared_memory if query_lower in item.lower()]
if not relevant:
logging.info("[Memory Retrieval]: No relevant memories found.")
else:
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
return relevant[:top_k]
# Create a global memory manager instance for RAG purposes.
global_memory_manager = MemoryManager()
# --- Multi-Round Swarm Agent Function ---
@spaces.GPU(duration=180) # Adjust duration as needed
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
prompt_brainstorm_text: str, prompt_code_generation_text: str, prompt_synthesis_text: str
) -> Generator[str, None, None]:
"""
A three-round iterative process that uses the provided prompt templates:
- Round 1: Brainstorming.
- Round 2: Advanced reasoning & code generation.
- Round 3: Synthesis & refinement.
This generator yields the response from the final round as it is produced.
Yields:
str: Progressive updates of the final response.
"""
model, tokenizer = get_model_and_tokenizer()
# ----- Round 1: Brainstorming -----
logging.info("--- Round 1: Brainstorming ---")
prompt_r1 = prompt_brainstorm_text.format(user_prompt=user_prompt)
input_ids_r1 = tokenizer.encode(prompt_r1, return_tensors="pt").to(model.device)
streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs_r1 = dict(
input_ids=input_ids_r1,
streamer=streamer_r1,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temp,
top_p=top_p,
)
try:
thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1)
with torch.no_grad():
thread_r1.start()
except Exception as e:
logging.error(f"Error starting Round 1 thread: {e}")
raise e
brainstorm_response = ""
try:
for text in streamer_r1:
logging.info(text)
brainstorm_response += text
except Exception as e:
logging.error(f"Error during Round 1 generation: {e}")
raise e
thread_r1.join()
global_memory_manager.store(f"Brainstorm Response: {brainstorm_response[:200]}...")
# ----- Round 2: Code Generation -----
logging.info("--- Round 2: Code Generation ---")
prompt_r2 = prompt_code_generation_text.format(
brainstorm_response=brainstorm_response,
user_prompt=user_prompt
)
input_ids_r2 = tokenizer.encode(prompt_r2, return_tensors="pt").to(model.device)
streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs_r2 = dict(
input_ids=input_ids_r2,
streamer=streamer_r2,
max_new_tokens=max_new_tokens + 100, # extra tokens for detail
temperature=temp,
top_p=top_p,
)
try:
thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2)
with torch.no_grad():
thread_r2.start()
except Exception as e:
logging.error(f"Error starting Round 2 thread: {e}")
raise e
code_response = ""
try:
for text in streamer_r2:
logging.info(text)
code_response += text
except Exception as e:
logging.error(f"Error during Round 2 generation: {e}")
raise e
thread_r2.join()
global_memory_manager.store(f"Code Generation Response: {code_response[:200]}...")
# ----- Round 3: Synthesis & Refinement -----
logging.info("--- Round 3: Synthesis & Refinement ---")
prompt_r3 = prompt_synthesis_text.format(code_response=code_response)
input_ids_r3 = tokenizer.encode(prompt_r3, return_tensors="pt").to(model.device)
streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs_r3 = dict(
input_ids=input_ids_r3,
streamer=streamer_r3,
max_new_tokens=max_new_tokens // 2,
temperature=temp,
top_p=top_p,
)
try:
thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3)
with torch.no_grad():
thread_r3.start()
except Exception as e:
logging.error(f"Error starting Round 3 thread: {e}")
raise e
final_response = ""
try:
for text in streamer_r3:
logging.info(text)
final_response += text
yield final_response # Yield progressive updates
except Exception as e:
logging.error(f"Error during Round 3 generation: {e}")
raise e
thread_r3.join()
global_memory_manager.store(f"Final Synthesis Response: {final_response[:200]}...")
# --- Explanation Function for Puns ---
def handle_explanation_request(user_prompt: str) -> str:
"""
If the user asks for an explanation of the puns, this function retrieves
relevant stored memory items (which are expected to include pun examples) and
constructs a new prompt to generate a detailed explanation.
Args:
user_prompt (str): The user request (e.g. "explain the different puns you mentioned")
Returns:
str: The explanation generated by the model.
"""
# Retrieve memory items that contain "pun" (assuming previous outputs include puns)
retrieved = global_memory_manager.retrieve("pun", top_k=3)
if not retrieved:
explanation_prompt = "No previous puns found to explain. Please provide the pun examples."
else:
explanation_prompt = "Please explain the following coding puns in detail:\n\n"
for item in retrieved:
explanation_prompt += f"- {item}\n"
explanation_prompt += "\nProvide a detailed explanation for each pun."
model, tokenizer = get_model_and_tokenizer()
input_ids = tokenizer.encode(explanation_prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
)
try:
thread = Thread(target=model.generate, kwargs=kwargs)
with torch.no_grad():
thread.start()
except Exception as e:
logging.error(f"Error starting explanation thread: {e}")
raise e
explanation = ""
try:
for text in streamer:
explanation += text
except Exception as e:
logging.error(f"Error during explanation generation: {e}")
raise e
thread.join()
return explanation
# --- Helper to Format History ---
def format_history(history: List) -> List[Dict[str, str]]:
"""
Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries)
into a list of OpenAI-style message dictionaries.
Args:
history (List): List of conversation items.
Returns:
List[Dict[str, str]]: A list of formatted message dictionaries.
"""
messages = []
for item in history:
# If item is a list or tuple, try to unpack it if it has exactly 2 elements.
if isinstance(item, (list, tuple)) and len(item) == 2:
user_msg, assistant_msg = item
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
elif isinstance(item, dict):
messages.append(item)
return messages
# --- Gradio Chat Interface Function ---
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]:
"""
This function is called by Gradio's ChatInterface.
It uses the current saved generation parameters and prompt templates.
If the user request appears to ask for an explanation of puns,
it routes the request to the explanation function.
Args:
message (str): The user message.
history (List): The conversation history.
param_state (Dict): Generation parameters.
prompt_state (Dict): Prompt templates.
Yields:
Generator[List[Dict[str, str]]]: Updated history in OpenAI-style message dictionaries.
"""
# Check if the user is asking to explain puns.
if "explain" in message.lower() and "pun" in message.lower():
explanation = handle_explanation_request(message)
history = history + [[message, explanation]]
yield format_history(history)
return
try:
temp = float(param_state.get("temperature", 0.5))
top_p = float(param_state.get("top_p", 0.9))
max_new_tokens = int(param_state.get("max_new_tokens", 300))
memory_top_k = int(param_state.get("memory_top_k", 2))
except Exception as e:
logging.error(f"Parameter conversion error: {e}")
temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2
prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm)
prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation)
prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis)
# Append the new user message with an empty assistant reply (as a two-item list)
history = history + [[message, ""]]
# Call the multi-round agent as a generator (for streaming)
for partial_response in swarm_agent_iterative(
user_prompt=message,
temp=temp,
top_p=top_p,
max_new_tokens=max_new_tokens,
memory_top_k=memory_top_k,
prompt_brainstorm_text=prompt_brainstorm_text,
prompt_code_generation_text=prompt_code_generation_text,
prompt_synthesis_text=prompt_synthesis_text
):
# Update the last assistant message with the new partial response.
history[-1][1] = partial_response
yield format_history(history)
# --- UI Settings & Styling ---
ui_description = '''
<div>
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat</h1>
<p style="text-align: center;">
Multi-round agent:
<br>- Brainstorming
<br>- Advanced reasoning & code generation
<br>- Synthesis & refinement
</p>
</div>
'''
ui_license = """
<p/>
---
"""
ui_placeholder = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">DeepSeek Agent Swarm</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# --- Gradio UI ---
with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
gr.Markdown(ui_description)
# Hidden States to hold parameters and prompt configuration
param_state = gr.State({
"temperature": 0.5,
"top_p": 0.9,
"max_new_tokens": 300,
"memory_top_k": 2,
})
prompt_state = gr.State({
"prompt_brainstorm": default_prompt_brainstorm,
"prompt_code_generation": default_prompt_code_generation,
"prompt_synthesis": default_prompt_synthesis,
})
# Create top-level Tabs
with gr.Tabs():
# --- Chat Tab ---
with gr.Tab("Chat"):
chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
gr.ChatInterface(
fn=gradio_interface,
chatbot=chatbot,
additional_inputs=[param_state, prompt_state],
examples=[
['How can we build a robust web service that scales efficiently under load?'],
['Explain how to design a fault-tolerant distributed system.'],
['Develop a streamlit app that visualizes real-time financial data.'],
['Create a pun-filled birthday message with a coding twist.'],
['Design a system that uses machine learning to optimize resource allocation.']
],
cache_examples=False,
type="messages",
)
# --- Parameters Tab ---
with gr.Tab("Parameters"):
gr.Markdown("### Generation Parameters")
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
save_params_btn = gr.Button("Save Parameters")
save_params_btn.click(
lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k},
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider],
outputs=param_state,
)
# --- Prompt Config Tab ---
with gr.Tab("Prompt Config"):
gr.Markdown("### Configure Prompt Templates")
prompt_brainstorm_box = gr.Textbox(
value=default_prompt_brainstorm,
label="Brainstorm Prompt",
lines=8,
)
prompt_code_generation_box = gr.Textbox(
value=default_prompt_code_generation,
label="Code Generation Prompt",
lines=8,
)
prompt_synthesis_box = gr.Textbox(
value=default_prompt_synthesis,
label="Synthesis Prompt",
lines=8,
)
save_prompts_btn = gr.Button("Save Prompts")
save_prompts_btn.click(
lambda b, c, s: {
"prompt_brainstorm": b,
"prompt_code_generation": c,
"prompt_synthesis": s,
},
inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box],
outputs=prompt_state,
)
gr.Markdown(ui_license)
if __name__ == "__main__":
demo.launch()