|
""" |
|
Hugging Face API integration for Norwegian RAG chatbot. |
|
Provides functions to interact with Hugging Face Inference API for both LLM and embedding models. |
|
""" |
|
|
|
import os |
|
import json |
|
import time |
|
import requests |
|
from typing import Dict, List, Optional, Union, Any |
|
|
|
from .config import ( |
|
LLM_MODELS, |
|
DEFAULT_LLM_MODEL, |
|
EMBEDDING_MODELS, |
|
DEFAULT_EMBEDDING_MODEL, |
|
HF_API_ENDPOINTS, |
|
API_PARAMS |
|
) |
|
|
|
class HuggingFaceAPI: |
|
""" |
|
Client for interacting with Hugging Face Inference API. |
|
Supports both text generation (LLM) and embedding generation. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
api_key: Optional[str] = None, |
|
llm_model: str = DEFAULT_LLM_MODEL, |
|
embedding_model: str = DEFAULT_EMBEDDING_MODEL |
|
): |
|
""" |
|
Initialize the Hugging Face API client. |
|
|
|
Args: |
|
api_key: Hugging Face API key (optional, can use HF_API_KEY env var) |
|
llm_model: LLM model identifier from config |
|
embedding_model: Embedding model identifier from config |
|
""" |
|
self.api_key = api_key or os.environ.get("HF_API_KEY", "") |
|
|
|
|
|
self.llm_model_id = LLM_MODELS[llm_model]["model_id"] if llm_model in LLM_MODELS else LLM_MODELS[DEFAULT_LLM_MODEL]["model_id"] |
|
self.embedding_model_id = EMBEDDING_MODELS[embedding_model]["model_id"] if embedding_model in EMBEDDING_MODELS else EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL]["model_id"] |
|
|
|
|
|
self.headers = {"Authorization": f"Bearer {self.api_key}"} |
|
if not self.api_key: |
|
print("Warning: No API key provided. API calls may be rate limited.") |
|
self.headers = {} |
|
|
|
def generate_text( |
|
self, |
|
prompt: str, |
|
max_length: int = API_PARAMS["max_length"], |
|
temperature: float = API_PARAMS["temperature"], |
|
top_p: float = API_PARAMS["top_p"], |
|
top_k: int = API_PARAMS["top_k"], |
|
repetition_penalty: float = API_PARAMS["repetition_penalty"], |
|
wait_for_model: bool = True |
|
) -> str: |
|
""" |
|
Generate text using the LLM model. |
|
|
|
Args: |
|
prompt: Input text prompt |
|
max_length: Maximum length of generated text |
|
temperature: Sampling temperature |
|
top_p: Top-p sampling parameter |
|
top_k: Top-k sampling parameter |
|
repetition_penalty: Penalty for repetition |
|
wait_for_model: Whether to wait for model to load |
|
|
|
Returns: |
|
Generated text response |
|
""" |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_length": max_length, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"repetition_penalty": repetition_penalty |
|
} |
|
} |
|
|
|
api_url = f"{HF_API_ENDPOINTS['inference']}{self.llm_model_id}" |
|
|
|
|
|
response = self._make_api_request(api_url, payload, wait_for_model) |
|
|
|
|
|
if isinstance(response, list) and len(response) > 0: |
|
if "generated_text" in response[0]: |
|
return response[0]["generated_text"] |
|
return response[0].get("text", "") |
|
elif isinstance(response, dict): |
|
return response.get("generated_text", "") |
|
|
|
|
|
return str(response) |
|
|
|
def generate_embeddings( |
|
self, |
|
texts: Union[str, List[str]], |
|
wait_for_model: bool = True |
|
) -> List[List[float]]: |
|
""" |
|
Generate embeddings for text using the embedding model. |
|
|
|
Args: |
|
texts: Single text or list of texts to embed |
|
wait_for_model: Whether to wait for model to load |
|
|
|
Returns: |
|
List of embedding vectors |
|
""" |
|
|
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
payload = { |
|
"inputs": texts, |
|
} |
|
|
|
api_url = f"{HF_API_ENDPOINTS['feature-extraction']}{self.embedding_model_id}" |
|
|
|
|
|
response = self._make_api_request(api_url, payload, wait_for_model) |
|
|
|
|
|
return response |
|
|
|
def _make_api_request( |
|
self, |
|
api_url: str, |
|
payload: Dict[str, Any], |
|
wait_for_model: bool = True, |
|
max_retries: int = 5, |
|
retry_delay: int = 1 |
|
) -> Any: |
|
""" |
|
Make a request to the Hugging Face API with retry logic. |
|
|
|
Args: |
|
api_url: API endpoint URL |
|
payload: Request payload |
|
wait_for_model: Whether to wait for model to load |
|
max_retries: Maximum number of retries |
|
retry_delay: Delay between retries in seconds |
|
|
|
Returns: |
|
API response |
|
""" |
|
for attempt in range(max_retries): |
|
try: |
|
response = requests.post(api_url, headers=self.headers, json=payload) |
|
|
|
|
|
if response.status_code == 503 and wait_for_model: |
|
|
|
estimated_time = json.loads(response.content.decode("utf-8")).get("estimated_time", 20) |
|
print(f"Model is loading. Waiting {estimated_time} seconds...") |
|
time.sleep(estimated_time) |
|
continue |
|
|
|
|
|
if response.status_code != 200: |
|
print(f"API request failed with status code {response.status_code}: {response.text}") |
|
if attempt < max_retries - 1: |
|
time.sleep(retry_delay * (2 ** attempt)) |
|
continue |
|
return {"error": response.text} |
|
|
|
return response.json() |
|
|
|
except Exception as e: |
|
print(f"API request failed: {str(e)}") |
|
if attempt < max_retries - 1: |
|
time.sleep(retry_delay * (2 ** attempt)) |
|
continue |
|
return {"error": str(e)} |
|
|
|
return {"error": "Max retries exceeded"} |
|
|
|
|
|
|
|
def create_rag_prompt(query: str, context: List[str]) -> str: |
|
""" |
|
Create a RAG prompt with retrieved context for the LLM. |
|
|
|
Args: |
|
query: User query |
|
context: List of retrieved document chunks |
|
|
|
Returns: |
|
Formatted prompt with context |
|
""" |
|
context_text = "\n\n".join([f"Dokument {i+1}:\n{chunk}" for i, chunk in enumerate(context)]) |
|
|
|
prompt = f"""Du er en hjelpsom assistent som svarer på norsk. Bruk følgende kontekst for å svare på spørsmålet. |
|
|
|
KONTEKST: |
|
{context_text} |
|
|
|
SPØRSMÅL: |
|
{query} |
|
|
|
SVAR: |
|
""" |
|
return prompt |
|
|