import os import gradio as gr import requests import json import base64 import logging import io import time from typing import List, Dict, Any, Union, Tuple, Optional from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Gracefully import libraries with fallbacks try: from PIL import Image HAS_PIL = True except ImportError: logger.warning("PIL not installed. Image processing will be limited.") HAS_PIL = False try: import PyPDF2 HAS_PYPDF2 = True except ImportError: logger.warning("PyPDF2 not installed. PDF processing will be limited.") HAS_PYPDF2 = False try: import markdown HAS_MARKDOWN = True except ImportError: logger.warning("Markdown not installed. Markdown processing will be limited.") HAS_MARKDOWN = False try: import openai HAS_OPENAI = True except ImportError: logger.warning("OpenAI package not installed. OpenAI models will be unavailable.") HAS_OPENAI = False try: from groq import Groq HAS_GROQ = True except ImportError: logger.warning("Groq client not installed. Groq API will be unavailable.") HAS_GROQ = False try: import cohere HAS_COHERE = True except ImportError: logger.warning("Cohere package not installed. Cohere models will be unavailable.") HAS_COHERE = False try: from huggingface_hub import InferenceClient HAS_HF = True except ImportError: logger.warning("HuggingFace hub not installed. HuggingFace models will be limited.") HAS_HF = False # API keys from environment OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "") COHERE_API_KEY = os.environ.get("COHERE_API_KEY", "") GLHF_API_KEY = os.environ.get("GLHF_API_KEY", "") HF_API_KEY = os.environ.get("HF_API_KEY", "") # ========================================================== # MODEL DEFINITIONS # ========================================================== # OPENROUTER MODELS # These are the original models from the provided code OPENROUTER_MODELS = [ # 1M+ Context Models {"category": "1M+ Context", "models": [ #("Google: Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000), ("Google: Gemini 2.0 Flash Thinking Experimental 01-21", "google/gemini-2.0-flash-thinking-exp:free", 1048576), ("Google: Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576), ("Google: Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000), ("Google: Gemini Flash 1.5 8B Experimental", "google/gemini-flash-1.5-8b-exp", 1000000), ]}, # 100K-1M Context Models {"category": "100K+ Context", "models": [ ("DeepSeek: DeepSeek R1 Zero", "deepseek/deepseek-r1-zero:free", 163840), ("DeepSeek: R1", "deepseek/deepseek-r1:free", 163840), ("DeepSeek: DeepSeek V3 Base", "deepseek/deepseek-v3-base:free", 131072), ("DeepSeek: DeepSeek V3 0324", "deepseek/deepseek-chat-v3-0324:free", 131072), ("Google: Gemma 3 4B", "google/gemma-3-4b-it:free", 131072), ("Google: Gemma 3 12B", "google/gemma-3-12b-it:free", 131072), ("Nous: DeepHermes 3 Llama 3 8B Preview", "nousresearch/deephermes-3-llama-3-8b-preview:free", 131072), ("Qwen: Qwen2.5 VL 72B Instruct", "qwen/qwen2.5-vl-72b-instruct:free", 131072), ("DeepSeek: DeepSeek V3", "deepseek/deepseek-chat:free", 131072), ("NVIDIA: Llama 3.1 Nemotron 70B Instruct", "nvidia/llama-3.1-nemotron-70b-instruct:free", 131072), ("Meta: Llama 3.2 1B Instruct", "meta-llama/llama-3.2-1b-instruct:free", 131072), ("Meta: Llama 3.2 11B Vision Instruct", "meta-llama/llama-3.2-11b-vision-instruct:free", 131072), ("Meta: Llama 3.1 8B Instruct", "meta-llama/llama-3.1-8b-instruct:free", 131072), ("Mistral: Mistral Nemo", "mistralai/mistral-nemo:free", 128000), ]}, # 64K-100K Context Models {"category": "64K-100K Context", "models": [ ("Mistral: Mistral Small 3.1 24B", "mistralai/mistral-small-3.1-24b-instruct:free", 96000), ("Google: Gemma 3 27B", "google/gemma-3-27b-it:free", 96000), ("Qwen: Qwen2.5 VL 3B Instruct", "qwen/qwen2.5-vl-3b-instruct:free", 64000), ("DeepSeek: R1 Distill Qwen 14B", "deepseek/deepseek-r1-distill-qwen-14b:free", 64000), ("Qwen: Qwen2.5-VL 7B Instruct", "qwen/qwen-2.5-vl-7b-instruct:free", 64000), ]}, # 32K-64K Context Models {"category": "32K-64K Context", "models": [ ("Google: LearnLM 1.5 Pro Experimental", "google/learnlm-1.5-pro-experimental:free", 40960), ("Qwen: QwQ 32B", "qwen/qwq-32b:free", 40000), ("Google: Gemini 2.0 Flash Thinking Experimental", "google/gemini-2.0-flash-thinking-exp-1219:free", 40000), ("Bytedance: UI-TARS 72B", "bytedance-research/ui-tars-72b:free", 32768), ("Qwerky 72b", "featherless/qwerky-72b:free", 32768), ("OlympicCoder 7B", "open-r1/olympiccoder-7b:free", 32768), ("OlympicCoder 32B", "open-r1/olympiccoder-32b:free", 32768), ("Google: Gemma 3 1B", "google/gemma-3-1b-it:free", 32768), ("Reka: Flash 3", "rekaai/reka-flash-3:free", 32768), ("Dolphin3.0 R1 Mistral 24B", "cognitivecomputations/dolphin3.0-r1-mistral-24b:free", 32768), ("Dolphin3.0 Mistral 24B", "cognitivecomputations/dolphin3.0-mistral-24b:free", 32768), ("Mistral: Mistral Small 3", "mistralai/mistral-small-24b-instruct-2501:free", 32768), ("Qwen2.5 Coder 32B Instruct", "qwen/qwen-2.5-coder-32b-instruct:free", 32768), ("Qwen2.5 72B Instruct", "qwen/qwen-2.5-72b-instruct:free", 32768), ]}, # 8K-32K Context Models {"category": "8K-32K Context", "models": [ ("Meta: Llama 3.2 3B Instruct", "meta-llama/llama-3.2-3b-instruct:free", 20000), ("Qwen: QwQ 32B Preview", "qwen/qwq-32b-preview:free", 16384), ("DeepSeek: R1 Distill Qwen 32B", "deepseek/deepseek-r1-distill-qwen-32b:free", 16000), ("Qwen: Qwen2.5 VL 32B Instruct", "qwen/qwen2.5-vl-32b-instruct:free", 8192), ("Moonshot AI: Moonlight 16B A3B Instruct", "moonshotai/moonlight-16b-a3b-instruct:free", 8192), ("DeepSeek: R1 Distill Llama 70B", "deepseek/deepseek-r1-distill-llama-70b:free", 8192), ("Qwen 2 7B Instruct", "qwen/qwen-2-7b-instruct:free", 8192), ("Google: Gemma 2 9B", "google/gemma-2-9b-it:free", 8192), ("Mistral: Mistral 7B Instruct", "mistralai/mistral-7b-instruct:free", 8192), ("Microsoft: Phi-3 Mini 128K Instruct", "microsoft/phi-3-mini-128k-instruct:free", 8192), ("Microsoft: Phi-3 Medium 128K Instruct", "microsoft/phi-3-medium-128k-instruct:free", 8192), ("Meta: Llama 3 8B Instruct", "meta-llama/llama-3-8b-instruct:free", 8192), ("OpenChat 3.5 7B", "openchat/openchat-7b:free", 8192), ("Meta: Llama 3.3 70B Instruct", "meta-llama/llama-3.3-70b-instruct:free", 8000), ]}, # <8K Context Models {"category": "4K Context", "models": [ ("AllenAI: Molmo 7B D", "allenai/molmo-7b-d:free", 4096), ("Rogue Rose 103B v0.2", "sophosympatheia/rogue-rose-103b-v0.2:free", 4096), ("Toppy M 7B", "undi95/toppy-m-7b:free", 4096), ("Hugging Face: Zephyr 7B", "huggingfaceh4/zephyr-7b-beta:free", 4096), ("MythoMax 13B", "gryphe/mythomax-l2-13b:free", 4096), ]}, # Vision-capable Models {"category": "Vision Models", "models": [ #("Google: Gemini Pro 2.0 Experimental", "google/gemini-2.0-pro-exp-02-05:free", 2000000), ("Google: Gemini 2.0 Flash Thinking Experimental 01-21", "google/gemini-2.0-flash-thinking-exp:free", 1048576), ("Google: Gemini Flash 2.0 Experimental", "google/gemini-2.0-flash-exp:free", 1048576), ("Google: Gemini Pro 2.5 Experimental", "google/gemini-2.5-pro-exp-03-25:free", 1000000), ("Google: Gemini Flash 1.5 8B Experimental", "google/gemini-flash-1.5-8b-exp", 1000000), ("Google: Gemma 3 4B", "google/gemma-3-4b-it:free", 131072), ("Google: Gemma 3 12B", "google/gemma-3-12b-it:free", 131072), ("Qwen: Qwen2.5 VL 72B Instruct", "qwen/qwen2.5-vl-72b-instruct:free", 131072), ("Meta: Llama 3.2 11B Vision Instruct", "meta-llama/llama-3.2-11b-vision-instruct:free", 131072), ("Mistral: Mistral Small 3.1 24B", "mistralai/mistral-small-3.1-24b-instruct:free", 96000), ("Google: Gemma 3 27B", "google/gemma-3-27b-it:free", 96000), ("Qwen: Qwen2.5 VL 3B Instruct", "qwen/qwen2.5-vl-3b-instruct:free", 64000), ("Qwen: Qwen2.5-VL 7B Instruct", "qwen/qwen-2.5-vl-7b-instruct:free", 64000), ("Google: LearnLM 1.5 Pro Experimental", "google/learnlm-1.5-pro-experimental:free", 40960), ("Google: Gemini 2.0 Flash Thinking Experimental", "google/gemini-2.0-flash-thinking-exp-1219:free", 40000), ("Bytedance: UI-TARS 72B", "bytedance-research/ui-tars-72b:free", 32768), ("Google: Gemma 3 1B", "google/gemma-3-1b-it:free", 32768), ("Qwen: Qwen2.5 VL 32B Instruct", "qwen/qwen2.5-vl-32b-instruct:free", 8192), ("AllenAI: Molmo 7B D", "allenai/molmo-7b-d:free", 4096), ]}, ] # Flatten OpenRouter model list for easier access OPENROUTER_ALL_MODELS = [] for category in OPENROUTER_MODELS: for model in category["models"]: if model not in OPENROUTER_ALL_MODELS: # Avoid duplicates OPENROUTER_ALL_MODELS.append(model) # OPENAI MODELS OPENAI_MODELS = { "gpt-3.5-turbo": 16385, "gpt-3.5-turbo-0125": 16385, "gpt-3.5-turbo-1106": 16385, "gpt-3.5-turbo-instruct": 4096, "gpt-4": 8192, "gpt-4-0314": 8192, "gpt-4-0613": 8192, "gpt-4-turbo": 128000, "gpt-4-turbo-2024-04-09": 128000, "gpt-4-turbo-preview": 128000, "gpt-4-0125-preview": 128000, "gpt-4-1106-preview": 128000, "gpt-4o": 128000, "gpt-4o-2024-11-20": 128000, "gpt-4o-2024-08-06": 128000, "gpt-4o-2024-05-13": 128000, "chatgpt-4o-latest": 128000, "gpt-4o-mini": 128000, "gpt-4o-mini-2024-07-18": 128000, "gpt-4o-realtime-preview": 128000, "gpt-4o-realtime-preview-2024-10-01": 128000, "gpt-4o-audio-preview": 128000, "gpt-4o-audio-preview-2024-10-01": 128000, "o1-preview": 128000, "o1-preview-2024-09-12": 128000, "o1-mini": 128000, "o1-mini-2024-09-12": 128000, } # HUGGINGFACE MODELS HUGGINGFACE_MODELS = { "microsoft/phi-3-mini-4k-instruct": 4096, "microsoft/Phi-3-mini-128k-instruct": 131072, "HuggingFaceH4/zephyr-7b-beta": 8192, "deepseek-ai/DeepSeek-Coder-V2-Instruct": 8192, "mistralai/Mistral-7B-Instruct-v0.3": 32768, "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768, "microsoft/Phi-3.5-mini-instruct": 4096, "HuggingFaceTB/SmolLM2-1.7B-Instruct": 2048, "google/gemma-2-2b-it": 2048, "openai-community/gpt2": 1024, "microsoft/phi-2": 2048, "TinyLlama/TinyLlama-1.1B-Chat-v1.0": 2048, "VAGOsolutions/Llama-3-SauerkrautLM-8b-Instruct": 2048, "VAGOsolutions/Llama-3.1-SauerkrautLM-8b-Instruct": 4096, "VAGOsolutions/SauerkrautLM-Nemo-12b-Instruct": 4096, "openGPT-X/Teuken-7B-instruct-research-v0.4": 4096, "Qwen/Qwen2.5-7B-Instruct": 131072, "tiiuae/falcon-7b-instruct": 8192, "Qwen/QwQ-32B-preview": 32768, } # GROQ MODELS - We'll populate this dynamically DEFAULT_GROQ_MODELS = { "gemma2-9b-it": 8192, "gemma-7b-it": 8192, "llama-3.3-70b-versatile": 131072, "llama-3.1-70b-versatile": 131072, "llama-3.1-8b-instant": 131072, "llama-guard-3-8b": 8192, "llama3-70b-8192": 8192, "llama3-8b-8192": 8192, "mixtral-8x7b-32768": 32768, "llama3-groq-70b-8192-tool-use-preview": 8192, "llama3-groq-8b-8192-tool-use-preview": 8192, "llama-3.3-70b-specdec": 131072, "llama-3.1-70b-specdec": 131072, "llama-3.2-1b-preview": 131072, "llama-3.2-3b-preview": 131072, } # COHERE MODELS COHERE_MODELS = { "command-r-plus-08-2024": 131072, "command-r-plus-04-2024": 131072, "command-r-plus": 131072, "command-r-08-2024": 131072, "command-r-03-2024": 131072, "command-r": 131072, "command": 4096, "command-nightly": 131072, "command-light": 4096, "command-light-nightly": 4096, "c4ai-aya-expanse-8b": 8192, "c4ai-aya-expanse-32b": 131072, } # GLHF MODELS GLHF_MODELS = { "mistralai/Mixtral-8x7B-Instruct-v0.1": 32768, "01-ai/Yi-34B-Chat": 32768, "mistralai/Mistral-7B-Instruct-v0.3": 32768, "microsoft/phi-3-mini-4k-instruct": 4096, "microsoft/Phi-3.5-mini-instruct": 4096, "microsoft/Phi-3-mini-128k-instruct": 131072, "HuggingFaceH4/zephyr-7b-beta": 8192, "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768, "google/gemma-2-2b-it": 2048, "microsoft/phi-2": 2048, } # ========================================================== # HELPER FUNCTIONS # ========================================================== def fetch_groq_models(): """Fetch available Groq models with proper error handling""" try: if not HAS_GROQ or not GROQ_API_KEY: logger.warning("Groq client not available or no API key. Using default model list.") return DEFAULT_GROQ_MODELS client = Groq(api_key=GROQ_API_KEY) models = client.models.list() # Create dictionary of model_id -> context size model_dict = {} for model in models.data: model_id = model.id # Map known context sizes or use a default if "llama-3" in model_id and "70b" in model_id: context_size = 131072 elif "llama-3" in model_id and "8b" in model_id: context_size = 131072 elif "mixtral" in model_id: context_size = 32768 elif "gemma" in model_id: context_size = 8192 else: context_size = 8192 # Default assumption model_dict[model_id] = context_size # Ensure we have models by combining with defaults if not model_dict: return DEFAULT_GROQ_MODELS return {**DEFAULT_GROQ_MODELS, **model_dict} except Exception as e: logger.error(f"Error fetching Groq models: {e}") return DEFAULT_GROQ_MODELS # Initialize Groq models GROQ_MODELS = fetch_groq_models() def encode_image_to_base64(image_path): """Encode an image file to base64 string""" try: if isinstance(image_path, str): # File path as string with open(image_path, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') file_extension = image_path.split('.')[-1].lower() mime_type = f"image/{file_extension}" if file_extension in ["jpg", "jpeg"]: mime_type = "image/jpeg" elif file_extension == "png": mime_type = "image/png" elif file_extension == "webp": mime_type = "image/webp" return f"data:{mime_type};base64,{encoded_string}" elif hasattr(image_path, 'name'): # Handle Gradio file objects directly with open(image_path.name, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') file_extension = image_path.name.split('.')[-1].lower() mime_type = f"image/{file_extension}" if file_extension in ["jpg", "jpeg"]: mime_type = "image/jpeg" elif file_extension == "png": mime_type = "image/png" elif file_extension == "webp": mime_type = "image/webp" return f"data:{mime_type};base64,{encoded_string}" else: # Handle file object or other types logger.error(f"Unsupported image type: {type(image_path)}") return None except Exception as e: logger.error(f"Error encoding image: {str(e)}") return None def extract_text_from_file(file_path): """Extract text from various file types""" try: file_extension = file_path.split('.')[-1].lower() if file_extension == 'pdf': if HAS_PYPDF2: text = "" with open(file_path, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) for page_num in range(len(pdf_reader.pages)): page = pdf_reader.pages[page_num] text += page.extract_text() + "\n\n" return text else: return "PDF processing is not available (PyPDF2 not installed)" elif file_extension == 'md': with open(file_path, 'r', encoding='utf-8') as file: return file.read() elif file_extension == 'txt': with open(file_path, 'r', encoding='utf-8') as file: return file.read() else: return f"Unsupported file type: {file_extension}" except Exception as e: logger.error(f"Error extracting text from file: {str(e)}") return f"Error processing file: {str(e)}" def prepare_message_with_media(text, images=None, documents=None): """Prepare a message with text, images, and document content""" # If no media, return text only if not images and not documents: return text # Start with text content if documents and len(documents) > 0: # If there are documents, append their content to the text document_texts = [] for doc in documents: if doc is None: continue # Make sure to handle file objects properly doc_path = doc.name if hasattr(doc, 'name') else doc doc_text = extract_text_from_file(doc_path) if doc_text: document_texts.append(doc_text) # Add document content to text if document_texts: if not text: text = "Please analyze these documents:" else: text = f"{text}\n\nDocument content:\n\n" text += "\n\n".join(document_texts) # If no images, return text only if not images: return text # If we have images, create a multimodal content array content = [{"type": "text", "text": text}] # Add images if any if images: # Check if images is a list of image paths or file objects if isinstance(images, list): for img in images: if img is None: continue encoded_image = encode_image_to_base64(img) if encoded_image: content.append({ "type": "image_url", "image_url": {"url": encoded_image} }) else: # For single image or Gallery component logger.warning(f"Images is not a list: {type(images)}") # Try to handle as single image encoded_image = encode_image_to_base64(images) if encoded_image: content.append({ "type": "image_url", "image_url": {"url": encoded_image} }) return content def format_to_message_dict(history): """Convert history to proper message format""" messages = [] for pair in history: if len(pair) == 2: human, ai = pair if human: messages.append({"role": "user", "content": human}) if ai: messages.append({"role": "assistant", "content": ai}) return messages def process_uploaded_images(files): """Process uploaded image files""" file_paths = [] for file in files: if hasattr(file, 'name'): file_paths.append(file.name) return file_paths def filter_models(provider, search_term): """Filter models based on search term and provider""" if provider == "OpenRouter": all_models = [model[0] for model in OPENROUTER_ALL_MODELS] elif provider == "OpenAI": all_models = list(OPENAI_MODELS.keys()) elif provider == "HuggingFace": all_models = list(HUGGINGFACE_MODELS.keys()) elif provider == "Groq": all_models = list(GROQ_MODELS.keys()) elif provider == "Cohere": all_models = list(COHERE_MODELS.keys()) elif provider == "GLHF": all_models = list(GLHF_MODELS.keys()) else: return [], None if not search_term: return all_models, all_models[0] if all_models else None filtered_models = [model for model in all_models if search_term.lower() in model.lower()] if filtered_models: return filtered_models, filtered_models[0] else: return all_models, all_models[0] if all_models else None def get_model_info(provider, model_choice): """Get model ID and context size based on provider and model name""" if provider == "OpenRouter": for name, model_id, ctx_size in OPENROUTER_ALL_MODELS: if name == model_choice: return model_id, ctx_size elif provider == "OpenAI": if model_choice in OPENAI_MODELS: return model_choice, OPENAI_MODELS[model_choice] elif provider == "HuggingFace": if model_choice in HUGGINGFACE_MODELS: return model_choice, HUGGINGFACE_MODELS[model_choice] elif provider == "Groq": if model_choice in GROQ_MODELS: return model_choice, GROQ_MODELS[model_choice] elif provider == "Cohere": if model_choice in COHERE_MODELS: return model_choice, COHERE_MODELS[model_choice] elif provider == "GLHF": if model_choice in GLHF_MODELS: return model_choice, GLHF_MODELS[model_choice] return None, 0 def update_context_display(provider, model_name): """Update context size display for the selected model""" _, ctx_size = get_model_info(provider, model_name) return f"{ctx_size:,}" if ctx_size else "Unknown" def update_model_info(provider, model_name): """Generate HTML info display for the selected model""" model_id, ctx_size = get_model_info(provider, model_name) if not model_id: return "

