Spaces:
Runtime error
Runtime error
import os | |
from fastapi import FastAPI | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from pydantic import BaseModel | |
app = FastAPI() | |
# 📌 Définir un dossier cache accessible | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
# 📌 Charger le modèle et le tokenizer avec cache local | |
MODEL_NAME = "fatmata/psybot" | |
local_dir = "/tmp/model" | |
os.makedirs(local_dir, exist_ok=True) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=local_dir) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=local_dir, torch_dtype=torch.float32) | |
# 📌 Définition du modèle pour recevoir l'entrée utilisateur | |
class PromptRequest(BaseModel): | |
prompt: str | |
def home(): | |
return {"message": "Bienvenue sur l'API PsyBot !"} | |
def generate_text(request: PromptRequest): | |
""" Génère une réponse du chatbot PsyBot """ | |
user_input = request.prompt | |
# 📌 Ajouter les balises pour respecter le format du modèle | |
formatted_prompt = f"<|startoftext|><|user|> {user_input} <|bot|>" | |
# 📌 Encodage du texte et génération de la réponse | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device) | |
with torch.no_grad(): | |
output = model.generate( | |
inputs, | |
max_new_tokens=100, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
do_sample=True, # Activation du sampling | |
temperature=0.7, # Génération plus naturelle | |
top_k=50, | |
top_p=0.9, | |
repetition_penalty=1.2 # Réduction de la répétition | |
) | |
response = tokenizer.decode(output[0], skip_special_tokens=True) | |
# 🔍 Nettoyage : récupérer uniquement la réponse du bot après <|bot|> | |
if "<|bot|>" in response: | |
response = response.split("<|bot|>")[-1].strip() | |
return {"response": response} | |