fatmata commited on
Commit
42748ef
·
verified ·
1 Parent(s): 7795d64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -4
app.py CHANGED
@@ -1,9 +1,46 @@
1
  import gradio as gr
 
 
2
 
3
- def chatbot_response(message):
4
- return f"Tu as dit : {message}"
 
 
5
 
6
- # Désactiver le flagging pour éviter l'erreur de permission
7
- iface = gr.Interface(fn=chatbot_response, inputs="text", outputs="text", allow_flagging="never")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ # Charger ton modèle fine-tuné depuis Hugging Face
6
+ MODEL_NAME = "fatmata/psybot"
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
9
 
10
+ # Fonction pour générer une réponse
11
+ def generate_response(user_input):
12
+ """ Génère une réponse du chatbot PsyBot """
13
+ prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
14
+ inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
15
+
16
+ with torch.no_grad():
17
+ output = model.generate(
18
+ inputs,
19
+ max_new_tokens=100,
20
+ pad_token_id=tokenizer.eos_token_id,
21
+ eos_token_id=tokenizer.eos_token_id,
22
+ do_sample=True,
23
+ temperature=0.7,
24
+ top_k=50,
25
+ top_p=0.9,
26
+ repetition_penalty=1.2
27
+ )
28
+
29
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
30
 
31
+ if "<|bot|>" in response:
32
+ response = response.split("<|bot|>")[-1].strip()
33
+
34
+ return response
35
+
36
+ # Interface Gradio
37
+ iface = gr.Interface(
38
+ fn=generate_response, # Fonction qui génère la réponse
39
+ inputs="text", # Champ d'entrée texte
40
+ outputs="text", # Champ de sortie texte
41
+ title="PsyBot - Chatbot Psychologue",
42
+ description="Posez vos questions et obtenez une réponse de PsyBot."
43
+ )
44
+
45
+ # Lancer l'application
46
  iface.launch(server_name="0.0.0.0", server_port=7860)