File size: 6,567 Bytes
63c9177
0588a72
3ad05c5
d0d60f4
0588a72
280275c
7f94b06
0588a72
280275c
0588a72
 
079c63d
616220d
280275c
5ad9b18
616220d
280275c
5ad9b18
 
 
 
280275c
 
5ad9b18
 
0588a72
280275c
0588a72
 
de624eb
0588a72
 
 
 
 
 
280275c
0588a72
 
 
3ad05c5
 
 
0588a72
 
 
3ad05c5
280275c
3ad05c5
280275c
3ad05c5
 
280275c
 
 
 
 
 
 
3ad05c5
 
 
 
 
5ad9b18
280275c
5ad9b18
0588a72
616220d
 
5ad9b18
 
280275c
5ad9b18
616220d
280275c
5ad9b18
616220d
280275c
616220d
 
 
 
 
 
280275c
616220d
 
5ad9b18
616220d
995a0e5
280275c
0588a72
079c63d
d0d60f4
079c63d
 
3ad05c5
 
 
 
 
280275c
88b9dfd
 
280275c
88b9dfd
d0d60f4
939e16f
 
d0d60f4
280275c
d0d60f4
 
5ad9b18
280275c
d0d60f4
280275c
616220d
0588a72
280275c
3ad05c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa05415
280275c
0588a72
5ad9b18
d0d60f4
5ad9b18
aa05415
5ad9b18
aa05415
5ad9b18
0588a72
280275c
0588a72
 
280275c
3ad05c5
5ad9b18
d0d60f4
5ad9b18
3ad05c5
5ad9b18
3ad05c5
5ad9b18
3ad05c5
280275c
3ad05c5
 
280275c
0588a72
5ad9b18
0588a72
d0d60f4
0588a72
 
280275c
0588a72
d0d60f4
3ad05c5
 
 
0588a72
 
280275c
2a08ae8
88b9dfd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import requests
import re
import os

# Load from environment or fallback to default
API_ENDPOINT = os.getenv("API_ENDPOINT")
API_TOKEN = os.getenv("API_TOKEN")
MODEL_ID = os.getenv("MODEL_ID", "none")

def get_ai_response(message, history):
    """Fetch AI response from the API using the modern messages format."""
    messages = [{"role": "system", "content": "You are a helpful assistant."}]
    
    for user_msg, ai_msg in history:
        if ai_msg != "⏳ Thinking...":
            # Clean HTML from AI messages to avoid nesting artifacts
            clean_ai_msg = re.sub(r'<details>.*?</details>', '', ai_msg, flags=re.DOTALL)
            clean_ai_msg = re.sub(r'<[^>]*>', '', clean_ai_msg)
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": clean_ai_msg})
    
    # Add latest user message
    messages.append({"role": "user", "content": message})
    
    payload = {
        "model": MODEL_ID,
        "messages": messages,
        "stream": False,
        "max_tokens": 10000,
        "temperature": 0.7
    }
    headers = {
        "Authorization": f"Bearer {API_TOKEN}",
        "Content-Type": "application/json"
    }

    try:
        response = requests.post(API_ENDPOINT, headers=headers, json=payload)
        response.raise_for_status()
        raw_response = response.json()["choices"][0]["message"]["content"]
        html_response = convert_reasoning_to_collapsible(raw_response)
        return html_response
    except Exception as e:
        return f"Error: {str(e)}"

def convert_reasoning_to_collapsible(text):
    """Convert <reasoning> tags into collapsible HTML elements."""
    reasoning_pattern = re.compile(r'<reasoning>(.*?)</reasoning>', re.DOTALL)
    
    def replace_with_collapsible(match):
        reasoning_content = match.group(1).strip()
        return (
            f'<details>'
            f'<summary><strong>See reasoning</strong></summary>'
            f'<div class="reasoning-content">{reasoning_content}</div>'
            f'</details>'
        )
    
    html_response = reasoning_pattern.sub(replace_with_collapsible, text)
    html_response = re.sub(r'<sep>.*?</sep>', '', html_response, flags=re.DOTALL)
    html_response = html_response.replace('<sep>', '').replace('</sep>', '')
    return html_response

