Spaces:
Running
Running
import os | |
import pdb | |
from dataclasses import dataclass | |
from dotenv import load_dotenv | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_ollama import ChatOllama | |
load_dotenv() | |
import sys | |
sys.path.append(".") | |
class LLMConfig: | |
provider: str | |
model_name: str | |
temperature: float = 0.8 | |
base_url: str = None | |
api_key: str = None | |
def create_message_content(text, image_path=None): | |
content = [{"type": "text", "text": text}] | |
image_format = "png" if image_path and image_path.endswith(".png") else "jpeg" | |
if image_path: | |
from src.utils import utils | |
image_data = utils.encode_image(image_path) | |
content.append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/{image_format};base64,{image_data}"} | |
}) | |
return content | |
def get_env_value(key, provider): | |
env_mappings = { | |
"openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"}, | |
"azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"}, | |
"google": {"api_key": "GOOGLE_API_KEY"}, | |
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"}, | |
"mistral": {"api_key": "MISTRAL_API_KEY", "base_url": "MISTRAL_ENDPOINT"}, | |
"alibaba": {"api_key": "ALIBABA_API_KEY", "base_url": "ALIBABA_ENDPOINT"}, | |
"moonshot":{"api_key": "MOONSHOT_API_KEY", "base_url": "MOONSHOT_ENDPOINT"}, | |
} | |
if provider in env_mappings and key in env_mappings[provider]: | |
return os.getenv(env_mappings[provider][key], "") | |
return "" | |
def test_llm(config, query, image_path=None, system_message=None): | |
from src.utils import utils | |
# Special handling for Ollama-based models | |
if config.provider == "ollama": | |
if "deepseek-r1" in config.model_name: | |
from src.utils.llm import DeepSeekR1ChatOllama | |
llm = DeepSeekR1ChatOllama(model=config.model_name) | |
else: | |
llm = ChatOllama(model=config.model_name) | |
ai_msg = llm.invoke(query) | |
print(ai_msg.content) | |
if "deepseek-r1" in config.model_name: | |
pdb.set_trace() | |
return | |
# For other providers, use the standard configuration | |
llm = utils.get_llm_model( | |
provider=config.provider, | |
model_name=config.model_name, | |
temperature=config.temperature, | |
base_url=config.base_url or get_env_value("base_url", config.provider), | |
api_key=config.api_key or get_env_value("api_key", config.provider) | |
) | |
# Prepare messages for non-Ollama models | |
messages = [] | |
if system_message: | |
messages.append(SystemMessage(content=create_message_content(system_message))) | |
messages.append(HumanMessage(content=create_message_content(query, image_path))) | |
ai_msg = llm.invoke(messages) | |
# Handle different response types | |
if hasattr(ai_msg, "reasoning_content"): | |
print(ai_msg.reasoning_content) | |
print(ai_msg.content) | |
if config.provider == "deepseek" and "deepseek-reasoner" in config.model_name: | |
print(llm.model_name) | |
pdb.set_trace() | |
def test_openai_model(): | |
config = LLMConfig(provider="openai", model_name="gpt-4o") | |
test_llm(config, "Describe this image", "assets/examples/test.png") | |
def test_google_model(): | |
# Enable your API key first if you haven't: https://ai.google.dev/palm_docs/oauth_quickstart | |
config = LLMConfig(provider="google", model_name="gemini-2.0-flash-exp") | |
test_llm(config, "Describe this image", "assets/examples/test.png") | |
def test_azure_openai_model(): | |
config = LLMConfig(provider="azure_openai", model_name="gpt-4o") | |
test_llm(config, "Describe this image", "assets/examples/test.png") | |
def test_deepseek_model(): | |
config = LLMConfig(provider="deepseek", model_name="deepseek-chat") | |
test_llm(config, "Who are you?") | |
def test_deepseek_r1_model(): | |
config = LLMConfig(provider="deepseek", model_name="deepseek-reasoner") | |
test_llm(config, "Which is greater, 9.11 or 9.8?", system_message="You are a helpful AI assistant.") | |
def test_ollama_model(): | |
config = LLMConfig(provider="ollama", model_name="qwen2.5:7b") | |
test_llm(config, "Sing a ballad of LangChain.") | |
def test_deepseek_r1_ollama_model(): | |
config = LLMConfig(provider="ollama", model_name="deepseek-r1:14b") | |
test_llm(config, "How many 'r's are in the word 'strawberry'?") | |
def test_mistral_model(): | |
config = LLMConfig(provider="mistral", model_name="pixtral-large-latest") | |
test_llm(config, "Describe this image", "assets/examples/test.png") | |
def test_moonshot_model(): | |
config = LLMConfig(provider="moonshot", model_name="moonshot-v1-32k-vision-preview") | |
test_llm(config, "Describe this image", "assets/examples/test.png") | |
if __name__ == "__main__": | |
# test_openai_model() | |
# test_google_model() | |
# test_azure_openai_model() | |
#test_deepseek_model() | |
# test_ollama_model() | |
test_deepseek_r1_model() | |
# test_deepseek_r1_ollama_model() | |
# test_mistral_model() | |