Spaces:
Sleeping
Sleeping
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, Trainer, TrainingArguments | |
from datasets import load_dataset | |
import torch | |
import gradio as gr | |
# ✅ Charger le modèle et le tokenizer | |
model_name = "facebook/blenderbot-400M-distill" | |
tokenizer = BlenderbotTokenizer.from_pretrained(model_name) | |
model = BlenderbotForConditionalGeneration.from_pretrained(model_name) | |
# ✅ Charger le dataset "fka/awesome-chatgpt-prompts" | |
dataset = load_dataset("fka/awesome-chatgpt-prompts") | |
# ✅ Préparer les données en adaptant les colonnes disponibles | |
def preprocess_function(examples): | |
inputs = examples["act"] | |
targets = examples["prompt"] | |
# Tokenisation avec padding et tronquage | |
model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length", return_tensors="pt") | |
labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length", return_tensors="pt") | |
model_inputs["labels"] = labels["input_ids"] | |
return model_inputs | |
# Appliquer le prétraitement en retirant les colonnes inutiles | |
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["act", "prompt"]) | |
# ✅ Vérifier si un fine-tuning est nécessaire | |
do_training = False # Change à True pour entraîner le modèle | |
if do_training: | |
training_args = TrainingArguments( | |
output_dir="./results", | |
evaluation_strategy="epoch", | |
learning_rate=5e-5, | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=8, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
save_total_limit=2, | |
push_to_hub=False, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset["train"], | |
eval_dataset=tokenized_dataset["test"], | |
) | |
trainer.train() | |
# ✅ Pipeline d'inférence pour le chatbot | |
def chatbot_response(user_input, history=None): # Accepte l'historique pour éviter l'erreur | |
inputs = tokenizer(user_input, return_tensors="pt", max_length=128, truncation=True) | |
with torch.no_grad(): | |
output = model.generate(**inputs, max_length=128, num_beams=5, early_stopping=True) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
# ✅ Interface Gradio pour dialoguer avec le chatbot | |
chatbot = gr.ChatInterface( | |
fn=chatbot_response, | |
title="Chatbot Service Après-Vente", | |
description="Posez vos questions sur le service après-vente et obtenez des réponses instantanées.", | |
) | |
chatbot.launch() | |