import os from fastapi import FastAPI from transformers import AutoModelForCausalLM, AutoTokenizer import torch from pydantic import BaseModel app = FastAPI() # 📌 Définir un dossier cache accessible os.environ["TRANSFORMERS_CACHE"] = "/tmp" # 📌 Charger le modèle et le tokenizer avec cache local MODEL_NAME = "fatmata/psybot" local_dir = "/tmp/model" os.makedirs(local_dir, exist_ok=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=local_dir) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=local_dir, torch_dtype=torch.float32) # 📌 Définition du modèle pour recevoir l'entrée utilisateur class PromptRequest(BaseModel): prompt: str @app.get("/") def home(): return {"message": "Bienvenue sur l'API PsyBot !"} @app.post("/generate") def generate_text(request: PromptRequest): """ Génère une réponse du chatbot PsyBot """ user_input = request.prompt # 📌 Ajouter les balises pour respecter le format du modèle formatted_prompt = f"<|startoftext|><|user|> {user_input} <|bot|>" # 📌 Encodage du texte et génération de la réponse inputs = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device) with torch.no_grad(): output = model.generate( inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=True, # Activation du sampling temperature=0.7, # Génération plus naturelle top_k=50, top_p=0.9, repetition_penalty=1.2 # Réduction de la répétition ) response = tokenizer.decode(output[0], skip_special_tokens=True) # 🔍 Nettoyage : récupérer uniquement la réponse du bot après <|bot|> if "<|bot|>" in response: response = response.split("<|bot|>")[-1].strip() return {"response": response}