sbstagiare commited on
Commit
be54a79
·
verified ·
1 Parent(s): 57a6f1a

Create trainer

Browse files
Files changed (1) hide show
  1. trainer +60 -0
trainer ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
4
+ from peft import LoraConfig, get_peft_model
5
+ from trl import SFTTrainer
6
+ from datasets import load_dataset
7
+
8
+ # Charger le modèle et le tokenizer
9
+ model_name = "mistralai/Mistral-7B-v0.1" # Tu peux changer pour DeepSeek R1 7B/8B
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_name,
13
+ torch_dtype=torch.float16,
14
+ load_in_4bit=True, # QLoRA
15
+ device_map="auto"
16
+ )
17
+
18
+ # Charger le dataset (peut être un dataset HF ou un CSV local)
19
+ dataset = load_dataset("your_dataset") # Remplace par ton dataset HF
20
+
21
+ # Configurer LoRA (adapté pour QLoRA)
22
+ lora_config = LoraConfig(
23
+ r=16,
24
+ lora_alpha=32,
25
+ target_modules=["q_proj", "v_proj"],
26
+ lora_dropout=0.05,
27
+ bias="none",
28
+ task_type="CAUSAL_LM"
29
+ )
30
+ model = get_peft_model(model, lora_config)
31
+
32
+ # Arguments d'entraînement
33
+ training_args = TrainingArguments(
34
+ output_dir="./results",
35
+ per_device_train_batch_size=2,
36
+ gradient_accumulation_steps=4,
37
+ num_train_epochs=3,
38
+ learning_rate=2e-4,
39
+ fp16=True,
40
+ optim="paged_adamw_8bit",
41
+ logging_dir="./logs",
42
+ save_strategy="epoch"
43
+ )
44
+
45
+ # Fine-tuning avec SFTTrainer
46
+ trainer = SFTTrainer(
47
+ model=model,
48
+ train_dataset=dataset["train"],
49
+ dataset_text_field="question", # Adapter selon le format du dataset
50
+ peft_config=lora_config,
51
+ args=training_args
52
+ )
53
+
54
+ # Interface Gradio
55
+ def train():
56
+ trainer.train()
57
+ model.push_to_hub("your_hf_username/fine-tuned-model")
58
+ return "Fine-tuning terminé et modèle uploadé sur Hugging Face !"
59
+
60
+ gr.Interface(fn=train, inputs=[], outputs="text").launch()