Spaces:
Runtime error
Runtime error
import gradio as gr | |
import httpx | |
import json | |
import os | |
import numpy as np | |
import torch | |
import asyncio | |
import logging | |
# =========================== | |
# Logging Configuration | |
# =========================== | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# =========================== | |
# Configuration and Constants | |
# =========================== | |
BASE_URL = os.getenv("LMSTUDIO_API_BASE_URL", "http://localhost:1234/v1") | |
USE_GPU = torch.cuda.is_available() | |
DEVICE = torch.device("cuda" if USE_GPU else "cpu") | |
logger.info(f"GPU Available: {USE_GPU}, Device: {DEVICE}") | |
MODEL_MAX_TOKENS = 32768 | |
AVERAGE_CHARS_PER_TOKEN = 4 | |
BUFFER_TOKENS = 1500 | |
MIN_OUTPUT_TOKENS = 500 | |
MAX_EMBEDDINGS = 100 | |
HTTPX_TIMEOUT = 3000 | |
client = httpx.AsyncClient(timeout=HTTPX_TIMEOUT) | |
# =========================== | |
# Utility Functions | |
# =========================== | |
def calculate_max_tokens(message_history, model_max_tokens=MODEL_MAX_TOKENS, | |
buffer=BUFFER_TOKENS, avg_chars_per_token=AVERAGE_CHARS_PER_TOKEN, | |
min_tokens=MIN_OUTPUT_TOKENS): | |
total_length = sum(len(message["content"]) for message in message_history) | |
input_tokens = total_length / avg_chars_per_token | |
max_tokens = model_max_tokens - int(input_tokens) - buffer | |
calculated_max = max(max_tokens, min_tokens) | |
logger.info(f"Calculated max tokens: {calculated_max}") | |
return calculated_max | |
async def get_embeddings(text): | |
url = f"{BASE_URL}/embeddings" | |
payload = {"model": "nomic_embed_text_v1_5_f16.gguf", "input": text} | |
try: | |
response = await client.post(url, json=payload, headers={"Content-Type": "application/json"}) | |
response.raise_for_status() | |
data = response.json() | |
if "data" in data and len(data["data"]) > 0: | |
embedding = np.array(data["data"][0]["embedding"]) | |
if USE_GPU: | |
embedding = torch.tensor(embedding, device=DEVICE).tolist() | |
logger.info("Successfully retrieved embeddings.") | |
return embedding | |
except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError) as e: | |
logger.error(f"Error occurred while getting embeddings: {e}") | |
return None | |
def calculate_similarity(vec1, vec2): | |
if vec1 is None or vec2 is None: | |
logger.warning("One or both vectors are None. Returning similarity as 0.0.") | |
return 0.0 | |
vec1_tensor = torch.tensor(vec1, device=DEVICE) if not isinstance(vec1, torch.Tensor) else vec1.to(DEVICE) | |
vec2_tensor = torch.tensor(vec2, device=DEVICE) if not isinstance(vec2, torch.Tensor) else vec2.to(DEVICE) | |
similarity = torch.nn.functional.cosine_similarity(vec1_tensor.unsqueeze(0), vec2_tensor.unsqueeze(0)).item() | |
logger.info(f"Calculated similarity: {similarity}") | |
return similarity | |
async def chat_with_lmstudio(messages, max_tokens): | |
url = f"{BASE_URL}/chat/completions" | |
payload = { | |
"model": "Qwen2.5-Coder-32B-Instruct-IQ2_M.gguf", | |
"messages": messages, | |
"temperature": 1, | |
"max_tokens": max_tokens, | |
"stream": True, | |
} | |
try: | |
logger.info("Sending chat completion request to LM Studio API.") | |
async with client.stream("POST", url, json=payload, headers={"Content-Type": "application/json"}) as response: | |
response.raise_for_status() | |
async for line in response.aiter_lines(): | |
if line: | |
try: | |
decoded_line = line.strip() | |
if decoded_line.startswith("data: "): | |
data = json.loads(decoded_line[6:]) | |
content = data.get("choices", [{}])[0].get("delta", {}).get("content", "") | |
if content: | |
yield content | |
except json.JSONDecodeError: | |
continue | |
except (httpx.RequestError, httpx.HTTPStatusError) as e: | |
logger.error(f"Error occurred while streaming chat completion: {e}") | |
yield "An error occurred while generating a response." | |
# =========================== | |
# Gradio Interface with Dynamic Resizing | |
# =========================== | |
def gradio_chat_interface(): | |
css = """ | |
.gradio-container { | |
background-color: #1e1e1e; | |
color: #f0f0f0; | |
font-family: 'Arial', sans-serif; | |
} | |
.gr-button { | |
background-color: #6200ea; | |
color: white; | |
font-weight: bold; | |
} | |
.gr-textbox { | |
border: 2px solid #6200ea; | |
resize: both; /* Allow resizing */ | |
} | |
.gr-chat-message { | |
border-radius: 8px; | |
padding: 10px; | |
} | |
""" | |
js = """ | |
function resizeTextarea(event) { | |
const textarea = event.data[0].querySelector('textarea'); | |
if (textarea) { | |
textarea.style.height = 'auto'; | |
textarea.style.height = textarea.scrollHeight + 'px'; | |
} | |
} | |
// Trigger resize on input change | |
document.addEventListener('input', function(event) { | |
if (event.target.classList.contains('gr-input')) { | |
resizeTextarea([event.target]); | |
} | |
}); | |
// Trigger resize on response updates | |
const chatbot = document.querySelector('.gradio-container .gradio-chatbot'); | |
if (chatbot) { | |
const observer = new MutationObserver((mutationsList) => { | |
mutationsList.forEach((mutation) => { | |
if (mutation.type === 'childList') { | |
mutation.addedNodes.forEach((node) => { | |
if (node.classList && node.classList.contains('gr-chat-message')) { | |
resizeTextarea([node]); | |
} | |
}); | |
} | |
}); | |
}); | |
observer.observe(chatbot, { childList: true, subtree: true }); | |
} | |
return [resizeTextarea]; | |
""" | |
with gr.Blocks(css=css, theme="default") as interface: | |
gr.Markdown("# π **Enhanced Chat Interface**\nBeautiful and functional AI-powered chat.") | |
chatbot = gr.Chatbot(label="Conversation", type="messages") | |
user_input = gr.Textbox( | |
label="Your Message", | |
placeholder="Type your message here...", | |
lines=1, # Start with a smaller number of lines | |
interactive=True, | |
container=False, # Avoid additional padding | |
) | |
send_button = gr.Button("Send", elem_id="send_button") | |
context_display = gr.Textbox( | |
label="Relevant Context", | |
interactive=False, | |
elem_id="context_display" | |
) | |
embeddings_state = gr.State({"embeddings": [], "messages_history": []}) | |
async def chat_handler(message, state): | |
embeddings = state.get("embeddings", []) | |
messages_history = state.get("messages_history", []) | |
user_embedding = await get_embeddings(message) | |
if not user_embedding: | |
yield [[], state, "Failed to generate embeddings."] | |
return | |
embeddings.append(user_embedding) | |
messages_history.append({"role": "user", "content": message}) | |
if len(embeddings) > MAX_EMBEDDINGS: | |
embeddings = embeddings[-MAX_EMBEDDINGS:] | |
messages_history = messages_history[-MAX_EMBEDDINGS:] | |
max_tokens = calculate_max_tokens(messages_history) | |
response = "" | |
async for chunk in chat_with_lmstudio(messages_history, max_tokens): | |
response += chunk | |
updated_chat = chatbot.value.copy() | |
updated_chat.append({"role": "user", "content": message}) | |
updated_chat.append({"role": "assistant", "content": response}) | |
# Update the context display with some relevant context logic (placeholder for now) | |
context_display_text = f"Context: {message}" | |
yield [updated_chat, {"embeddings": embeddings, "messages_history": messages_history}, context_display_text] | |
send_button.click( | |
chat_handler, | |
inputs=[user_input, embeddings_state], | |
outputs=[chatbot, embeddings_state, context_display], | |
show_progress=True | |
) | |
interface.launch(share=True, server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
asyncio.run(gradio_chat_interface()) |