def add_user_message(message, history):
    """Add user message with a placeholder AI response ('⏳ Thinking...')."""
    if history is None:
        history = []
    history.append((message, "⏳ Thinking..."))
    return history, history

def generate_response_from_history(history):
    """Replace last '⏳ Thinking...' with real assistant response."""
    if not history:
        return history, history

    last_user_message = history[-1][0]
    api_history = []

    for user_msg, ai_msg in history:
        if ai_msg != "⏳ Thinking...":
            clean_ai_msg = re.sub(r'<details>.*?</details>', '', ai_msg, flags=re.DOTALL)
            clean_ai_msg = re.sub(r'<[^>]*>', '', clean_ai_msg)
            api_history.append({"role": "user", "content": user_msg})
            api_history.append({"role": "assistant", "content": clean_ai_msg})

    api_history.append({"role": "user", "content": last_user_message})
    ai_response = get_ai_response(last_user_message, api_history)
    history[-1] = (last_user_message, ai_response)
    return history, history

# CSS for dark mode + collapsible sections
custom_css = """
body { background-color: #1a1a1a; color: #ffffff; font-family: 'Arial', sans-serif; }
#chatbot { height: 80vh; background-color: #2d2d2d; border: 1px solid #404040; border-radius: 8px; }
input, button { background-color: #333333; color: #ffffff; border: 1px solid #404040; border-radius: 5px; }
button:hover { background-color: #404040; }
details { background-color: #333333; padding: 10px; margin: 5px 0; border-radius: 5px; }
summary { cursor: pointer; color: #70a9e6; }
.reasoning-content { padding: 10px; margin-top: 5px; background-color: #404040; border-radius: 5px; }
"""

# Set model name for UI title
model_display_name = MODEL_ID

# Gradio UI definition
with gr.Blocks(css=custom_css, title=model_display_name) as demo:
    with gr.Column():
        gr.Markdown("## nvidia-Llama-3_1-Nemotron-Ultra-253B-v1 Demo")
        gr.Markdown("This is a demo of nvidia-Llama-3_1-Nemotron-Ultra-253B-v1")
        chatbot = gr.Chatbot(elem_id="chatbot", render_markdown=False, bubble_full_width=True)
        
        with gr.Row():
            message = gr.Textbox(placeholder="Type your message...", show_label=False, container=False)
            submit_btn = gr.Button("Send", size="lg")
        
        clear_chat_btn = gr.Button("Clear Chat")

    chat_state = gr.State([])

    # JS to allow rendering HTML in the chat
    js = """
    function() {
        const observer = new MutationObserver(function(mutations) {
            mutations.forEach(function(mutation) {
                if (mutation.addedNodes.length) {
                    document.querySelectorAll('#chatbot .message:not(.processed)').forEach(msg => {
                        msg.classList.add('processed');
                        const content = msg.querySelector('.content');
                        if (content) {
                            content.innerHTML = content.textContent;
                        }
                    });
                }
            });
        });
        const chatbot = document.getElementById('chatbot');
        if (chatbot) {
            observer.observe(chatbot, { childList: true, subtree: true });
        }
        return [];
    }
    """

    # Event: Send button clicked
    submit_btn.click(
        add_user_message,
        [message, chat_state],
        [chat_state, chatbot]
    ).then(
        generate_response_from_history,
        chat_state,
        [chat_state, chatbot]
    ).then(
        lambda: "", None, message  # clear textbox
    )

    # Event: Pressing Enter key in Textbox
    message.submit(
        add_user_message,
        [message, chat_state],
        [chat_state, chatbot]
    ).then(
        generate_response_from_history,
        chat_state,
        [chat_state, chatbot]
    ).then(
        lambda: "", None, message
    )

    # Clear chat
    clear_chat_btn.click(
        lambda: ([], []),
        None,
        [chat_state, chatbot]
    )

    # Load JS on UI load
    demo.load(
        fn=lambda: None,
        inputs=None,
        outputs=None,
        js=js
    )

# Launch Gradio interface
demo.queue()
demo.launch()