File size: 7,111 Bytes
b34efa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""
Hugging Face API integration for Norwegian RAG chatbot.
Provides functions to interact with Hugging Face Inference API for both LLM and embedding models.
"""

import os
import json
import time
import requests
from typing import Dict, List, Optional, Union, Any

from .config import (
    LLM_MODELS, 
    DEFAULT_LLM_MODEL, 
    EMBEDDING_MODELS, 
    DEFAULT_EMBEDDING_MODEL,
    HF_API_ENDPOINTS,
    API_PARAMS
)

class HuggingFaceAPI:
    """
    Client for interacting with Hugging Face Inference API.
    Supports both text generation (LLM) and embedding generation.
    """
    
    def __init__(
        self, 
        api_key: Optional[str] = None,
        llm_model: str = DEFAULT_LLM_MODEL,
        embedding_model: str = DEFAULT_EMBEDDING_MODEL
    ):
        """
        Initialize the Hugging Face API client.
        
        Args:
            api_key: Hugging Face API key (optional, can use HF_API_KEY env var)
            llm_model: LLM model identifier from config
            embedding_model: Embedding model identifier from config
        """
        self.api_key = api_key or os.environ.get("HF_API_KEY", "")
        
        # Set up model IDs
        self.llm_model_id = LLM_MODELS[llm_model]["model_id"] if llm_model in LLM_MODELS else LLM_MODELS[DEFAULT_LLM_MODEL]["model_id"]
        self.embedding_model_id = EMBEDDING_MODELS[embedding_model]["model_id"] if embedding_model in EMBEDDING_MODELS else EMBEDDING_MODELS[DEFAULT_EMBEDDING_MODEL]["model_id"]
        
        # Set up headers
        self.headers = {"Authorization": f"Bearer {self.api_key}"}
        if not self.api_key:
            print("Warning: No API key provided. API calls may be rate limited.")
            self.headers = {}
    
    def generate_text(
        self, 
        prompt: str,
        max_length: int = API_PARAMS["max_length"],
        temperature: float = API_PARAMS["temperature"],
        top_p: float = API_PARAMS["top_p"],
        top_k: int = API_PARAMS["top_k"],
        repetition_penalty: float = API_PARAMS["repetition_penalty"],
        wait_for_model: bool = True
    ) -> str:
        """
        Generate text using the LLM model.
        
        Args:
            prompt: Input text prompt
            max_length: Maximum length of generated text
            temperature: Sampling temperature
            top_p: Top-p sampling parameter
            top_k: Top-k sampling parameter
            repetition_penalty: Penalty for repetition
            wait_for_model: Whether to wait for model to load
            
        Returns:
            Generated text response
        """
        payload = {
            "inputs": prompt,
            "parameters": {
                "max_length": max_length,
                "temperature": temperature,
                "top_p": top_p,
                "top_k": top_k,
                "repetition_penalty": repetition_penalty
            }
        }
        
        api_url = f"{HF_API_ENDPOINTS['inference']}{self.llm_model_id}"
        
        # Make API request
        response = self._make_api_request(api_url, payload, wait_for_model)
        
        # Parse response
        if isinstance(response, list) and len(response) > 0:
            if "generated_text" in response[0]:
                return response[0]["generated_text"]
            return response[0].get("text", "")
        elif isinstance(response, dict):
            return response.get("generated_text", "")
        
        # Fallback
        return str(response)
    
    def generate_embeddings(
        self, 
        texts: Union[str, List[str]],
        wait_for_model: bool = True
    ) -> List[List[float]]:
        """
        Generate embeddings for text using the embedding model.
        
        Args:
            texts: Single text or list of texts to embed
            wait_for_model: Whether to wait for model to load
            
        Returns:
            List of embedding vectors
        """
        # Ensure texts is a list
        if isinstance(texts, str):
            texts = [texts]
        
        payload = {
            "inputs": texts,
        }
        
        api_url = f"{HF_API_ENDPOINTS['feature-extraction']}{self.embedding_model_id}"
        
        # Make API request
        response = self._make_api_request(api_url, payload, wait_for_model)
        
        # Return embeddings
        return response
    
    def _make_api_request(
        self, 
        api_url: str, 
        payload: Dict[str, Any],
        wait_for_model: bool = True,
        max_retries: int = 5,
        retry_delay: int = 1
    ) -> Any:
        """
        Make a request to the Hugging Face API with retry logic.
        
        Args:
            api_url: API endpoint URL
            payload: Request payload
            wait_for_model: Whether to wait for model to load
            max_retries: Maximum number of retries
            retry_delay: Delay between retries in seconds
            
        Returns:
            API response
        """
        for attempt in range(max_retries):
            try:
                response = requests.post(api_url, headers=self.headers, json=payload)
                
                # Check if model is still loading
                if response.status_code == 503 and wait_for_model:
                    # Model is loading, wait and retry
                    estimated_time = json.loads(response.content.decode("utf-8")).get("estimated_time", 20)
                    print(f"Model is loading. Waiting {estimated_time} seconds...")
                    time.sleep(estimated_time)
                    continue
                
                # Check for other errors
                if response.status_code != 200:
                    print(f"API request failed with status code {response.status_code}: {response.text}")
                    if attempt < max_retries - 1:
                        time.sleep(retry_delay * (2 ** attempt))  # Exponential backoff
                        continue
                    return {"error": response.text}
                
                return response.json()
                
            except Exception as e:
                print(f"API request failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(retry_delay * (2 ** attempt))  # Exponential backoff
                    continue
                return {"error": str(e)}
        
        return {"error": "Max retries exceeded"}


# Example RAG prompt template for Norwegian
def create_rag_prompt(query: str, context: List[str]) -> str:
    """
    Create a RAG prompt with retrieved context for the LLM.
    
    Args:
        query: User query
        context: List of retrieved document chunks
        
    Returns:
        Formatted prompt with context
    """
    context_text = "\n\n".join([f"Dokument {i+1}:\n{chunk}" for i, chunk in enumerate(context)])
    
    prompt = f"""Du er en hjelpsom assistent som svarer på norsk. Bruk følgende kontekst for å svare på spørsmålet.
    
KONTEKST:
{context_text}

SPØRSMÅL:
{query}

SVAR:
"""
    return prompt