import gradio as gr import httpx import json import os import numpy as np import torch import asyncio import logging # =========================== # Logging Configuration # =========================== logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # =========================== # Configuration and Constants # =========================== BASE_URL = os.getenv("LMSTUDIO_API_BASE_URL", "http://localhost:1234/v1") USE_GPU = torch.cuda.is_available() DEVICE = torch.device("cuda" if USE_GPU else "cpu") logger.info(f"GPU Available: {USE_GPU}, Device: {DEVICE}") MODEL_MAX_TOKENS = 32768 AVERAGE_CHARS_PER_TOKEN = 4 BUFFER_TOKENS = 1500 MIN_OUTPUT_TOKENS = 500 MAX_EMBEDDINGS = 100 HTTPX_TIMEOUT = 3000 client = httpx.AsyncClient(timeout=HTTPX_TIMEOUT) # =========================== # Utility Functions # =========================== def calculate_max_tokens(message_history, model_max_tokens=MODEL_MAX_TOKENS, buffer=BUFFER_TOKENS, avg_chars_per_token=AVERAGE_CHARS_PER_TOKEN, min_tokens=MIN_OUTPUT_TOKENS): total_length = sum(len(message["content"]) for message in message_history) input_tokens = total_length / avg_chars_per_token max_tokens = model_max_tokens - int(input_tokens) - buffer calculated_max = max(max_tokens, min_tokens) logger.info(f"Calculated max tokens: {calculated_max}") return calculated_max async def get_embeddings(text): url = f"{BASE_URL}/embeddings" payload = {"model": "nomic_embed_text_v1_5_f16.gguf", "input": text} try: response = await client.post(url, json=payload, headers={"Content-Type": "application/json"}) response.raise_for_status() data = response.json() if "data" in data and len(data["data"]) > 0: embedding = np.array(data["data"][0]["embedding"]) if USE_GPU: embedding = torch.tensor(embedding, device=DEVICE).tolist() logger.info("Successfully retrieved embeddings.") return embedding except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError) as e: logger.error(f"Error occurred while getting embeddings: {e}") return None def calculate_similarity(vec1, vec2): if vec1 is None or vec2 is None: logger.warning("One or both vectors are None. Returning similarity as 0.0.") return 0.0 vec1_tensor = torch.tensor(vec1, device=DEVICE) if not isinstance(vec1, torch.Tensor) else vec1.to(DEVICE) vec2_tensor = torch.tensor(vec2, device=DEVICE) if not isinstance(vec2, torch.Tensor) else vec2.to(DEVICE) similarity = torch.nn.functional.cosine_similarity(vec1_tensor.unsqueeze(0), vec2_tensor.unsqueeze(0)).item() logger.info(f"Calculated similarity: {similarity}") return similarity async def chat_with_lmstudio(messages, max_tokens): url = f"{BASE_URL}/chat/completions" payload = { "model": "Qwen2.5-Coder-32B-Instruct-IQ2_M.gguf", "messages": messages, "temperature": 1, "max_tokens": max_tokens, "stream": True, } try: logger.info("Sending chat completion request to LM Studio API.") async with client.stream("POST", url, json=payload, headers={"Content-Type": "application/json"}) as response: response.raise_for_status() async for line in response.aiter_lines(): if line: try: decoded_line = line.strip() if decoded_line.startswith("data: "): data = json.loads(decoded_line[6:]) content = data.get("choices", [{}])[0].get("delta", {}).get("content", "") if content: yield content except json.JSONDecodeError: continue except (httpx.RequestError, httpx.HTTPStatusError) as e: logger.error(f"Error occurred while streaming chat completion: {e}") yield "An error occurred while generating a response." # =========================== # Gradio Interface with Dynamic Resizing # =========================== def gradio_chat_interface(): css = """ .gradio-container { background-color: #1e1e1e; color: #f0f0f0; font-family: 'Arial', sans-serif; } .gr-button { background-color: #6200ea; color: white; font-weight: bold; } .gr-textbox { border: 2px solid #6200ea; resize: both; /* Allow resizing */ } .gr-chat-message { border-radius: 8px; padding: 10px; } """ js = """ function resizeTextarea(event) { const textarea = event.data[0].querySelector('textarea'); if (textarea) { textarea.style.height = 'auto'; textarea.style.height = textarea.scrollHeight + 'px'; } } // Trigger resize on input change document.addEventListener('input', function(event) { if (event.target.classList.contains('gr-input')) { resizeTextarea([event.target]); } }); // Trigger resize on response updates const chatbot = document.querySelector('.gradio-container .gradio-chatbot'); if (chatbot) { const observer = new MutationObserver((mutationsList) => { mutationsList.forEach((mutation) => { if (mutation.type === 'childList') { mutation.addedNodes.forEach((node) => { if (node.classList && node.classList.contains('gr-chat-message')) { resizeTextarea([node]); } }); } }); }); observer.observe(chatbot, { childList: true, subtree: true }); } return [resizeTextarea]; """ with gr.Blocks(css=css, theme="default") as interface: gr.Markdown("# 🌟 **Enhanced Chat Interface**\nBeautiful and functional AI-powered chat.") chatbot = gr.Chatbot(label="Conversation", type="messages") user_input = gr.Textbox( label="Your Message", placeholder="Type your message here...", lines=1, # Start with a smaller number of lines interactive=True, container=False, # Avoid additional padding ) send_button = gr.Button("Send", elem_id="send_button") context_display = gr.Textbox( label="Relevant Context", interactive=False, elem_id="context_display" ) embeddings_state = gr.State({"embeddings": [], "messages_history": []}) async def chat_handler(message, state): embeddings = state.get("embeddings", []) messages_history = state.get("messages_history", []) user_embedding = await get_embeddings(message) if not user_embedding: yield [[], state, "Failed to generate embeddings."] return embeddings.append(user_embedding) messages_history.append({"role": "user", "content": message}) if len(embeddings) > MAX_EMBEDDINGS: embeddings = embeddings[-MAX_EMBEDDINGS:] messages_history = messages_history[-MAX_EMBEDDINGS:] max_tokens = calculate_max_tokens(messages_history) response = "" async for chunk in chat_with_lmstudio(messages_history, max_tokens): response += chunk updated_chat = chatbot.value.copy() updated_chat.append({"role": "user", "content": message}) updated_chat.append({"role": "assistant", "content": response}) # Update the context display with some relevant context logic (placeholder for now) context_display_text = f"Context: {message}" yield [updated_chat, {"embeddings": embeddings, "messages_history": messages_history}, context_display_text] send_button.click( chat_handler, inputs=[user_input, embeddings_state], outputs=[chatbot, embeddings_state, context_display], show_progress=True ) interface.launch(share=True, server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": asyncio.run(gradio_chat_interface())