|
import gradio as gr |
|
import os |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer |
|
import torch |
|
from threading import Thread |
|
import logging |
|
from typing import Tuple, List, Dict, Generator |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" |
|
|
|
|
|
models: Dict[str, AutoModelForCausalLM] = {} |
|
tokenizers: Dict[str, AutoTokenizer] = {} |
|
|
|
bnb_config_4bit = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: |
|
""" |
|
Lazy-load the model and tokenizer if not already loaded. |
|
|
|
Returns: |
|
Tuple[model, tokenizer]: The loaded model and tokenizer. |
|
""" |
|
if "7B" not in models: |
|
logging.info(f"Loading 7B model: {MODEL_ID} on demand") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
quantization_config=bnb_config_4bit, |
|
torch_dtype=torch.bfloat16, |
|
device_map='auto', |
|
trust_remote_code=True, |
|
) |
|
model.eval() |
|
models["7B"] = model |
|
tokenizers["7B"] = tokenizer |
|
logging.info("Loaded 7B model on demand.") |
|
except Exception as e: |
|
logging.error(f"Failed to load model and tokenizer: {e}") |
|
raise e |
|
return models["7B"], tokenizers["7B"] |
|
|
|
|
|
|
|
default_prompt_brainstorm = """**Brainstorming Task (Round 1)** |
|
As a Senior Code Analyst, provide an initial analysis of the problem below. |
|
|
|
**User Request:** |
|
{user_prompt} |
|
|
|
**Guidelines:** |
|
1. Identify key challenges and constraints. |
|
2. Suggest multiple potential approaches. |
|
3. Outline any potential edge cases or critical considerations. |
|
""" |
|
|
|
default_prompt_code_generation = """**Advanced Reasoning & Code Generation (Round 2)** |
|
Based on the initial analysis below: |
|
|
|
**Initial Analysis:** |
|
{brainstorm_response} |
|
|
|
**User Request:** |
|
{user_prompt} |
|
|
|
**Task:** |
|
1. Develop a detailed solution that includes production-ready code. |
|
2. Explain the reasoning behind the chosen approach. |
|
3. Incorporate advanced reasoning to handle edge cases. |
|
4. Provide commented code that is clear and maintainable. |
|
""" |
|
|
|
default_prompt_synthesis = """**Synthesis & Final Refinement (Round 3)** |
|
Review the detailed code generation and reasoning below, and produce a final, refined response that: |
|
1. Synthesizes the brainstorming insights and advanced reasoning. |
|
2. Provides a concise summary of the solution. |
|
3. Highlights any potential improvements or considerations. |
|
|
|
**Detailed Response:** |
|
{code_response} |
|
""" |
|
|
|
|
|
|
|
class MemoryManager: |
|
"""Encapsulate shared memory for storing and retrieving conversation items.""" |
|
def __init__(self) -> None: |
|
self.shared_memory: List[str] = [] |
|
|
|
def store(self, item: str) -> None: |
|
""" |
|
Store a memory item and log an excerpt. |
|
|
|
Args: |
|
item (str): The memory content to store. |
|
""" |
|
self.shared_memory.append(item) |
|
logging.info(f"[Memory Stored]: {item[:50]}...") |
|
|
|
def retrieve(self, query: str, top_k: int = 3) -> List[str]: |
|
""" |
|
Retrieve memory items that contain the query text (case-insensitive). |
|
|
|
Args: |
|
query (str): The text query to search for. |
|
top_k (int): Maximum number of memory items to return. |
|
|
|
Returns: |
|
List[str]: A list of up to top_k memory items. |
|
""" |
|
query_lower = query.lower() |
|
relevant = [item for item in self.shared_memory if query_lower in item.lower()] |
|
if not relevant: |
|
logging.info("[Memory Retrieval]: No relevant memories found.") |
|
else: |
|
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.") |
|
return relevant[:top_k] |
|
|
|
|
|
global_memory_manager = MemoryManager() |
|
|
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int, |
|
prompt_brainstorm_text: str, prompt_code_generation_text: str, prompt_synthesis_text: str |
|
) -> Generator[str, None, None]: |
|
""" |
|
A three-round iterative process that uses the provided prompt templates: |
|
- Round 1: Brainstorming. |
|
- Round 2: Advanced reasoning & code generation. |
|
- Round 3: Synthesis & refinement. |
|
|
|
This generator yields the response from the final round as it is produced. |
|
|
|
Yields: |
|
str: Progressive updates of the final response. |
|
""" |
|
model, tokenizer = get_model_and_tokenizer() |
|
|
|
|
|
logging.info("--- Round 1: Brainstorming ---") |
|
prompt_r1 = prompt_brainstorm_text.format(user_prompt=user_prompt) |
|
input_ids_r1 = tokenizer.encode(prompt_r1, return_tensors="pt").to(model.device) |
|
streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
kwargs_r1 = dict( |
|
input_ids=input_ids_r1, |
|
streamer=streamer_r1, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temp, |
|
top_p=top_p, |
|
) |
|
try: |
|
thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1) |
|
with torch.no_grad(): |
|
thread_r1.start() |
|
except Exception as e: |
|
logging.error(f"Error starting Round 1 thread: {e}") |
|
raise e |
|
|
|
brainstorm_response = "" |
|
try: |
|
for text in streamer_r1: |
|
logging.info(text) |
|
brainstorm_response += text |
|
except Exception as e: |
|
logging.error(f"Error during Round 1 generation: {e}") |
|
raise e |
|
thread_r1.join() |
|
global_memory_manager.store(f"Brainstorm Response: {brainstorm_response[:200]}...") |
|
|
|
|
|
logging.info("--- Round 2: Code Generation ---") |
|
prompt_r2 = prompt_code_generation_text.format( |
|
brainstorm_response=brainstorm_response, |
|
user_prompt=user_prompt |
|
) |
|
input_ids_r2 = tokenizer.encode(prompt_r2, return_tensors="pt").to(model.device) |
|
streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
kwargs_r2 = dict( |
|
input_ids=input_ids_r2, |
|
streamer=streamer_r2, |
|
max_new_tokens=max_new_tokens + 100, |
|
temperature=temp, |
|
top_p=top_p, |
|
) |
|
try: |
|
thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2) |
|
with torch.no_grad(): |
|
thread_r2.start() |
|
except Exception as e: |
|
logging.error(f"Error starting Round 2 thread: {e}") |
|
raise e |
|
|
|
code_response = "" |
|
try: |
|
for text in streamer_r2: |
|
logging.info(text) |
|
code_response += text |
|
except Exception as e: |
|
logging.error(f"Error during Round 2 generation: {e}") |
|
raise e |
|
thread_r2.join() |
|
global_memory_manager.store(f"Code Generation Response: {code_response[:200]}...") |
|
|
|
|
|
logging.info("--- Round 3: Synthesis & Refinement ---") |
|
prompt_r3 = prompt_synthesis_text.format(code_response=code_response) |
|
input_ids_r3 = tokenizer.encode(prompt_r3, return_tensors="pt").to(model.device) |
|
streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
kwargs_r3 = dict( |
|
input_ids=input_ids_r3, |
|
streamer=streamer_r3, |
|
max_new_tokens=max_new_tokens // 2, |
|
temperature=temp, |
|
top_p=top_p, |
|
) |
|
try: |
|
thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3) |
|
with torch.no_grad(): |
|
thread_r3.start() |
|
except Exception as e: |
|
logging.error(f"Error starting Round 3 thread: {e}") |
|
raise e |
|
|
|
final_response = "" |
|
try: |
|
for text in streamer_r3: |
|
logging.info(text) |
|
final_response += text |
|
yield final_response |
|
except Exception as e: |
|
logging.error(f"Error during Round 3 generation: {e}") |
|
raise e |
|
thread_r3.join() |
|
global_memory_manager.store(f"Final Synthesis Response: {final_response[:200]}...") |
|
|
|
|
|
|
|
def handle_explanation_request(user_prompt: str) -> str: |
|
""" |
|
If the user asks for an explanation of the puns, this function retrieves |
|
relevant stored memory items (which are expected to include pun examples) and |
|
constructs a new prompt to generate a detailed explanation. |
|
|
|
Args: |
|
user_prompt (str): The user request (e.g. "explain the different puns you mentioned") |
|
|
|
Returns: |
|
str: The explanation generated by the model. |
|
""" |
|
|
|
retrieved = global_memory_manager.retrieve("pun", top_k=3) |
|
if not retrieved: |
|
explanation_prompt = "No previous puns found to explain. Please provide the pun examples." |
|
else: |
|
explanation_prompt = "Please explain the following coding puns in detail:\n\n" |
|
for item in retrieved: |
|
explanation_prompt += f"- {item}\n" |
|
explanation_prompt += "\nProvide a detailed explanation for each pun." |
|
|
|
model, tokenizer = get_model_and_tokenizer() |
|
input_ids = tokenizer.encode(explanation_prompt, return_tensors="pt").to(model.device) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
max_new_tokens=300, |
|
temperature=0.7, |
|
top_p=0.9, |
|
) |
|
try: |
|
thread = Thread(target=model.generate, kwargs=kwargs) |
|
with torch.no_grad(): |
|
thread.start() |
|
except Exception as e: |
|
logging.error(f"Error starting explanation thread: {e}") |
|
raise e |
|
|
|
explanation = "" |
|
try: |
|
for text in streamer: |
|
explanation += text |
|
except Exception as e: |
|
logging.error(f"Error during explanation generation: {e}") |
|
raise e |
|
thread.join() |
|
return explanation |
|
|
|
|
|
|
|
def format_history(history: List) -> List[Dict[str, str]]: |
|
""" |
|
Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries) |
|
into a list of OpenAI-style message dictionaries. |
|
|
|
Args: |
|
history (List): List of conversation items. |
|
|
|
Returns: |
|
List[Dict[str, str]]: A list of formatted message dictionaries. |
|
""" |
|
messages = [] |
|
for item in history: |
|
|
|
if isinstance(item, (list, tuple)) and len(item) == 2: |
|
user_msg, assistant_msg = item |
|
messages.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
elif isinstance(item, dict): |
|
messages.append(item) |
|
return messages |
|
|
|
|
|
|
|
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]: |
|
""" |
|
This function is called by Gradio's ChatInterface. |
|
It uses the current saved generation parameters and prompt templates. |
|
If the user request appears to ask for an explanation of puns, |
|
it routes the request to the explanation function. |
|
|
|
Args: |
|
message (str): The user message. |
|
history (List): The conversation history. |
|
param_state (Dict): Generation parameters. |
|
prompt_state (Dict): Prompt templates. |
|
|
|
Yields: |
|
Generator[List[Dict[str, str]]]: Updated history in OpenAI-style message dictionaries. |
|
""" |
|
|
|
if "explain" in message.lower() and "pun" in message.lower(): |
|
explanation = handle_explanation_request(message) |
|
history = history + [[message, explanation]] |
|
yield format_history(history) |
|
return |
|
|
|
try: |
|
temp = float(param_state.get("temperature", 0.5)) |
|
top_p = float(param_state.get("top_p", 0.9)) |
|
max_new_tokens = int(param_state.get("max_new_tokens", 300)) |
|
memory_top_k = int(param_state.get("memory_top_k", 2)) |
|
except Exception as e: |
|
logging.error(f"Parameter conversion error: {e}") |
|
temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2 |
|
|
|
prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm) |
|
prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation) |
|
prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis) |
|
|
|
|
|
history = history + [[message, ""]] |
|
|
|
|
|
for partial_response in swarm_agent_iterative( |
|
user_prompt=message, |
|
temp=temp, |
|
top_p=top_p, |
|
max_new_tokens=max_new_tokens, |
|
memory_top_k=memory_top_k, |
|
prompt_brainstorm_text=prompt_brainstorm_text, |
|
prompt_code_generation_text=prompt_code_generation_text, |
|
prompt_synthesis_text=prompt_synthesis_text |
|
): |
|
|
|
history[-1][1] = partial_response |
|
yield format_history(history) |
|
|
|
|
|
|
|
ui_description = ''' |
|
<div> |
|
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat</h1> |
|
<p style="text-align: center;"> |
|
Multi-round agent: |
|
<br>- Brainstorming |
|
<br>- Advanced reasoning & code generation |
|
<br>- Synthesis & refinement |
|
</p> |
|
</div> |
|
''' |
|
|
|
ui_license = """ |
|
<p/> |
|
--- |
|
""" |
|
|
|
ui_placeholder = """ |
|
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> |
|
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">DeepSeek Agent Swarm</h1> |
|
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p> |
|
</div> |
|
""" |
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display: block; |
|
} |
|
#duplicate-button { |
|
margin: auto; |
|
color: white; |
|
background: #1565c0; |
|
border-radius: 100vh; |
|
} |
|
""" |
|
|
|
|
|
|
|
with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo: |
|
gr.Markdown(ui_description) |
|
|
|
|
|
param_state = gr.State({ |
|
"temperature": 0.5, |
|
"top_p": 0.9, |
|
"max_new_tokens": 300, |
|
"memory_top_k": 2, |
|
}) |
|
prompt_state = gr.State({ |
|
"prompt_brainstorm": default_prompt_brainstorm, |
|
"prompt_code_generation": default_prompt_code_generation, |
|
"prompt_synthesis": default_prompt_synthesis, |
|
}) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
with gr.Tab("Chat"): |
|
chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages") |
|
gr.ChatInterface( |
|
fn=gradio_interface, |
|
chatbot=chatbot, |
|
additional_inputs=[param_state, prompt_state], |
|
examples=[ |
|
['How can we build a robust web service that scales efficiently under load?'], |
|
['Explain how to design a fault-tolerant distributed system.'], |
|
['Develop a streamlit app that visualizes real-time financial data.'], |
|
['Create a pun-filled birthday message with a coding twist.'], |
|
['Design a system that uses machine learning to optimize resource allocation.'] |
|
], |
|
cache_examples=False, |
|
type="messages", |
|
) |
|
|
|
|
|
with gr.Tab("Parameters"): |
|
gr.Markdown("### Generation Parameters") |
|
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature") |
|
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P") |
|
max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0) |
|
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K") |
|
save_params_btn = gr.Button("Save Parameters") |
|
save_params_btn.click( |
|
lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k}, |
|
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider], |
|
outputs=param_state, |
|
) |
|
|
|
|
|
with gr.Tab("Prompt Config"): |
|
gr.Markdown("### Configure Prompt Templates") |
|
prompt_brainstorm_box = gr.Textbox( |
|
value=default_prompt_brainstorm, |
|
label="Brainstorm Prompt", |
|
lines=8, |
|
) |
|
prompt_code_generation_box = gr.Textbox( |
|
value=default_prompt_code_generation, |
|
label="Code Generation Prompt", |
|
lines=8, |
|
) |
|
prompt_synthesis_box = gr.Textbox( |
|
value=default_prompt_synthesis, |
|
label="Synthesis Prompt", |
|
lines=8, |
|
) |
|
save_prompts_btn = gr.Button("Save Prompts") |
|
save_prompts_btn.click( |
|
lambda b, c, s: { |
|
"prompt_brainstorm": b, |
|
"prompt_code_generation": c, |
|
"prompt_synthesis": s, |
|
}, |
|
inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box], |
|
outputs=prompt_state, |
|
) |
|
|
|
gr.Markdown(ui_license) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |