File size: 3,792 Bytes
7f5ef51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import aiohttp
import json
import logging
import torch
import faiss
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict, Any
from cryptography.fernet import Fernet
from jwt import encode, decode, ExpiredSignatureError
from datetime import datetime, timedelta
from components.adaptive_learning import AdaptiveLearningEnvironment
from components.real_time_data import RealTimeDataIntegrator
from components.sentiment_analysis import EnhancedSentimentAnalyzer
from components.self_improving_ai import SelfImprovingAI
from utils.database import Database
from utils.logger import logger
class AICore:
def __init__(self, config_path: str = "config.json"):
self.config = self._load_config(config_path)
self.models = self._initialize_models()
self.context_memory = self._initialize_vector_memory()
self.tokenizer = AutoTokenizer.from_pretrained(self.config["model_name"])
self.model = AutoModelForCausalLM.from_pretrained(self.config["model_name"])
self.http_session = aiohttp.ClientSession()
self.database = Database()
self.sentiment_analyzer = EnhancedSentimentAnalyzer()
self.data_fetcher = RealTimeDataIntegrator()
self.self_improving_ai = SelfImprovingAI()
self._encryption_key = Fernet.generate_key()
self.jwt_secret = "your_jwt_secret_key"
def _load_config(self, config_path: str) -> dict:
with open(config_path, 'r') as file:
return json.load(file)
def _initialize_models(self):
return {
"mistralai": AutoModelForCausalLM.from_pretrained(self.config["model_name"]),
"tokenizer": AutoTokenizer.from_pretrained(self.config["model_name"])
}
def _initialize_vector_memory(self):
return faiss.IndexFlatL2(768)
async def generate_response(self, query: str, user_id: int) -> Dict[str, Any]:
try:
vectorized_query = self._vectorize_query(query)
self.context_memory.add(np.array([vectorized_query]))
model_response = await self._generate_local_model_response(query)
sentiment = self.sentiment_analyzer.detailed_analysis(query)
final_response = self._apply_security_filters(model_response)
self.database.log_interaction(user_id, query, final_response)
return {
"response": final_response,
"sentiment": sentiment,
"security_level": self._evaluate_risk(final_response),
"real_time_data": self.data_fetcher.fetch_latest_data(),
"token_optimized": True
}
except Exception as e:
logger.error(f"Response generation failed: {e}")
return {"error": "Processing failed - safety protocols engaged"}
def _vectorize_query(self, query: str):
tokenized = self.tokenizer(query, return_tensors="pt")
return tokenized["input_ids"].detach().numpy()
def _apply_security_filters(self, response: str):
return response.replace("malicious", "[filtered]")
async def _generate_local_model_response(self, query: str) -> str:
inputs = self.tokenizer(query, return_tensors="pt")
outputs = self.model.generate(**inputs)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_jwt(self, user_id: int):
payload = {
"user_id": user_id,
"exp": datetime.utcnow() + timedelta(hours=1)
}
return encode(payload, self.jwt_secret, algorithm="HS256")
def verify_jwt(self, token: str):
try:
return decode(token, self.jwt_secret, algorithms=["HS256"])
except ExpiredSignatureError:
return None
|