from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, Trainer, TrainingArguments from datasets import load_dataset import torch import gradio as gr # ✅ Charger le modèle et le tokenizer model_name = "facebook/blenderbot-400M-distill" tokenizer = BlenderbotTokenizer.from_pretrained(model_name) model = BlenderbotForConditionalGeneration.from_pretrained(model_name) # ✅ Charger le dataset "fka/awesome-chatgpt-prompts" dataset = load_dataset("fka/awesome-chatgpt-prompts") # ✅ Préparer les données en adaptant les colonnes disponibles def preprocess_function(examples): inputs = examples["act"] targets = examples["prompt"] # Tokenisation avec padding et tronquage model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length", return_tensors="pt") labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length", return_tensors="pt") model_inputs["labels"] = labels["input_ids"] return model_inputs # Appliquer le prétraitement en retirant les colonnes inutiles tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["act", "prompt"]) # ✅ Vérifier si un fine-tuning est nécessaire do_training = False # Change à True pour entraîner le modèle if do_training: training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, weight_decay=0.01, save_total_limit=2, push_to_hub=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["test"], ) trainer.train() # ✅ Pipeline d'inférence pour le chatbot def chatbot_response(user_input, history=None): # Accepte l'historique pour éviter l'erreur inputs = tokenizer(user_input, return_tensors="pt", max_length=128, truncation=True) with torch.no_grad(): output = model.generate(**inputs, max_length=128, num_beams=5, early_stopping=True) return tokenizer.decode(output[0], skip_special_tokens=True) # ✅ Interface Gradio pour dialoguer avec le chatbot chatbot = gr.ChatInterface( fn=chatbot_response, title="Chatbot Service Après-Vente", description="Posez vos questions sur le service après-vente et obtenez des réponses instantanées.", ) chatbot.launch()