hackergeek98 commited on
Commit
767fba0
Β·
verified Β·
1 Parent(s): 62ffb32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -116
app.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import torch
2
  import gradio as gr
 
3
  from transformers import (
4
  AutoModelForCausalLM,
5
  AutoTokenizer,
@@ -15,125 +17,18 @@ from urllib.parse import urlparse
15
  # Configure logging
16
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
17
 
18
- def parse_hf_dataset_url(url: str) -> tuple[str, str | None]:
19
- """Parse Hugging Face dataset URL into (dataset_name, config)"""
20
- parsed = urlparse(url)
21
- path_parts = parsed.path.split('/')
22
-
23
- try:
24
- # Find 'datasets' in path
25
- datasets_idx = path_parts.index('datasets')
26
- except ValueError:
27
- raise ValueError("Invalid Hugging Face dataset URL")
28
-
29
- dataset_parts = path_parts[datasets_idx+1:]
30
- dataset_name = "/".join(dataset_parts[0:2])
31
-
32
- # Try to find config (common pattern for datasets with viewer)
33
- try:
34
- viewer_idx = dataset_parts.index('viewer')
35
- config = dataset_parts[viewer_idx+1] if viewer_idx+1 < len(dataset_parts) else None
36
- except ValueError:
37
- config = None
38
-
39
- return dataset_name, config
40
 
41
  def train(dataset_url: str):
42
  try:
43
- # Parse dataset URL
44
- dataset_name, dataset_config = parse_hf_dataset_url(dataset_url)
45
- logging.info(f"Loading dataset: {dataset_name} (config: {dataset_config})")
46
-
47
- # Load model and tokenizer
48
- model_name = "microsoft/phi-2"
49
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
50
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)
51
-
52
- # Add padding token
53
- if tokenizer.pad_token is None:
54
- tokenizer.pad_token = tokenizer.eos_token
55
-
56
- # Load dataset from Hugging Face Hub
57
- dataset = load_dataset(
58
- dataset_name,
59
- dataset_config,
60
- trust_remote_code=True
61
- )
62
-
63
- # Handle dataset splits
64
- if "train" not in dataset:
65
- raise ValueError("Dataset must have a 'train' split")
66
 
67
- train_dataset = dataset["train"]
68
- eval_dataset = dataset.get("validation", None)
69
-
70
- # Split if no validation set
71
- if eval_dataset is None:
72
- split = train_dataset.train_test_split(test_size=0.1, seed=42)
73
- train_dataset = split["train"]
74
- eval_dataset = split["test"]
75
-
76
- # Tokenization function
77
- def tokenize_function(examples):
78
- return tokenizer(
79
- examples["text"], # Adjust column name as needed
80
- padding="max_length",
81
- truncation=True,
82
- max_length=256,
83
- return_tensors="pt",
84
- )
85
-
86
- # Tokenize datasets
87
- tokenized_train = train_dataset.map(
88
- tokenize_function,
89
- batched=True,
90
- remove_columns=train_dataset.column_names
91
- )
92
- tokenized_eval = eval_dataset.map(
93
- tokenize_function,
94
- batched=True,
95
- remove_columns=eval_dataset.column_names
96
- )
97
-
98
- # Data collator
99
- data_collator = DataCollatorForLanguageModeling(
100
- tokenizer=tokenizer,
101
- mlm=False
102
- )
103
-
104
- # Training arguments
105
- training_args = TrainingArguments(
106
- output_dir="./phi2-results",
107
- per_device_train_batch_size=2,
108
- per_device_eval_batch_size=2,
109
- num_train_epochs=3,
110
- logging_dir="./logs",
111
- logging_steps=10,
112
- fp16=False,
113
- )
114
-
115
- # Trainer
116
- trainer = Trainer(
117
- model=model,
118
- args=training_args,
119
- train_dataset=tokenized_train,
120
- eval_dataset=tokenized_eval,
121
- data_collator=data_collator,
122
- )
123
-
124
- # Start training
125
- logging.info("Training started...")
126
- trainer.train()
127
- trainer.save_model("./phi2-trained-model")
128
- logging.info("Training completed!")
129
-
130
- return "βœ… Training succeeded! Model saved."
131
-
132
  except Exception as e:
133
- logging.error(f"Training failed: {str(e)}")
134
- return f"❌ Training failed: {str(e)}"
135
 
136
- # Gradio UI with dataset URL input
137
  with gr.Blocks(title="Phi-2 Training") as demo:
138
  gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
139
 
@@ -147,7 +42,7 @@ with gr.Blocks(title="Phi-2 Training") as demo:
147
  status_output = gr.Textbox(label="Status", interactive=False)
148
 
149
  start_btn.click(
150
- fn=train,
151
  inputs=[dataset_url],
152
  outputs=status_output
153
  )
@@ -156,6 +51,6 @@ if __name__ == "__main__":
156
  demo.launch(
157
  server_name="0.0.0.0",
158
  server_port=7860,
159
- enable_queue=True, # Add queueing
160
- share=False # Disable public sharing
161
  )
 
1
+ # app.py
2
  import torch
3
  import gradio as gr
4
+ import threading
5
  from transformers import (
6
  AutoModelForCausalLM,
7
  AutoTokenizer,
 
17
  # Configure logging
18
  logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19
 
20
+ def parse_hf_dataset_url(url: str):
21
+ # ... (keep previous URL parsing logic) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def train(dataset_url: str):
24
  try:
25
+ # ... (keep previous training logic) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  except Exception as e:
28
+ logging.error(f"Critical error: {str(e)}")
29
+ return f"❌ Critical error: {str(e)}"
30
 
31
+ # Gradio interface
32
  with gr.Blocks(title="Phi-2 Training") as demo:
33
  gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
34
 
 
42
  status_output = gr.Textbox(label="Status", interactive=False)
43
 
44
  start_btn.click(
45
+ fn=lambda url: threading.Thread(target=train, args=(url,)).start(),
46
  inputs=[dataset_url],
47
  outputs=status_output
48
  )
 
51
  demo.launch(
52
  server_name="0.0.0.0",
53
  server_port=7860,
54
+ enable_queue=True,
55
+ share=False
56
  )