Model information not available

" # Check if this is a vision model is_vision_model = False # For OpenRouter, check the vision models category if provider == "OpenRouter": for cat in OPENROUTER_MODELS: if cat["category"] == "Vision Models": if any(m[0] == model_name for m in cat["models"]): is_vision_model = True break # For other providers, use heuristics elif provider == "OpenAI" and any(x in model_name.lower() for x in ["gpt-4", "gpt-4o"]): is_vision_model = True elif provider == "HuggingFace" and any(x in model_name.lower() for x in ["vl", "vision"]): is_vision_model = True vision_badge = 'Vision' if is_vision_model else '' # For OpenRouter, show the model ID model_id_html = f"

Model ID: {model_id}

" if provider == "OpenRouter" else "" # For others, the ID is the same as the name if provider != "OpenRouter": model_id_html = "" return f"""

{model_name} {vision_badge}

{model_id_html}

Context Size: {ctx_size:,} tokens

Provider: {provider}

{f'

Features: Supports image understanding

' if is_vision_model else ''}
""" # ========================================================== # API HANDLERS # ========================================================== def call_openrouter_api(payload, api_key_override=None): """Make a call to OpenRouter API with error handling""" try: api_key = api_key_override if api_key_override else OPENROUTER_API_KEY if not api_key: raise ValueError("OpenRouter API key is required") response = requests.post( "https://openrouter.ai/api/v1/chat/completions", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "HTTP-Referer": "https://huggingface.co/spaces/user/MultiProviderCrispChat" }, json=payload, timeout=180 # Longer timeout for document processing ) return response except requests.RequestException as e: logger.error(f"OpenRouter API request error: {str(e)}") raise e def call_openai_api(payload, api_key_override=None): """Make a call to OpenAI API with error handling""" try: if not HAS_OPENAI: raise ImportError("OpenAI package not installed") api_key = api_key_override if api_key_override else OPENAI_API_KEY if not api_key: raise ValueError("OpenAI API key is required") client = openai.OpenAI(api_key=api_key) # Extract parameters from payload model = payload.get("model", "gpt-3.5-turbo") messages = payload.get("messages", []) temperature = payload.get("temperature", 0.7) max_tokens = payload.get("max_tokens", 1000) stream = payload.get("stream", False) top_p = payload.get("top_p", 0.9) presence_penalty = payload.get("presence_penalty", 0) frequency_penalty = payload.get("frequency_penalty", 0) # Handle response format if specified response_format = None if payload.get("response_format") == "json_object": response_format = {"type": "json_object"} # Create completion response = client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, top_p=top_p, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, response_format=response_format ) return response except Exception as e: logger.error(f"OpenAI API error: {str(e)}") raise e def call_huggingface_api(payload, api_key_override=None): """Make a call to HuggingFace API with error handling""" try: if not HAS_HF: raise ImportError("HuggingFace hub not installed") api_key = api_key_override if api_key_override else HF_API_KEY # Extract parameters from payload model_id = payload.get("model", "mistralai/Mistral-7B-Instruct-v0.3") messages = payload.get("messages", []) temperature = payload.get("temperature", 0.7) max_tokens = payload.get("max_tokens", 500) # Create a prompt from messages prompt = "" for msg in messages: role = msg["role"].upper() content = msg["content"] # Handle multimodal content if isinstance(content, list): text_parts = [] for item in content: if item["type"] == "text": text_parts.append(item["text"]) content = "\n".join(text_parts) prompt += f"{role}: {content}\n" prompt += "ASSISTANT: " # Create client with or without API key client = InferenceClient(token=api_key) if api_key else InferenceClient() # Generate response response = client.text_generation( prompt, model=model_id, max_new_tokens=max_tokens, temperature=temperature, repetition_penalty=1.1 ) return {"generated_text": str(response)} except Exception as e: logger.error(f"HuggingFace API error: {str(e)}") raise e def call_groq_api(payload, api_key_override=None): """Make a call to Groq API with error handling""" try: if not HAS_GROQ: raise ImportError("Groq client not installed") api_key = api_key_override if api_key_override else GROQ_API_KEY if not api_key: raise ValueError("Groq API key is required") client = Groq(api_key=api_key) # Extract parameters from payload model = payload.get("model", "llama-3.1-8b-instant") messages = payload.get("messages", []) temperature = payload.get("temperature", 0.7) max_tokens = payload.get("max_tokens", 1000) stream = payload.get("stream", False) top_p = payload.get("top_p", 0.9) # Create completion response = client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream, top_p=top_p ) return response except Exception as e: logger.error(f"Groq API error: {str(e)}") raise e def call_cohere_api(payload, api_key_override=None): """Make a call to Cohere API with error handling""" try: if not HAS_COHERE: raise ImportError("Cohere package not installed") api_key = api_key_override if api_key_override else COHERE_API_KEY if not api_key: raise ValueError("Cohere API key is required") client = cohere.Client(api_key=api_key) # Extract parameters from payload model = payload.get("model", "command-r-plus") messages = payload.get("messages", []) temperature = payload.get("temperature", 0.7) max_tokens = payload.get("max_tokens", 1000) # Format messages for Cohere chat_history = [] user_message = "" for msg in messages: if msg["role"] == "system": # For system message, we'll prepend to the user's first message system_content = msg["content"] if isinstance(system_content, list): # Handle multimodal content system_parts = [] for item in system_content: if item["type"] == "text": system_parts.append(item["text"]) system_content = "\n".join(system_parts) user_message = f"System: {system_content}\n\n" + user_message elif msg["role"] == "user": content = msg["content"] # Handle multimodal content if isinstance(content, list): text_parts = [] for item in content: if item["type"] == "text": text_parts.append(item["text"]) content = "\n".join(text_parts) user_message = content elif msg["role"] == "assistant": content = msg["content"] if content: chat_history.append({"role": "ASSISTANT", "message": content}) # Create chat completion response = client.chat( message=user_message, chat_history=chat_history, model=model, temperature=temperature, max_tokens=max_tokens ) return response except Exception as e: logger.error(f"Cohere API error: {str(e)}") raise e def call_glhf_api(payload, api_key_override=None): """Make a call to GLHF API with error handling""" try: if not HAS_OPENAI: raise ImportError("OpenAI package not installed (required for GLHF API)") api_key = api_key_override if api_key_override else GLHF_API_KEY if not api_key: raise ValueError("GLHF API key is required") client = openai.OpenAI( api_key=api_key, base_url="https://glhf.chat/api/openai/v1" ) # Extract parameters from payload model_name = payload.get("model", "mistralai/Mistral-7B-Instruct-v0.3") # Add "hf:" prefix if not already there if not model_name.startswith("hf:"): model = f"hf:{model_name}" else: model = model_name messages = payload.get("messages", []) temperature = payload.get("temperature", 0.7) max_tokens = payload.get("max_tokens", 1000) stream = payload.get("stream", False) # Create completion response = client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, stream=stream ) return response except Exception as e: logger.error(f"GLHF API error: {str(e)}") raise e def extract_ai_response(result, provider): """Extract AI response based on provider format""" try: if provider == "OpenRouter": if isinstance(result, dict): if "choices" in result and len(result["choices"]) > 0: if "message" in result["choices"][0]: message = result["choices"][0]["message"] if message.get("reasoning") and not message.get("content"): reasoning = message.get("reasoning") lines = reasoning.strip().split('\n') for line in lines: if line and not line.startswith('I should') and not line.startswith('Let me'): return line.strip() for line in lines: if line.strip(): return line.strip() return message.get("content", "") elif "delta" in result["choices"][0]: return result["choices"][0]["delta"].get("content", "") elif provider == "OpenAI": if hasattr(result, "choices") and len(result.choices) > 0: return result.choices[0].message.content elif provider == "HuggingFace": return result.get("generated_text", "") elif provider == "Groq": if hasattr(result, "choices") and len(result.choices) > 0: return result.choices[0].message.content elif provider == "Cohere": if hasattr(result, "text"): return result.text elif provider == "GLHF": if hasattr(result, "choices") and len(result.choices) > 0: return result.choices[0].message.content logger.error(f"Unexpected response structure from {provider}: {result}") return f"Error: Could not extract response from {provider} API result" except Exception as e: logger.error(f"Error extracting AI response: {str(e)}") return f"Error: {str(e)}" # ========================================================== # STREAMING HANDLERS # ========================================================== def openrouter_streaming_handler(response, chatbot, message_idx, message): try: # First add the user message if needed if len(chatbot) == message_idx: chatbot.append([message, ""]) for line in response.iter_lines(): if not line: continue line = line.decode('utf-8') if not line.startswith('data: '): continue data = line[6:] if data.strip() == '[DONE]': break try: chunk = json.loads(data) if "choices" in chunk and len(chunk["choices"]) > 0: delta = chunk["choices"][0].get("delta", {}) if "content" in delta and delta["content"]: # Update the current response chatbot[-1][1] += delta["content"] yield chatbot except json.JSONDecodeError: logger.error(f"Failed to parse JSON from chunk: {data}") except Exception as e: logger.error(f"Error in streaming handler: {str(e)}") # Add error message to the current response if len(chatbot) > message_idx: chatbot[-1][1] += f"\n\nError during streaming: {str(e)}" yield chatbot def openai_streaming_handler(response, chatbot, message_idx, message): try: # First add the user message if needed if len(chatbot) == message_idx: chatbot.append([message, ""]) full_response = "" for chunk in response: if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None: content = chunk.choices[0].delta.content full_response += content chatbot[-1][1] = full_response yield chatbot except Exception as e: logger.error(f"Error in OpenAI streaming handler: {str(e)}") # Add error message to the current response chatbot[-1][1] += f"\n\nError during streaming: {str(e)}" yield chatbot def groq_streaming_handler(response, chatbot, message_idx, message): try: # First add the user message if needed if len(chatbot) == message_idx: chatbot.append([message, ""]) full_response = "" for chunk in response: if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None: content = chunk.choices[0].delta.content full_response += content chatbot[-1][1] = full_response yield chatbot except Exception as e: logger.error(f"Error in Groq streaming handler: {str(e)}") # Add error message to the current response chatbot[-1][1] += f"\n\nError during streaming: {str(e)}" yield chatbot def glhf_streaming_handler(response, chatbot, message_idx, message): try: # First add the user message if needed if len(chatbot) == message_idx: chatbot.append([message, ""]) full_response = "" for chunk in response: if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None: content = chunk.choices[0].delta.content full_response += content chatbot[-1][1] = full_response yield chatbot except Exception as e: logger.error(f"Error in GLHF streaming handler: {str(e)}") # Add error message to the current response chatbot[-1][1] += f"\n\nError during streaming: {str(e)}" yield chatbot # ========================================================== # MAIN FUNCTION TO ASK AI # ========================================================== def ask_ai(message, history, provider, model_choice, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty, top_k, min_p, seed, top_a, stream_output, response_format, images, documents, reasoning_effort, system_message, transforms, api_key_override=None): """Enhanced AI query function with support for multiple providers""" # Validate input if not message.strip() and not images and not documents: return history # Copy history to new list to avoid modifying the original chat_history = list(history) # Create messages from chat history messages = format_to_message_dict(chat_history) # Add system message if provided if system_message and system_message.strip(): # Remove any existing system message messages = [msg for msg in messages if msg.get("role") != "system"] # Add new system message at the beginning messages.insert(0, {"role": "system", "content": system_message.strip()}) # Prepare message with images and documents if any content = prepare_message_with_media(message, images, documents) # Add current message messages.append({"role": "user", "content": content}) # Common parameters for all providers common_params = { "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "stream": stream_output } try: # Process based on provider if provider == "OpenRouter": # Get model ID from registry model_id, _ = get_model_info(provider, model_choice) if not model_id: error_message = f"Error: Model '{model_choice}' not found in OpenRouter" chat_history.append([message, error_message]) return chat_history # Build OpenRouter payload payload = { "model": model_id, "messages": messages, **common_params } # Add optional parameters if set if repetition_penalty != 1.0: payload["repetition_penalty"] = repetition_penalty if top_k > 0: payload["top_k"] = top_k if min_p > 0: payload["min_p"] = min_p if seed > 0: payload["seed"] = seed if top_a > 0: payload["top_a"] = top_a # Add response format if JSON is requested if response_format == "json_object": payload["response_format"] = {"type": "json_object"} # Add reasoning if selected if reasoning_effort != "none": payload["reasoning"] = { "effort": reasoning_effort } # Add transforms if selected if transforms: payload["transforms"] = transforms # Call OpenRouter API logger.info(f"Sending request to OpenRouter model: {model_id}") response = call_openrouter_api(payload, api_key_override) # Handle streaming response if stream_output and response.status_code == 200: # Add empty response slot to history chat_history.append([message, ""]) # Set up generator for streaming updates def streaming_generator(): for updated_history in openrouter_streaming_handler(response, chat_history, len(chat_history) - 1, message): yield updated_history return streaming_generator() # Handle normal response elif response.status_code == 200: result = response.json() logger.info(f"Response content: {result}") # Extract AI response ai_response = extract_ai_response(result, provider) # Add response to history chat_history.append([message, ai_response]) return chat_history # Handle error response else: error_message = f"Error: Status code {response.status_code}" try: response_data = response.json() error_message += f"\n\nDetails: {json.dumps(response_data, indent=2)}" except: error_message += f"\n\nResponse: {response.text}" logger.error(error_message) chat_history.append([message, error_message]) return chat_history elif provider == "OpenAI": # Get model ID from registry model_id, _ = get_model_info(provider, model_choice) if not model_id: error_message = f"Error: Model '{model_choice}' not found in OpenAI" chat_history.append([message, error_message]) return chat_history # Build OpenAI payload payload = { "model": model_id, "messages": messages, **common_params } # Add response format if JSON is requested if response_format == "json_object": payload["response_format"] = {"type": "json_object"} # Call OpenAI API logger.info(f"Sending request to OpenAI model: {model_id}") try: response = call_openai_api(payload, api_key_override) # Handle streaming response if stream_output: # Add empty response slot to history chat_history.append([message, ""]) # Set up generator for streaming updates def streaming_generator(): for updated_history in openai_streaming_handler(response, chat_history, len(chat_history) - 1, message): yield updated_history return streaming_generator() # Handle normal response else: ai_response = extract_ai_response(response, provider) chat_history.append([message, ai_response]) return chat_history except Exception as e: error_message = f"OpenAI API Error: {str(e)}" logger.error(error_message) chat_history.append([message, error_message]) return chat_history elif provider == "HuggingFace": # Get model ID from registry model_id, _ = get_model_info(provider, model_choice) if not model_id: error_message = f"Error: Model '{model_choice}' not found in HuggingFace" chat_history.append([message, error_message]) return chat_history # Build HuggingFace payload payload = { "model": model_id, "messages": messages, "temperature": temperature, "max_tokens": max_tokens } # Call HuggingFace API logger.info(f"Sending request to HuggingFace model: {model_id}") try: response = call_huggingface_api(payload, api_key_override) # Extract response ai_response = extract_ai_response(response, provider) chat_history.append([message, ai_response]) return chat_history except Exception as e: error_message = f"HuggingFace API Error: {str(e)}" logger.error(error_message) chat_history.append([message, error_message]) return chat_history elif provider == "Groq": # Get model ID from registry model_id, _ = get_model_info(provider, model_choice) if not model_id: error_message = f"Error: Model '{model_choice}' not found in Groq" chat_history.append([message, error_message]) return chat_history # Build Groq payload payload = { "model": model_id, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "stream": stream_output } # Call Groq API logger.info(f"Sending request to Groq model: {model_id}") try: response = call_groq_api(payload, api_key_override) # Handle streaming response if stream_output: # Add empty response slot to history chat_history.append([message, ""]) # Set up generator for streaming updates def streaming_generator(): for updated_history in groq_streaming_handler(response, chat_history, len(chat_history) - 1, message): yield updated_history return streaming_generator() # Handle normal response else: ai_response = extract_ai_response(response, provider) chat_history.append([message, ai_response]) return chat_history except Exception as e: error_message = f"Groq API Error: {str(e)}" logger.error(error_message) chat_history.append([message, error_message]) return chat_history elif provider == "Cohere": # Get model ID from registry model_id, _ = get_model_info(provider, model_choice) if not model_id: error_message = f"Error: Model '{model_choice}' not found in Cohere" chat_history.append([message, error_message]) return chat_history # Build Cohere payload (doesn't support streaming the same way) payload = { "model": model_id, "messages": messages, "temperature": temperature, "max_tokens": max_tokens } # Call Cohere API logger.info(f"Sending request to Cohere model: {model_id}") try: response = call_cohere_api(payload, api_key_override) # Extract response ai_response = extract_ai_response(response, provider) chat_history.append([message, ai_response]) return chat_history except Exception as e: error_message = f"Cohere API Error: {str(e)}" logger.error(error_message) chat_history.append([message, error_message]) return chat_history elif provider == "GLHF": # Get model ID from registry model_id, _ = get_model_info(provider, model_choice) if not model_id: error_message = f"Error: Model '{model_choice}' not found in GLHF" chat_history.append([message, error_message]) return chat_history # Build GLHF payload payload = { "model": model_id, # The hf: prefix will be added in the API call "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": stream_output } # Call GLHF API logger.info(f"Sending request to GLHF model: {model_id}") try: response = call_glhf_api(payload, api_key_override) # Handle streaming response if stream_output: # Add empty response slot to history chat_history.append([message, ""]) # Set up generator for streaming updates def streaming_generator(): for updated_history in glhf_streaming_handler(response, chat_history, len(chat_history) - 1, message): yield updated_history return streaming_generator() # Handle normal response else: ai_response = extract_ai_response(response, provider) chat_history.append([message, ai_response]) return chat_history except Exception as e: error_message = f"GLHF API Error: {str(e)}" logger.error(error_message) chat_history.append([message, error_message]) return chat_history else: error_message = f"Error: Unsupported provider '{provider}'" chat_history.append([message, error_message]) return chat_history except Exception as e: error_message = f"Error: {str(e)}" logger.error(f"Exception during API call: {error_message}") chat_history.append([message, error_message]) return chat_history def clear_chat(): """Reset all inputs""" return [], "", [], [], 0.7, 1000, 0.8, 0.0, 0.0, 1.0, 40, 0.1, 0, 0.0, False, "default", "none", "", [] # ========================================================== # UI CREATION # ========================================================== def create_app(): """Create the Multi-Provider CrispChat Gradio application""" with gr.Blocks( title="Multi-Provider CrispChat", css=""" .context-size { font-size: 0.9em; color: #666; margin-left: 10px; } footer { display: none !important; } .model-selection-row { display: flex; align-items: center; } .parameter-grid { display: grid; grid-template-columns: 1fr 1fr; gap: 10px; } .vision-badge { background-color: #4CAF50; color: white; padding: 3px 6px; border-radius: 3px; font-size: 0.8em; margin-left: 5px; } .provider-selection { margin-bottom: 10px; padding: 10px; border-radius: 5px; background-color: #f5f5f5; } """ ) as demo: gr.Markdown(""" # 🤖 Multi-Provider CrispChat Chat with AI models from multiple providers: OpenRouter, OpenAI, HuggingFace, Groq, Cohere, and GLHF. """) with gr.Row(): with gr.Column(scale=2): # Chatbot interface chatbot = gr.Chatbot( height=500, show_copy_button=True, show_label=False, avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/0/04/ChatGPT_logo.svg"), type="messages", elem_id="chat-window" ) with gr.Row(): message = gr.Textbox( placeholder="Type your message here...", label="Message", lines=2, elem_id="message-input", scale=4 ) with gr.Row(): with gr.Column(scale=3): submit_btn = gr.Button("Send", variant="primary", elem_id="send-btn") with gr.Column(scale=1): clear_btn = gr.Button("Clear Chat", variant="secondary") with gr.Row(): # Image upload with gr.Accordion("Upload Images (for vision models)", open=False): images = gr.File( label="Uploaded Images", file_types=["image"], file_count="multiple" ) image_upload_btn = gr.UploadButton( label="Upload Images", file_types=["image"], file_count="multiple" ) # Document upload with gr.Accordion("Upload Documents (PDF, MD, TXT)", open=False): documents = gr.File( label="Uploaded Documents", file_types=[".pdf", ".md", ".txt"], file_count="multiple" ) with gr.Column(scale=1): with gr.Group(elem_classes="provider-selection"): gr.Markdown("### Provider Selection") # Provider selection provider_choice = gr.Radio( choices=["OpenRouter", "OpenAI", "HuggingFace", "Groq", "Cohere", "GLHF"], value="OpenRouter", label="AI Provider" ) # API key input api_key_override = gr.Textbox( placeholder="Override API key (leave empty to use environment variable)", label="API Key Override", type="password" ) with gr.Group(): gr.Markdown("### Model Selection") with gr.Row(elem_classes="model-selection-row"): model_search = gr.Textbox( placeholder="Search models...", label="", show_label=False ) # Provider-specific model dropdowns openrouter_model = gr.Dropdown( choices=[model[0] for model in OPENROUTER_ALL_MODELS], value=OPENROUTER_ALL_MODELS[0][0] if OPENROUTER_ALL_MODELS else None, label="OpenRouter Model", elem_id="openrouter-model-choice", visible=True ) openai_model = gr.Dropdown( choices=list(OPENAI_MODELS.keys()), value="gpt-3.5-turbo" if "gpt-3.5-turbo" in OPENAI_MODELS else None, label="OpenAI Model", elem_id="openai-model-choice", visible=False ) hf_model = gr.Dropdown( choices=list(HUGGINGFACE_MODELS.keys()), value="mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in HUGGINGFACE_MODELS else None, label="HuggingFace Model", elem_id="hf-model-choice", visible=False ) groq_model = gr.Dropdown( choices=list(GROQ_MODELS.keys()), value="llama-3.1-8b-instant" if "llama-3.1-8b-instant" in GROQ_MODELS else None, label="Groq Model", elem_id="groq-model-choice", visible=False ) cohere_model = gr.Dropdown( choices=list(COHERE_MODELS.keys()), value="command-r-plus" if "command-r-plus" in COHERE_MODELS else None, label="Cohere Model", elem_id="cohere-model-choice", visible=False ) glhf_model = gr.Dropdown( choices=list(GLHF_MODELS.keys()), value="mistralai/Mistral-7B-Instruct-v0.3" if "mistralai/Mistral-7B-Instruct-v0.3" in GLHF_MODELS else None, label="GLHF Model", elem_id="glhf-model-choice", visible=False ) context_display = gr.Textbox( value=update_context_display("OpenRouter", OPENROUTER_ALL_MODELS[0][0]), label="Context Size", interactive=False, elem_classes="context-size" ) with gr.Accordion("Generation Parameters", open=False): with gr.Group(elem_classes="parameter-grid"): temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) max_tokens = gr.Slider( minimum=100, maximum=4000, value=1000, step=100, label="Max Tokens" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Top P" ) frequency_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty" ) presence_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Presence Penalty" ) reasoning_effort = gr.Radio( ["none", "low", "medium", "high"], value="none", label="Reasoning Effort (OpenRouter)" ) with gr.Accordion("Advanced Options", open=False): with gr.Row(): with gr.Column(): repetition_penalty = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Repetition Penalty" ) top_k = gr.Slider( minimum=1, maximum=100, value=40, step=1, label="Top K" ) min_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Min P" ) with gr.Column(): seed = gr.Number( value=0, label="Seed (0 for random)", precision=0 ) top_a = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="Top A" ) stream_output = gr.Checkbox( label="Stream Output", value=False ) with gr.Row(): response_format = gr.Radio( ["default", "json_object"], value="default", label="Response Format" ) gr.Markdown(""" * **json_object**: Forces the model to respond with valid JSON only. * Only available on certain models - check model support. """) # Custom instructing options with gr.Accordion("Custom Instructions", open=False): system_message = gr.Textbox( placeholder="Enter a system message to guide the model's behavior...", label="System Message", lines=3 ) transforms = gr.CheckboxGroup( ["prompt_optimize", "prompt_distill", "prompt_compress"], label="Prompt Transforms (OpenRouter specific)" ) gr.Markdown(""" * **prompt_optimize**: Improve prompt for better responses. * **prompt_distill**: Compress prompt to use fewer tokens without changing meaning. * **prompt_compress**: Aggressively compress prompt to fit larger contexts. """) # Add a model information section with gr.Accordion("About Selected Model", open=False): model_info_display = gr.HTML( value=update_model_info("OpenRouter", OPENROUTER_ALL_MODELS[0][0]) ) # Add usage instructions with gr.Accordion("Usage Instructions", open=False): gr.Markdown(""" ## Basic Usage 1. Type your message in the input box 2. Select a provider and model 3. Click "Send" or press Enter ## Working with Files - **Images**: Upload images to use with vision-capable models - **Documents**: Upload PDF, Markdown, or text files to analyze their content ## Provider Information - **OpenRouter**: Free access to various models with context window sizes up to 2M tokens - **OpenAI**: Requires an API key, includes GPT-3.5 and GPT-4 models - **HuggingFace**: Direct access to open models, some models require API key - **Groq**: High-performance inference, requires API key - **Cohere**: Specialized in language understanding, requires API key - **GLHF**: Access to HuggingFace models, requires API key ## Advanced Parameters - **Temperature**: Controls randomness (higher = more creative, lower = more deterministic) - **Max Tokens**: Maximum length of the response - **Top P**: Nucleus sampling threshold (higher = consider more tokens) - **Reasoning Effort**: Some models can show their reasoning process (OpenRouter only) """) # Add a footer with version info footer_md = gr.Markdown(""" --- ### Multi-Provider CrispChat v1.0 Built with ❤️ using Gradio and multiple AI provider APIs | Context sizes shown next to model names """) # Define event handlers def toggle_model_dropdowns(provider): """Show/hide model dropdowns based on provider selection""" return [ gr.update(visible=(provider == "OpenRouter")), gr.update(visible=(provider == "OpenAI")), gr.update(visible=(provider == "HuggingFace")), gr.update(visible=(provider == "Groq")), gr.update(visible=(provider == "Cohere")), gr.update(visible=(provider == "GLHF")) ] def update_context_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model): """Update context display based on selected provider and model""" if provider == "OpenRouter": return update_context_display(provider, openrouter_model) elif provider == "OpenAI": return update_context_display(provider, openai_model) elif provider == "HuggingFace": return update_context_display(provider, hf_model) elif provider == "Groq": return update_context_display(provider, groq_model) elif provider == "Cohere": return update_context_display(provider, cohere_model) elif provider == "GLHF": return update_context_display(provider, glhf_model) return "Unknown" def update_model_info_for_provider(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model): """Update model info based on selected provider and model""" if provider == "OpenRouter": return update_model_info(provider, openrouter_model) elif provider == "OpenAI": return update_model_info(provider, openai_model) elif provider == "HuggingFace": return update_model_info(provider, hf_model) elif provider == "Groq": return update_model_info(provider, groq_model) elif provider == "Cohere": return update_model_info(provider, cohere_model) elif provider == "GLHF": return update_model_info(provider, glhf_model) return "

Model information not available

" # Handling model search function - Fixed compared to previous implementation def search_models(provider, search_term): """Filter models for the selected provider based on search term""" filtered_models = [] if provider == "OpenRouter": all_models = [model[0] for model in OPENROUTER_ALL_MODELS] if search_term: filtered_models = [model for model in all_models if search_term.lower() in model.lower()] else: filtered_models = all_models return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None) elif provider == "OpenAI": all_models = list(OPENAI_MODELS.keys()) if search_term: filtered_models = [model for model in all_models if search_term.lower() in model.lower()] else: filtered_models = all_models return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None) elif provider == "HuggingFace": all_models = list(HUGGINGFACE_MODELS.keys()) if search_term: filtered_models = [model for model in all_models if search_term.lower() in model.lower()] else: filtered_models = all_models return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None) elif provider == "Groq": all_models = list(GROQ_MODELS.keys()) if search_term: filtered_models = [model for model in all_models if search_term.lower() in model.lower()] else: filtered_models = all_models return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None) elif provider == "Cohere": all_models = list(COHERE_MODELS.keys()) if search_term: filtered_models = [model for model in all_models if search_term.lower() in model.lower()] else: filtered_models = all_models return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None) elif provider == "GLHF": all_models = list(GLHF_MODELS.keys()) if search_term: filtered_models = [model for model in all_models if search_term.lower() in model.lower()] else: filtered_models = all_models return gr.update(choices=filtered_models, value=filtered_models[0] if filtered_models else None) # Default return in case of unknown provider return gr.update(choices=[], value=None) def refresh_groq_models_list(): """Refresh the list of Groq models""" global GROQ_MODELS GROQ_MODELS = fetch_groq_models() return gr.update(choices=list(GROQ_MODELS.keys())) def get_current_model(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model): """Get the currently selected model based on provider""" if provider == "OpenRouter": return openrouter_model elif provider == "OpenAI": return openai_model elif provider == "HuggingFace": return hf_model elif provider == "Groq": return groq_model elif provider == "Cohere": return cohere_model elif provider == "GLHF": return glhf_model return None # Process uploaded images image_upload_btn.upload( fn=lambda files: files, inputs=image_upload_btn, outputs=images ) # Set up provider selection event provider_choice.change( fn=toggle_model_dropdowns, inputs=provider_choice, outputs=[openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model] ).then( fn=update_context_for_provider, inputs=[provider_choice, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model], outputs=context_display ).then( fn=update_model_info_for_provider, inputs=[provider_choice, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model], outputs=model_info_display ) # Set up model search event - FIXED VERSION # Important: We need to return a proper Gradio component update for each dropdown model_search.change( fn=search_models, inputs=[provider_choice, model_search], outputs=[openrouter_model] # This will be handled by the JS forwarding logic ) # Set up model change events openrouter_model.change( fn=lambda model: update_context_display("OpenRouter", model), inputs=openrouter_model, outputs=context_display ).then( fn=lambda model: update_model_info("OpenRouter", model), inputs=openrouter_model, outputs=model_info_display ) openai_model.change( fn=lambda model: update_context_display("OpenAI", model), inputs=openai_model, outputs=context_display ).then( fn=lambda model: update_model_info("OpenAI", model), inputs=openai_model, outputs=model_info_display ) hf_model.change( fn=lambda model: update_context_display("HuggingFace", model), inputs=hf_model, outputs=context_display ).then( fn=lambda model: update_model_info("HuggingFace", model), inputs=hf_model, outputs=model_info_display ) groq_model.change( fn=lambda model: update_context_display("Groq", model), inputs=groq_model, outputs=context_display ).then( fn=lambda model: update_model_info("Groq", model), inputs=groq_model, outputs=model_info_display ) cohere_model.change( fn=lambda model: update_context_display("Cohere", model), inputs=cohere_model, outputs=context_display ).then( fn=lambda model: update_model_info("Cohere", model), inputs=cohere_model, outputs=model_info_display ) glhf_model.change( fn=lambda model: update_context_display("GLHF", model), inputs=glhf_model, outputs=context_display ).then( fn=lambda model: update_model_info("GLHF", model), inputs=glhf_model, outputs=model_info_display ) # Add custom JavaScript for routing model search to visible dropdown gr.HTML(""" """) # Set up submission event def submit_message(message, history, provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty, top_k, min_p, seed, top_a, stream_output, response_format, images, documents, reasoning_effort, system_message, transforms, api_key_override): """Submit message to selected provider and model""" # Get the currently selected model model_choice = get_current_model(provider, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model) # Check if model is selected if not model_choice: history.append([message, f"Error: No model selected for provider {provider}"]) return history # Call the ask_ai function with the appropriate parameters return ask_ai( message=message, history=history, provider=provider, model_choice=model_choice, temperature=temperature, max_tokens=max_tokens, top_p=top_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repetition_penalty=repetition_penalty, top_k=top_k, min_p=min_p, seed=seed, top_a=top_a, stream_output=stream_output, response_format=response_format, images=images, documents=documents, reasoning_effort=reasoning_effort, system_message=system_message, transforms=transforms, api_key_override=api_key_override ) # Submit button click event submit_btn.click( fn=submit_message, inputs=[ message, chatbot, provider_choice, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty, top_k, min_p, seed, top_a, stream_output, response_format, images, documents, reasoning_effort, system_message, transforms, api_key_override ], outputs=chatbot, show_progress="minimal", ).then( fn=lambda: "", # Clear message box after sending inputs=None, outputs=message ) # Also submit on Enter key message.submit( fn=submit_message, inputs=[ message, chatbot, provider_choice, openrouter_model, openai_model, hf_model, groq_model, cohere_model, glhf_model, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty, top_k, min_p, seed, top_a, stream_output, response_format, images, documents, reasoning_effort, system_message, transforms, api_key_override ], outputs=chatbot, show_progress="minimal", ).then( fn=lambda: "", # Clear message box after sending inputs=None, outputs=message ) # Clear chat button clear_btn.click( fn=clear_chat, inputs=[], outputs=[ chatbot, message, images, documents, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, repetition_penalty, top_k, min_p, seed, top_a, stream_output, response_format, reasoning_effort, system_message, transforms ] ) return demo # Launch the app if __name__ == "__main__": # Check API keys and print status missing_keys = [] if not OPENROUTER_API_KEY: logger.warning("WARNING: OPENROUTER_API_KEY environment variable is not set") missing_keys.append("OpenRouter") if not OPENAI_API_KEY: logger.warning("WARNING: OPENAI_API_KEY environment variable is not set") missing_keys.append("OpenAI") if not GROQ_API_KEY: logger.warning("WARNING: GROQ_API_KEY environment variable is not set") missing_keys.append("Groq") if not COHERE_API_KEY: logger.warning("WARNING: COHERE_API_KEY environment variable is not set") missing_keys.append("Cohere") if not GLHF_API_KEY: logger.warning("WARNING: GLHF_API_KEY environment variable is not set") missing_keys.append("GLHF") if missing_keys: print("Missing API keys for the following providers:") for key in missing_keys: print(f"- {key}") print("\nYou can still use the application, but some providers will require API keys.") print("You can provide API keys through environment variables or use the API Key Override field.") if "OpenRouter" in missing_keys: print("\nNote: OpenRouter offers free tier access to many models!") print("\nStarting Multi-Provider CrispChat application...") demo = create_app() demo.launch( server_name="0.0.0.0", server_port=7860, debug=True, show_error=True )