File size: 1,938 Bytes
a3db5d7
0293ad8
7090deb
 
b800391
0293ad8
 
 
b800391
a3db5d7
 
b800391
381aae1
b800391
 
a3db5d7
 
 
7090deb
b800391
 
 
 
0293ad8
381aae1
 
7090deb
381aae1
b800391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7090deb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from pydantic import BaseModel

app = FastAPI()

# 📌 Définir un dossier cache accessible
os.environ["TRANSFORMERS_CACHE"] = "/tmp"

# 📌 Charger le modèle et le tokenizer avec cache local
MODEL_NAME = "fatmata/psybot"
local_dir = "/tmp/model"
os.makedirs(local_dir, exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=local_dir)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=local_dir, torch_dtype=torch.float32)

# 📌 Définition du modèle pour recevoir l'entrée utilisateur
class PromptRequest(BaseModel):
    prompt: str

@app.get("/")
def home():
    return {"message": "Bienvenue sur l'API PsyBot !"}

@app.post("/generate")
def generate_text(request: PromptRequest):
    """ Génère une réponse du chatbot PsyBot """
    user_input = request.prompt

    # 📌 Ajouter les balises pour respecter le format du modèle
    formatted_prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"

    # 📌 Encodage du texte et génération de la réponse
    inputs = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)

    with torch.no_grad():
        output = model.generate(
            inputs,
            max_new_tokens=100,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,  # Activation du sampling
            temperature=0.7,  # Génération plus naturelle
            top_k=50,
            top_p=0.9,
            repetition_penalty=1.2  # Réduction de la répétition
        )

    response = tokenizer.decode(output[0], skip_special_tokens=True)

    # 🔍 Nettoyage : récupérer uniquement la réponse du bot après <|bot|>
    if "<|bot|>" in response:
        response = response.split("<|bot|>")[-1].strip()

    return {"response": response}