Enhanced_Chat_Interface / lmstudio_gradio.py
Aborman's picture
Upload folder using huggingface_hub
0f394de verified
import gradio as gr
import httpx
import json
import os
import numpy as np
import torch
import asyncio
import logging
# ===========================
# Logging Configuration
# ===========================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ===========================
# Configuration and Constants
# ===========================
BASE_URL = os.getenv("LMSTUDIO_API_BASE_URL", "http://localhost:1234/v1")
USE_GPU = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_GPU else "cpu")
logger.info(f"GPU Available: {USE_GPU}, Device: {DEVICE}")
MODEL_MAX_TOKENS = 32768
AVERAGE_CHARS_PER_TOKEN = 4
BUFFER_TOKENS = 1500
MIN_OUTPUT_TOKENS = 500
MAX_EMBEDDINGS = 100
HTTPX_TIMEOUT = 3000
client = httpx.AsyncClient(timeout=HTTPX_TIMEOUT)
# ===========================
# Utility Functions
# ===========================
def calculate_max_tokens(message_history, model_max_tokens=MODEL_MAX_TOKENS,
buffer=BUFFER_TOKENS, avg_chars_per_token=AVERAGE_CHARS_PER_TOKEN,
min_tokens=MIN_OUTPUT_TOKENS):
total_length = sum(len(message["content"]) for message in message_history)
input_tokens = total_length / avg_chars_per_token
max_tokens = model_max_tokens - int(input_tokens) - buffer
calculated_max = max(max_tokens, min_tokens)
logger.info(f"Calculated max tokens: {calculated_max}")
return calculated_max
async def get_embeddings(text):
url = f"{BASE_URL}/embeddings"
payload = {"model": "nomic_embed_text_v1_5_f16.gguf", "input": text}
try:
response = await client.post(url, json=payload, headers={"Content-Type": "application/json"})
response.raise_for_status()
data = response.json()
if "data" in data and len(data["data"]) > 0:
embedding = np.array(data["data"][0]["embedding"])
if USE_GPU:
embedding = torch.tensor(embedding, device=DEVICE).tolist()
logger.info("Successfully retrieved embeddings.")
return embedding
except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError) as e:
logger.error(f"Error occurred while getting embeddings: {e}")
return None
def calculate_similarity(vec1, vec2):
if vec1 is None or vec2 is None:
logger.warning("One or both vectors are None. Returning similarity as 0.0.")
return 0.0
vec1_tensor = torch.tensor(vec1, device=DEVICE) if not isinstance(vec1, torch.Tensor) else vec1.to(DEVICE)
vec2_tensor = torch.tensor(vec2, device=DEVICE) if not isinstance(vec2, torch.Tensor) else vec2.to(DEVICE)
similarity = torch.nn.functional.cosine_similarity(vec1_tensor.unsqueeze(0), vec2_tensor.unsqueeze(0)).item()
logger.info(f"Calculated similarity: {similarity}")
return similarity
async def chat_with_lmstudio(messages, max_tokens):
url = f"{BASE_URL}/chat/completions"
payload = {
"model": "Qwen2.5-Coder-32B-Instruct-IQ2_M.gguf",
"messages": messages,
"temperature": 1,
"max_tokens": max_tokens,
"stream": True,
}
try:
logger.info("Sending chat completion request to LM Studio API.")
async with client.stream("POST", url, json=payload, headers={"Content-Type": "application/json"}) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line:
try:
decoded_line = line.strip()
if decoded_line.startswith("data: "):
data = json.loads(decoded_line[6:])
content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f"Error occurred while streaming chat completion: {e}")
yield "An error occurred while generating a response."
# ===========================
# Gradio Interface with Dynamic Resizing
# ===========================
def gradio_chat_interface():
css = """
.gradio-container {
background-color: #1e1e1e;
color: #f0f0f0;
font-family: 'Arial', sans-serif;
}
.gr-button {
background-color: #6200ea;
color: white;
font-weight: bold;
}
.gr-textbox {
border: 2px solid #6200ea;
resize: both; /* Allow resizing */
}
.gr-chat-message {
border-radius: 8px;
padding: 10px;
}
"""
js = """
function resizeTextarea(event) {
const textarea = event.data[0].querySelector('textarea');
if (textarea) {
textarea.style.height = 'auto';
textarea.style.height = textarea.scrollHeight + 'px';
}
}
// Trigger resize on input change
document.addEventListener('input', function(event) {
if (event.target.classList.contains('gr-input')) {
resizeTextarea([event.target]);
}
});
// Trigger resize on response updates
const chatbot = document.querySelector('.gradio-container .gradio-chatbot');
if (chatbot) {
const observer = new MutationObserver((mutationsList) => {
mutationsList.forEach((mutation) => {
if (mutation.type === 'childList') {
mutation.addedNodes.forEach((node) => {
if (node.classList && node.classList.contains('gr-chat-message')) {
resizeTextarea([node]);
}
});
}
});
});
observer.observe(chatbot, { childList: true, subtree: true });
}
return [resizeTextarea];
"""
with gr.Blocks(css=css, theme="default") as interface:
gr.Markdown("# 🌟 **Enhanced Chat Interface**\nBeautiful and functional AI-powered chat.")
chatbot = gr.Chatbot(label="Conversation", type="messages")
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
lines=1, # Start with a smaller number of lines
interactive=True,
container=False, # Avoid additional padding
)
send_button = gr.Button("Send", elem_id="send_button")
context_display = gr.Textbox(
label="Relevant Context",
interactive=False,
elem_id="context_display"
)
embeddings_state = gr.State({"embeddings": [], "messages_history": []})
async def chat_handler(message, state):
embeddings = state.get("embeddings", [])
messages_history = state.get("messages_history", [])
user_embedding = await get_embeddings(message)
if not user_embedding:
yield [[], state, "Failed to generate embeddings."]
return
embeddings.append(user_embedding)
messages_history.append({"role": "user", "content": message})
if len(embeddings) > MAX_EMBEDDINGS:
embeddings = embeddings[-MAX_EMBEDDINGS:]
messages_history = messages_history[-MAX_EMBEDDINGS:]
max_tokens = calculate_max_tokens(messages_history)
response = ""
async for chunk in chat_with_lmstudio(messages_history, max_tokens):
response += chunk
updated_chat = chatbot.value.copy()
updated_chat.append({"role": "user", "content": message})
updated_chat.append({"role": "assistant", "content": response})
# Update the context display with some relevant context logic (placeholder for now)
context_display_text = f"Context: {message}"
yield [updated_chat, {"embeddings": embeddings, "messages_history": messages_history}, context_display_text]
send_button.click(
chat_handler,
inputs=[user_input, embeddings_state],
outputs=[chatbot, embeddings_state, context_display],
show_progress=True
)
interface.launch(share=True, server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
asyncio.run(gradio_chat_interface())