hbofficial-1005 commited on
Commit
31d916b
·
1 Parent(s): 68f05a6

Updated Gradio App

Browse files
Files changed (1) hide show
  1. train.py +13 -7
train.py CHANGED
@@ -6,7 +6,7 @@ from datasets import load_dataset, load_metric
6
  # Load dataset
7
  dataset = load_dataset("conll2003")
8
 
9
- # Load tokenizer
10
  model_checkpoint = "dbmdz/bert-large-cased-finetuned-conll03-english"
11
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
12
 
@@ -17,7 +17,7 @@ def tokenize_and_align_labels(examples):
17
 
18
  tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)
19
 
20
- # Load model
21
  model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=9)
22
 
23
  # Training arguments
@@ -37,9 +37,10 @@ metric = load_metric("seqeval")
37
 
38
  def compute_metrics(eval_pred):
39
  predictions, labels = eval_pred
40
- return metric.compute(predictions=predictions.argmax(-1), references=labels)
 
41
 
42
- # Trainer
43
  trainer = Trainer(
44
  model=model,
45
  args=training_args,
@@ -52,10 +53,15 @@ trainer = Trainer(
52
  # Train model
53
  trainer.train()
54
 
55
- # Ensure directory exists before saving
56
  output_dir = "./models/ner_model"
57
  os.makedirs(output_dir, exist_ok=True)
58
 
59
- # Save model
60
- trainer.save_model(output_dir)
 
 
 
 
 
61
  tokenizer.save_pretrained(output_dir)
 
6
  # Load dataset
7
  dataset = load_dataset("conll2003")
8
 
9
+ # Load tokenizer and model checkpoint
10
  model_checkpoint = "dbmdz/bert-large-cased-finetuned-conll03-english"
11
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
12
 
 
17
 
18
  tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)
19
 
20
+ # Load model for token classification (with specified number of labels)
21
  model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=9)
22
 
23
  # Training arguments
 
37
 
38
  def compute_metrics(eval_pred):
39
  predictions, labels = eval_pred
40
+ predictions = predictions.argmax(-1)
41
+ return metric.compute(predictions=predictions, references=labels)
42
 
43
+ # Initialize Trainer
44
  trainer = Trainer(
45
  model=model,
46
  args=training_args,
 
53
  # Train model
54
  trainer.train()
55
 
56
+ # Ensure the output directory exists
57
  output_dir = "./models/ner_model"
58
  os.makedirs(output_dir, exist_ok=True)
59
 
60
+ # Make sure the model config has a model_type key.
61
+ # Since we started with a BERT checkpoint, we set it to "bert".
62
+ if not hasattr(model.config, "model_type") or not model.config.model_type:
63
+ model.config.model_type = "bert"
64
+
65
+ # Save the trained model and tokenizer
66
+ model.save_pretrained(output_dir)
67
  tokenizer.save_pretrained(output_dir)