File size: 3,630 Bytes
e20d86e
925ba7d
e20d86e
 
 
 
 
 
 
 
925ba7d
 
 
 
 
ef46523
6c8c083
925ba7d
 
 
 
 
e20d86e
925ba7d
 
 
e20d86e
6c8c083
 
925ba7d
6c8c083
 
 
 
925ba7d
 
6c8c083
 
 
 
925ba7d
 
6c8c083
925ba7d
 
 
 
 
e20d86e
925ba7d
 
 
6c8c083
e20d86e
 
925ba7d
 
 
 
 
e20d86e
925ba7d
 
 
 
 
 
 
 
 
 
e20d86e
925ba7d
 
 
 
 
6c8c083
925ba7d
 
ef46523
925ba7d
 
 
 
 
 
 
e20d86e
925ba7d
 
 
ef46523
6c8c083
925ba7d
6c8c083
925ba7d
 
6c8c083
 
 
 
 
925ba7d
 
 
6c8c083
925ba7d
 
e20d86e
ef46523
925ba7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import load_dataset
import logging
import sys

# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

def train(dataset_name: str, dataset_config: str = None):
    try:
        # Load model and tokenizer
        model_name = "microsoft/phi-2"
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)

        # Add padding token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Load dataset from Hugging Face Hub
        logging.info(f"Loading dataset: {eswardivi/medical_qa} (config: {dataset_config})")
        dataset = load_dataset(
            dataset_name,
            dataset_config,  # Optional config (e.g., language for Common Voice)
            split="train+validation",  # Combine splits
            trust_remote_code=True  # Required for some datasets
        )

        # Split into train/validation
        dataset = dataset.train_test_split(test_size=0.1, seed=42)

        # Tokenization function (adjust based on dataset columns)
        def tokenize_function(examples):
            return tokenizer(
                examples["text"],  # Replace "text" with your dataset's text column
                padding="max_length",
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )

        tokenized_dataset = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=dataset["train"].column_names
        )

        # Data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False
        )

        # Training arguments
        training_args = TrainingArguments(
            output_dir="./phi2-results",
            per_device_train_batch_size=2,
            per_device_eval_batch_size=2,
            num_train_epochs=3,
            logging_dir="./logs",
            logging_steps=10,
            fp16=False,
        )

        # Trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset["train"],
            eval_dataset=tokenized_dataset["test"],
            data_collator=data_collator,
        )

        # Start training
        logging.info("Training started...")
        trainer.train()
        trainer.save_model("./phi2-trained-model")
        logging.info("Training completed!")

        return "βœ… Training succeeded! Model saved."

    except Exception as e:
        logging.error(f"Training failed: {str(e)}")
        return f"❌ Training failed: {str(e)}"

# Gradio UI with dataset input
with gr.Blocks(title="Phi-2 Training") as demo:
    gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
    
    with gr.Row():
        dataset_name = gr.Textbox(label="Dataset Name", value="mozilla-foundation/common_voice_11_0")
        dataset_config = gr.Textbox(label="Dataset Config (optional)", value="en")
    
    start_btn = gr.Button("Start Training", variant="primary")
    status_output = gr.Textbox(label="Status", interactive=False)
    
    start_btn.click(
        fn=train,
        inputs=[dataset_name, dataset_config],
        outputs=status_output
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)