psybot-api / app.py
fatmata's picture
Update app.py
b800391 verified
import os
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from pydantic import BaseModel
app = FastAPI()
# 📌 Définir un dossier cache accessible
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
# 📌 Charger le modèle et le tokenizer avec cache local
MODEL_NAME = "fatmata/psybot"
local_dir = "/tmp/model"
os.makedirs(local_dir, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=local_dir)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=local_dir, torch_dtype=torch.float32)
# 📌 Définition du modèle pour recevoir l'entrée utilisateur
class PromptRequest(BaseModel):
prompt: str
@app.get("/")
def home():
return {"message": "Bienvenue sur l'API PsyBot !"}
@app.post("/generate")
def generate_text(request: PromptRequest):
""" Génère une réponse du chatbot PsyBot """
user_input = request.prompt
# 📌 Ajouter les balises pour respecter le format du modèle
formatted_prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
# 📌 Encodage du texte et génération de la réponse
inputs = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
with torch.no_grad():
output = model.generate(
inputs,
max_new_tokens=100,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True, # Activation du sampling
temperature=0.7, # Génération plus naturelle
top_k=50,
top_p=0.9,
repetition_penalty=1.2 # Réduction de la répétition
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
# 🔍 Nettoyage : récupérer uniquement la réponse du bot après <|bot|>
if "<|bot|>" in response:
response = response.split("<|bot|>")[-1].strip()
return {"response": response}