fatmata commited on
Commit
b800391
·
verified ·
1 Parent(s): a10a707

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -8
app.py CHANGED
@@ -2,27 +2,57 @@ import os
2
  from fastapi import FastAPI
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
5
 
6
  app = FastAPI()
7
 
8
- # Définir un dossier cache accessible
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
10
 
11
- # Charger le modèle et le tokenizer depuis Hugging Face avec cache local
12
  MODEL_NAME = "fatmata/psybot"
13
- local_dir = "/tmp/model" # Changer le dossier vers /tmp/model
14
- os.makedirs(local_dir, exist_ok=True) # Crée le dossier si nécessaire
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=local_dir)
17
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=local_dir, torch_dtype=torch.float32)
18
 
 
 
 
 
19
  @app.get("/")
20
  def home():
21
  return {"message": "Bienvenue sur l'API PsyBot !"}
22
 
23
  @app.post("/generate")
24
- def generate_text(prompt: str):
25
- inputs = tokenizer(prompt, return_tensors="pt")
26
- outputs = model.generate(**inputs, max_length=100)
27
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return {"response": response}
 
2
  from fastapi import FastAPI
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ from pydantic import BaseModel
6
 
7
  app = FastAPI()
8
 
9
+ # 📌 Définir un dossier cache accessible
10
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
11
 
12
+ # 📌 Charger le modèle et le tokenizer avec cache local
13
  MODEL_NAME = "fatmata/psybot"
14
+ local_dir = "/tmp/model"
15
+ os.makedirs(local_dir, exist_ok=True)
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=local_dir)
18
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=local_dir, torch_dtype=torch.float32)
19
 
20
+ # 📌 Définition du modèle pour recevoir l'entrée utilisateur
21
+ class PromptRequest(BaseModel):
22
+ prompt: str
23
+
24
  @app.get("/")
25
  def home():
26
  return {"message": "Bienvenue sur l'API PsyBot !"}
27
 
28
  @app.post("/generate")
29
+ def generate_text(request: PromptRequest):
30
+ """ Génère une réponse du chatbot PsyBot """
31
+ user_input = request.prompt
32
+
33
+ # 📌 Ajouter les balises pour respecter le format du modèle
34
+ formatted_prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
35
+
36
+ # 📌 Encodage du texte et génération de la réponse
37
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
38
+
39
+ with torch.no_grad():
40
+ output = model.generate(
41
+ inputs,
42
+ max_new_tokens=100,
43
+ pad_token_id=tokenizer.eos_token_id,
44
+ eos_token_id=tokenizer.eos_token_id,
45
+ do_sample=True, # Activation du sampling
46
+ temperature=0.7, # Génération plus naturelle
47
+ top_k=50,
48
+ top_p=0.9,
49
+ repetition_penalty=1.2 # Réduction de la répétition
50
+ )
51
+
52
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
53
+
54
+ # 🔍 Nettoyage : récupérer uniquement la réponse du bot après <|bot|>
55
+ if "<|bot|>" in response:
56
+ response = response.split("<|bot|>")[-1].strip()
57
+
58
  return {"response": response}