Maharshi Gor
Adds support for caching llm calls to a sqlite db and a hf dataset. Refactors repo creation logic and fixes unused temperature param.
3a1af80
# %% | |
import json | |
import os | |
from typing import Any, Optional | |
import cohere | |
import numpy as np | |
from langchain_anthropic import ChatAnthropic | |
from langchain_cohere import ChatCohere | |
from langchain_core.language_models import BaseChatModel | |
from langchain_openai import ChatOpenAI | |
from loguru import logger | |
from openai import OpenAI | |
from pydantic import BaseModel, Field | |
from rich import print as rprint | |
from .configs import AVAILABLE_MODELS | |
from .llmcache import LLMCache | |
# Initialize global cache | |
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache") | |
def _openai_is_json_mode_supported(model_name: str) -> bool: | |
if model_name.startswith("gpt-4"): | |
return True | |
if model_name.startswith("gpt-3.5"): | |
return False | |
logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False") | |
return False | |
class LLMOutput(BaseModel): | |
content: str = Field(description="The content of the response") | |
logprob: Optional[float] = Field(None, description="The log probability of the response") | |
def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str: | |
output = llm.invoke([("system", system), ("human", prompt)]) | |
ai_message = output["raw"] | |
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls} | |
content_str = json.dumps(content) | |
return {"content": content_str, "output": output["parsed"].model_dump()} | |
def _cohere_completion( | |
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True | |
) -> str: | |
messages = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": prompt}, | |
] | |
client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY")) | |
response = client.chat( | |
model=model, | |
messages=messages, | |
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()}, | |
logprobs=logprobs, | |
temperature=temperature, | |
) | |
output = {} | |
output["content"] = response.message.content[0].text | |
output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump() | |
if logprobs: | |
output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs) | |
output["prob"] = np.exp(output["logprob"]) | |
return output | |
def _openai_langchain_completion( | |
model: str, system: str, prompt: str, response_model, temperature: float | None = None | |
) -> str: | |
llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True) | |
return _get_langchain_chat_output(llm, system, prompt) | |
def _openai_completion( | |
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True | |
) -> str: | |
messages = [ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": prompt}, | |
] | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
response = client.beta.chat.completions.parse( | |
model=model, | |
messages=messages, | |
response_format=response_model, | |
logprobs=logprobs, | |
temperature=temperature, | |
) | |
output = {} | |
output["content"] = response.choices[0].message.content | |
output["output"] = response.choices[0].message.parsed.model_dump() | |
if logprobs: | |
output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content) | |
output["prob"] = np.exp(output["logprob"]) | |
return output | |
def _anthropic_completion( | |
model: str, system: str, prompt: str, response_model, temperature: float | None = None | |
) -> str: | |
llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True) | |
return _get_langchain_chat_output(llm, system, prompt) | |
def _llm_completion( | |
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False | |
) -> dict[str, Any]: | |
""" | |
Generate a completion from an LLM provider with structured output without caching. | |
Args: | |
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4") | |
system (str): System prompt/instructions for the model | |
prompt (str): User prompt/input | |
response_format: Pydantic model defining the expected response structure | |
logprobs (bool, optional): Whether to return log probabilities. Defaults to False. | |
Note: Not supported by Anthropic models. | |
Returns: | |
dict: Contains: | |
- output: The structured response matching response_format | |
- logprob: (optional) Sum of log probabilities if logprobs=True | |
- prob: (optional) Exponential of logprob if logprobs=True | |
Raises: | |
ValueError: If logprobs=True with Anthropic models | |
""" | |
if model not in AVAILABLE_MODELS: | |
raise ValueError(f"Model {model} not supported") | |
model_name = AVAILABLE_MODELS[model]["model"] | |
provider = model.split("/")[0] | |
if provider == "Cohere": | |
return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs) | |
elif provider == "OpenAI": | |
if _openai_is_json_mode_supported(model_name): | |
return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs) | |
elif logprobs: | |
raise ValueError(f"{model} does not support logprobs feature.") | |
else: | |
return _openai_langchain_completion(model_name, system, prompt, response_format, temperature) | |
elif provider == "Anthropic": | |
if logprobs: | |
raise ValueError("Anthropic models do not support logprobs") | |
return _anthropic_completion(model_name, system, prompt, response_format, temperature) | |
else: | |
raise ValueError(f"Provider {provider} not supported") | |
def completion( | |
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False | |
) -> dict[str, Any]: | |
""" | |
Generate a completion from an LLM provider with structured output with caching. | |
Args: | |
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4") | |
system (str): System prompt/instructions for the model | |
prompt (str): User prompt/input | |
response_format: Pydantic model defining the expected response structure | |
logprobs (bool, optional): Whether to return log probabilities. Defaults to False. | |
Note: Not supported by Anthropic models. | |
Returns: | |
dict: Contains: | |
- output: The structured response matching response_format | |
- logprob: (optional) Sum of log probabilities if logprobs=True | |
- prob: (optional) Exponential of logprob if logprobs=True | |
Raises: | |
ValueError: If logprobs=True with Anthropic models | |
""" | |
# Check cache first | |
cached_response = llm_cache.get(model, system, prompt, response_format, temperature) | |
if cached_response is not None: | |
logger.info(f"Cache hit for model {model}") | |
return cached_response | |
logger.info(f"Cache miss for model {model}, calling API") | |
# Continue with the original implementation for cache miss | |
response = _llm_completion(model, system, prompt, response_format, temperature, logprobs) | |
# Update cache with the new response | |
llm_cache.set( | |
model, | |
system, | |
prompt, | |
response_format, | |
temperature, | |
response, | |
) | |
return response | |
# %% | |
if __name__ == "__main__": | |
from tqdm import tqdm | |
class ExplainedAnswer(BaseModel): | |
""" | |
The answer to the question and a terse explanation of the answer. | |
""" | |
answer: str = Field(description="The short answer to the question") | |
explanation: str = Field(description="5 words terse best explanation of the answer.") | |
models = list(AVAILABLE_MODELS.keys())[:1] # Just use the first model for testing | |
system = "You are an accurate and concise explainer of scientific concepts." | |
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed." | |
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache", reset=True) | |
# First call - should be a cache miss | |
logger.info("First call - should be a cache miss") | |
for model in tqdm(models): | |
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False) | |
rprint(response) | |
# Second call - should be a cache hit | |
logger.info("Second call - should be a cache hit") | |
for model in tqdm(models): | |
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False) | |
rprint(response) | |
# Slightly different prompt - should be a cache miss | |
logger.info("Different prompt - should be a cache miss") | |
prompt2 = "Which planet is closest to the sun? Answer directly." | |
for model in tqdm(models): | |
response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False) | |
rprint(response) | |
# Get cache entries count from SQLite | |
try: | |
cache_entries = llm_cache.get_all_entries() | |
logger.info(f"Cache now has {len(cache_entries)} items") | |
except Exception as e: | |
logger.error(f"Failed to get cache entries: {e}") | |
# Test adding entry with temperature parameter | |
logger.info("Testing with temperature parameter") | |
response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False) | |
rprint(response) | |
# Demonstrate forced sync to HF if repo is configured | |
if llm_cache.hf_repo_id: | |
logger.info("Forcing sync to HF dataset") | |
try: | |
llm_cache.sync_to_hf() | |
logger.info("Successfully synced to HF dataset") | |
except Exception as e: | |
logger.exception(f"Failed to sync to HF: {e}") | |
else: | |
logger.info("HF repo not configured, skipping sync test") | |
# %% | |