Spaces:
Running
Running
import os | |
import time | |
import random | |
import logging | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
from utils import read_config | |
# --- Load environment & config --- | |
load_dotenv() | |
_config = read_config()["llm"] | |
# --- Logging setup --- | |
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() | |
logger = logging.getLogger("polLLM") | |
logger.setLevel(LOG_LEVEL) | |
handler = logging.StreamHandler() | |
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")) | |
logger.addHandler(handler) | |
# --- LLM settings from config.yaml --- | |
_DEFAULT_MODEL = _config.get("model","chutesai/Llama-4-Maverick-17B-128E-Instruct-FP8") # _config.get("model", "openai-large") | |
_SYSTEM_TEMPLATE = _config.get("system_prompt", "") | |
_CHAR = _config.get("char", "Eve") | |
_CHUTES_API_KEY = os.getenv("CHUTES_API_KEY") | |
# --- Custom exception --- | |
class LLMBadRequestError(Exception): | |
"""Raised when the LLM returns HTTP 400 (Bad Request).""" | |
pass | |
# --- OpenAI client init --- | |
client = OpenAI( | |
base_url="https://llm.chutes.ai/v1/", | |
api_key=_CHUTES_API_KEY | |
) | |
def _build_system_prompt() -> str: | |
""" | |
Substitute {char} into the system prompt template. | |
""" | |
return _SYSTEM_TEMPLATE.replace("{char}", _CHAR) | |
def generate_llm( | |
prompt: str, | |
) -> str: | |
""" | |
Send a chat-completion request to the LLM, with retries and backoff. | |
Reads defaults from config.yaml, but can be overridden per-call. | |
""" | |
model = _DEFAULT_MODEL | |
system_prompt = _build_system_prompt() | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": prompt}, | |
] | |
backoff = 1 | |
for attempt in range(1, 6): | |
try: | |
seed = random.randint(0, 2**31 - 1) | |
logger.debug(f"LLM call attempt={attempt}, model={model}, seed={seed}") | |
resp = client.chat.completions.create( | |
model=model, | |
messages=messages, | |
seed=seed, | |
) | |
text = resp.choices[0].message.content.strip() | |
logger.debug("LLM response received") | |
return text | |
except Exception as e: | |
if getattr(e, "status_code", None) == 400: | |
logger.error("LLM error 400 (Bad Request): Not retrying.") | |
raise LLMBadRequestError("LLM returned HTTP 400") | |
logger.error(f"LLM error on attempt {attempt}: {e}") | |
if attempt < 5: | |
time.sleep(backoff) | |
backoff *= 2 | |
else: | |
logger.critical("LLM failed after 5 attempts, raising") | |
raise | |
# Example local test | |
if __name__ == "__main__": | |
logger.info("Testing generate_llm() with a sample prompt") | |
try: | |
print(generate_llm("generate 4 images of 1:1 profile picture")) | |
except LLMBadRequestError as e: | |
logger.warning(f"Test failed with bad request: {e}") | |