hackergeek98 commited on
Commit
925ba7d
Β·
verified Β·
1 Parent(s): 221c9ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -65
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from transformers import (
3
  AutoModelForCausalLM,
4
  AutoTokenizer,
@@ -7,82 +8,98 @@ from transformers import (
7
  DataCollatorForLanguageModeling
8
  )
9
  from datasets import load_dataset
10
- import os
 
 
 
 
11
 
12
  def train():
13
- # Load model and tokenizer
14
- model_name = "microsoft/phi-2"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
16
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)
 
17
 
18
- # Add padding token if missing
19
- if tokenizer.pad_token is None:
20
- tokenizer.pad_token = tokenizer.eos_token
21
 
22
- # Load dataset (update paths as needed)
23
- dataset = load_dataset(
24
- "csv",
25
- data_files={
26
- "train": "eswardivi/medical_qa",
27
- "validation": "eswardivi/medical_qa"
28
- }
29
- )
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Tokenization function
32
- def tokenize_function(examples):
33
- return tokenizer(
34
- examples["text"],
35
- padding="max_length",
36
- truncation=True,
37
- max_length=256,
38
- return_tensors="pt",
39
  )
40
 
41
- # Preprocess dataset
42
- tokenized_dataset = dataset.map(
43
- tokenize_function,
44
- batched=True,
45
- remove_columns=["text"]
46
- )
47
 
48
- # Data collator
49
- data_collator = DataCollatorForLanguageModeling(
50
- tokenizer=tokenizer,
51
- mlm=False
52
- )
 
 
 
 
 
53
 
54
- # Training arguments
55
- training_args = TrainingArguments(
56
- output_dir="./phi2-cpu-results",
57
- overwrite_output_dir=True,
58
- per_device_train_batch_size=2,
59
- per_device_eval_batch_size=2,
60
- num_train_epochs=3,
61
- logging_dir="./logs",
62
- logging_steps=100,
63
- evaluation_strategy="epoch",
64
- save_strategy="epoch",
65
- fp16=False,
66
- report_to="none",
67
- )
68
 
69
- # Initialize Trainer
70
- trainer = Trainer(
71
- model=model,
72
- args=training_args,
73
- train_dataset=tokenized_dataset["train"],
74
- eval_dataset=tokenized_dataset["validation"],
75
- data_collator=data_collator,
76
- )
77
 
78
- # Start training
79
- print("Starting training...")
80
- trainer.train()
81
 
82
- # Save model
83
- trainer.save_model("./phi2-trained-model")
84
- tokenizer.save_pretrained("./phi2-trained-model")
85
- print("Training complete! Model saved.")
 
 
 
 
 
 
 
 
86
 
87
  if __name__ == "__main__":
88
- train()
 
1
  import torch
2
+ import gradio as gr
3
  from transformers import (
4
  AutoModelForCausalLM,
5
  AutoTokenizer,
 
8
  DataCollatorForLanguageModeling
9
  )
10
  from datasets import load_dataset
11
+ import logging
12
+ import sys
13
+
14
+ # Configure logging
15
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16
 
17
  def train():
18
+ try:
19
+ # Load model and tokenizer
20
+ model_name = "microsoft/phi-2"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)
23
 
24
+ # Add padding token
25
+ if tokenizer.pad_token is None:
26
+ tokenizer.pad_token = tokenizer.eos_token
27
 
28
+ # Load dataset
29
+ dataset = load_dataset(
30
+ "csv",
31
+ data_files={
32
+ "train": "data/train/data.csv",
33
+ "validation": "data/validation/data.csv"
34
+ }
35
+ )
36
+
37
+ # Tokenization function
38
+ def tokenize_function(examples):
39
+ return tokenizer(
40
+ examples["text"],
41
+ padding="max_length",
42
+ truncation=True,
43
+ max_length=256,
44
+ return_tensors="pt",
45
+ )
46
 
47
+ tokenized_dataset = dataset.map(
48
+ tokenize_function,
49
+ batched=True,
50
+ remove_columns=["text"]
 
 
 
 
51
  )
52
 
53
+ # Data collator
54
+ data_collator = DataCollatorForLanguageModeling(
55
+ tokenizer=tokenizer,
56
+ mlm=False
57
+ )
 
58
 
59
+ # Training arguments
60
+ training_args = TrainingArguments(
61
+ output_dir="./phi2-results",
62
+ per_device_train_batch_size=2,
63
+ per_device_eval_batch_size=2,
64
+ num_train_epochs=3,
65
+ logging_dir="./logs",
66
+ logging_steps=10,
67
+ fp16=False,
68
+ )
69
 
70
+ # Trainer
71
+ trainer = Trainer(
72
+ model=model,
73
+ args=training_args,
74
+ train_dataset=tokenized_dataset["train"],
75
+ eval_dataset=tokenized_dataset["validation"],
76
+ data_collator=data_collator,
77
+ )
 
 
 
 
 
 
78
 
79
+ # Start training
80
+ logging.info("Training started...")
81
+ trainer.train()
82
+ trainer.save_model("./phi2-trained-model")
83
+ logging.info("Training completed!")
84
+
85
+ return "βœ… Training succeeded! Model saved."
 
86
 
87
+ except Exception as e:
88
+ logging.error(f"Training failed: {str(e)}")
89
+ return f"❌ Training failed: {str(e)}"
90
 
91
+ # Gradio UI
92
+ with gr.Blocks(title="Phi-2 Training") as demo:
93
+ gr.Markdown("# πŸš€ Train Phi-2 on CPU")
94
+
95
+ with gr.Row():
96
+ start_btn = gr.Button("Start Training", variant="primary")
97
+ status_output = gr.Textbox(label="Status", interactive=False)
98
+
99
+ start_btn.click(
100
+ fn=train,
101
+ outputs=status_output
102
+ )
103
 
104
  if __name__ == "__main__":
105
+ demo.launch(server_name="0.0.0.0", server_port=7860)