File size: 1,428 Bytes
42748ef
5c99e03
c091f94
283e868
c091f94
42748ef
 
 
283e868
5c99e03
c091f94
42748ef
 
 
c091f94
 
42748ef
 
 
 
 
 
 
 
 
 
 
 
 
 
c091f94
42748ef
c091f94
283e868
42748ef
 
 
 
 
5c99e03
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# Charger le modèle fine-tuné
MODEL_NAME = "fatmata/psybot"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

print("✅ Modèle chargé avec succès :", model.config)  # Debugging

def generate_response(user_input):
    """ Génère une réponse du chatbot PsyBot """
    prompt = f"<|startoftext|><|user|> {user_input} <|bot|>"
    print(f"🔹 Prompt envoyé au modèle : {prompt}")  # Debugging
    
    inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    
    with torch.no_grad():
        output = model.generate(
            inputs,
            max_new_tokens=100,
            pad_token_id=tokenizer.eos_token_id,  
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.9,
            repetition_penalty=1.2
        )

    response = tokenizer.decode(output[0], skip_special_tokens=True)
    print(f"🔹 Réponse brute du modèle : {response}")  # Debugging

    if "<|bot|>" in response:
        response = response.split("<|bot|>")[-1].strip()
    
    return response

# Interface Gradio avec le bon modèle
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text")

iface.launch(server_name="0.0.0.0", server_port=7860)