Spaces:
Runtime error
Runtime error
File size: 1,938 Bytes
a3db5d7 0293ad8 7090deb b800391 0293ad8 b800391 a3db5d7 b800391 381aae1 b800391 a3db5d7 7090deb b800391 0293ad8 381aae1 7090deb 381aae1 b800391 7090deb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import os
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from pydantic import BaseModel
app = FastAPI()
# 📌 Définir un dossier cache accessible
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
# 📌 Charger le modèle et le tokenizer avec cache local
MODEL_NAME = "fatmata/psybot"
local_dir = "/tmp/model"
os.makedirs(local_dir, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=local_dir)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=local_dir, torch_dtype=torch.float32)
# 📌 Définition du modèle pour recevoir l'entrée utilisateur
class PromptRequest(BaseModel):
prompt: str
@app.get("/")
def home():
return {"message": "Bienvenue sur l'API PsyBot !"}
@app.post("/generate")
def generate_text(request: PromptRequest):
""" Génère une réponse du chatbot PsyBot """
user_input = request.prompt
# 📌 Ajouter les balises pour respecter le format du modèle
formatted_prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
# 📌 Encodage du texte et génération de la réponse
inputs = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
with torch.no_grad():
output = model.generate(
inputs,
max_new_tokens=100,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True, # Activation du sampling
temperature=0.7, # Génération plus naturelle
top_k=50,
top_p=0.9,
repetition_penalty=1.2 # Réduction de la répétition
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
# 🔍 Nettoyage : récupérer uniquement la réponse du bot après <|bot|>
if "<|bot|>" in response:
response = response.split("<|bot|>")[-1].strip()
return {"response": response}
|