File size: 3,795 Bytes
5889992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Optional, Dict, Any
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import AIMessage
import logging
from src.llm.core.config import settings
from src.llm.utils.logging import TheryBotLogger

class LLMError(Exception):
    """Custom exception for LLM-related errors"""
    pass

class TheryLLM:
    """Enhanced LLM wrapper with safety checks and response validation"""
    
    def __init__(
        self,
        model_name: str = "gemini-1.5-flash",
        temperature: float = 0.3,
        max_retries: int = 3,
        safety_threshold: float = 0.75,
        logger: Optional[TheryBotLogger] = None
    ):
        self.model_name = model_name
        self.temperature = temperature
        self.max_retries = max_retries
        self.safety_threshold = safety_threshold
        self.logger = logger or TheryBotLogger()
        self._initialize_llm()
    
    def _initialize_llm(self) -> None:
        """Initialize the LLM with proper error handling"""
        try:
            self.llm = ChatGoogleGenerativeAI(
                model=self.model_name,
                temperature=self.temperature,
                max_retries=self.max_retries,
                google_api_key=settings.GOOGLE_API_KEY,
                max_tokens= settings.MAX_TOKENS
            )
            self._session_active = True
        except Exception as e:
            self._session_active = False
            self.logger.log_interaction(
                interaction_type="llm_initialization_failed",
                data={"error": str(e)},
                level=logging.ERROR
            )
            raise LLMError(f"LLM initialization failed: {str(e)}")
    
    def generate(self, prompt: str, **kwargs) -> AIMessage:
        """Generate a response with safety checks and validation"""
        if not self._session_active:
            self._initialize_llm()
        
        try:
            # Log the generation attempt
            self.logger.log_interaction(
                interaction_type="llm_generation_attempt",
                data={"prompt": prompt, "kwargs": kwargs},
                level=logging.INFO
            )
            
            # Generate response
            response = self.llm.invoke(prompt)
            
            # Validate response
            validated_response = self._validate_response(response)
            
            # Log successful generation
            self.logger.log_interaction(
                interaction_type="llm_generation_success",
                data={"prompt": prompt, "response": str(validated_response)},
                level=logging.INFO
            )
            
            return validated_response
            
        except Exception as e:
            self.logger.log_interaction(
                interaction_type="llm_generation_error",
                data={"prompt": prompt, "error": str(e)},
                level=logging.ERROR
            )
            raise LLMError(f"Generation failed: {str(e)}")
    
    def _validate_response(
        self,
        response: AIMessage
    ) -> AIMessage:
        """Validate response content and format"""
        if not isinstance(response, AIMessage):
            self.logger.log_interaction(
                interaction_type="llm_invalid_response_type",
                data={"response": response},
                level=logging.ERROR
            )
            raise LLMError("Invalid response type")
            
        if not response.content.strip():
            self.logger.log_interaction(
                interaction_type="llm_empty_response",
                data={"response": response},
                level=logging.ERROR
            )
            raise LLMError("Empty response content")
            
        return response