fatmata commited on
Commit
c091f94
·
verified ·
1 Parent(s): 42748ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -1,16 +1,26 @@
1
- import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
4
 
5
- # Charger ton modèle fine-tuné depuis Hugging Face
6
  MODEL_NAME = "fatmata/psybot"
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
9
 
10
- # Fonction pour générer une réponse
 
 
 
 
 
 
 
11
  def generate_response(user_input):
12
  """ Génère une réponse du chatbot PsyBot """
13
  prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
 
 
14
  inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
15
 
16
  with torch.no_grad():
@@ -25,22 +35,16 @@ def generate_response(user_input):
25
  top_p=0.9,
26
  repetition_penalty=1.2
27
  )
28
-
29
  response = tokenizer.decode(output[0], skip_special_tokens=True)
 
30
 
31
  if "<|bot|>" in response:
32
  response = response.split("<|bot|>")[-1].strip()
33
 
34
  return response
35
 
36
- # Interface Gradio
37
- iface = gr.Interface(
38
- fn=generate_response, # Fonction qui génère la réponse
39
- inputs="text", # Champ d'entrée texte
40
- outputs="text", # Champ de sortie texte
41
- title="PsyBot - Chatbot Psychologue",
42
- description="Posez vos questions et obtenez une réponse de PsyBot."
43
- )
44
-
45
- # Lancer l'application
46
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
 
6
+ # Charger le modèle fine-tuné
7
  MODEL_NAME = "fatmata/psybot"
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
 
11
+ print("✅ Modèle et tokenizer chargés avec succès !") # Vérification du chargement
12
+
13
+ # Définir l'API avec FastAPI
14
+ app = FastAPI()
15
+
16
+ class UserInput(BaseModel):
17
+ text: str
18
+
19
  def generate_response(user_input):
20
  """ Génère une réponse du chatbot PsyBot """
21
  prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
22
+ print(f"🔹 Prompt envoyé au modèle : {prompt}") # Debugging
23
+
24
  inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
25
 
26
  with torch.no_grad():
 
35
  top_p=0.9,
36
  repetition_penalty=1.2
37
  )
38
+
39
  response = tokenizer.decode(output[0], skip_special_tokens=True)
40
+ print(f"🔹 Réponse brute du modèle : {response}") # Debugging
41
 
42
  if "<|bot|>" in response:
43
  response = response.split("<|bot|>")[-1].strip()
44
 
45
  return response
46
 
47
+ @app.post("/generate/")
48
+ def generate(user_input: UserInput):
49
+ response = generate_response(user_input.text)
50
+ return {"response": response}