File size: 1,747 Bytes
57951e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from trl import SFTTrainer
import torch
from datasets import load_dataset

# Load the base model (TinyLlama)
model_name = "NousResearch/Hermes-3-Llama-3.2-3B"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Prepare model for QLoRA
model = prepare_model_for_kbit_training(model)

# LoRA Configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Load dataset
dataset = load_dataset("json", data_files="sevaai_faq.json")
from datasets import load_dataset

# Load dataset from your JSON file
dataset = load_dataset("json", data_files="sevaai_faq.json")

# Rename the "output" column to "text" so SFTTrainer can find it
dataset["train"] = dataset["train"].rename_column("output", "text")

# Training arguments
training_args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    learning_rate=2e-4,
    logging_steps=10,
    output_dir="./nirmaya",
    save_steps=1000,
    save_total_limit=2,
    optim="adamw_torch"
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    peft_config=lora_config,

    tokenizer=tokenizer,
    args=training_args
)



# Train the model
trainer.train()

# Save fine-tuned model
trainer.save_model("./nirmaya")
print("Fine-tuning complete! Model saved to ./nirmaya")