hackergeek98 commited on
Commit
e20d86e
·
verified ·
1 Parent(s): 0c5c5b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -17
app.py CHANGED
@@ -1,27 +1,88 @@
 
 
 
 
 
 
 
 
 
1
  import os
2
- from autotrain import logger
3
- from autotrain.trainers.common import ALLOW_REMOTE_CODE
4
- from autotrain.trainers.text_generation import LLMTrainingParams, LLMTrainer
5
 
6
  def train():
7
-
8
-
9
- # Define training parameters
10
- params = LLMTrainingParams(
11
- model_name="microsoft/phi-4", # Replace with your model
12
- data_path="lavita/medical-qa-datasets",
13
- project_name="phi4-training",
14
- learning_rate=2e-5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  num_train_epochs=3,
16
- batch_size=2,
17
- fp16=True,
18
- push_to_hub=True,
19
- repo_id="hackergeek98/phi4-trained",
 
 
20
  )
21
 
22
- # Initialize and run trainer
23
- trainer = LLMTrainer(params=params)
 
 
 
 
 
 
 
 
 
24
  trainer.train()
25
 
 
 
 
 
 
26
  if __name__ == "__main__":
27
  train()
 
1
+ import torch
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ TrainingArguments,
6
+ Trainer,
7
+ DataCollatorForLanguageModeling
8
+ )
9
+ from datasets import load_dataset
10
  import os
 
 
 
11
 
12
  def train():
13
+ # Load model and tokenizer
14
+ model_name = "microsoft/phi-2"
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
16
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)
17
+
18
+ # Add padding token if missing
19
+ if tokenizer.pad_token is None:
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+
22
+ # Load dataset (update paths as needed)
23
+ dataset = load_dataset(
24
+ "csv",
25
+ data_files={
26
+ "train": "eswardivi/medical_qa",
27
+ "validation": "eswardivi/medical_qa"
28
+ }
29
+ )
30
+
31
+ # Tokenization function
32
+ def tokenize_function(examples):
33
+ return tokenizer(
34
+ examples["text"],
35
+ padding="max_length",
36
+ truncation=True,
37
+ max_length=256,
38
+ return_tensors="pt",
39
+ )
40
+
41
+ # Preprocess dataset
42
+ tokenized_dataset = dataset.map(
43
+ tokenize_function,
44
+ batched=True,
45
+ remove_columns=["text"]
46
+ )
47
+
48
+ # Data collator
49
+ data_collator = DataCollatorForLanguageModeling(
50
+ tokenizer=tokenizer,
51
+ mlm=False
52
+ )
53
+
54
+ # Training arguments
55
+ training_args = TrainingArguments(
56
+ output_dir="./phi2-cpu-results",
57
+ overwrite_output_dir=True,
58
+ per_device_train_batch_size=2,
59
+ per_device_eval_batch_size=2,
60
  num_train_epochs=3,
61
+ logging_dir="./logs",
62
+ logging_steps=100,
63
+ evaluation_strategy="epoch",
64
+ save_strategy="epoch",
65
+ fp16=False,
66
+ report_to="none",
67
  )
68
 
69
+ # Initialize Trainer
70
+ trainer = Trainer(
71
+ model=model,
72
+ args=training_args,
73
+ train_dataset=tokenized_dataset["train"],
74
+ eval_dataset=tokenized_dataset["validation"],
75
+ data_collator=data_collator,
76
+ )
77
+
78
+ # Start training
79
+ print("Starting training...")
80
  trainer.train()
81
 
82
+ # Save model
83
+ trainer.save_model("./phi2-trained-model")
84
+ tokenizer.save_pretrained("./phi2-trained-model")
85
+ print("Training complete! Model saved.")
86
+
87
  if __name__ == "__main__":
88
  train()