hbofficial-1005 commited on
Commit
9756ad2
·
1 Parent(s): 31d916b

Updated Gradio App

Browse files
Files changed (1) hide show
  1. train.py +18 -13
train.py CHANGED
@@ -1,28 +1,35 @@
1
  import os
2
- import torch
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
4
  from datasets import load_dataset, load_metric
5
 
6
- # Load dataset
 
 
 
 
 
 
 
7
  dataset = load_dataset("conll2003")
8
 
9
- # Load tokenizer and model checkpoint
10
  model_checkpoint = "dbmdz/bert-large-cased-finetuned-conll03-english"
11
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
12
 
13
- # Tokenize the dataset
14
  def tokenize_and_align_labels(examples):
15
  tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
16
  return tokenized_inputs
17
 
18
  tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)
19
 
20
- # Load model for token classification (with specified number of labels)
21
  model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=9)
22
 
23
- # Training arguments
24
  training_args = TrainingArguments(
25
- output_dir="./models/ner_model",
26
  evaluation_strategy="epoch",
27
  save_strategy="epoch",
28
  learning_rate=2e-5,
@@ -32,7 +39,7 @@ training_args = TrainingArguments(
32
  weight_decay=0.01,
33
  )
34
 
35
- # Load metric
36
  metric = load_metric("seqeval")
37
 
38
  def compute_metrics(eval_pred):
@@ -40,7 +47,7 @@ def compute_metrics(eval_pred):
40
  predictions = predictions.argmax(-1)
41
  return metric.compute(predictions=predictions, references=labels)
42
 
43
- # Initialize Trainer
44
  trainer = Trainer(
45
  model=model,
46
  args=training_args,
@@ -50,15 +57,13 @@ trainer = Trainer(
50
  compute_metrics=compute_metrics,
51
  )
52
 
53
- # Train model
54
  trainer.train()
55
 
56
  # Ensure the output directory exists
57
- output_dir = "./models/ner_model"
58
  os.makedirs(output_dir, exist_ok=True)
59
 
60
- # Make sure the model config has a model_type key.
61
- # Since we started with a BERT checkpoint, we set it to "bert".
62
  if not hasattr(model.config, "model_type") or not model.config.model_type:
63
  model.config.model_type = "bert"
64
 
 
1
  import os
2
+ import shutil
3
  from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
4
  from datasets import load_dataset, load_metric
5
 
6
+ # Define output directory
7
+ output_dir = "./models/ner_model"
8
+
9
+ # Remove the old model directory (if exists) to ensure a clean save
10
+ if os.path.exists(output_dir):
11
+ shutil.rmtree(output_dir)
12
+
13
+ # Load the CoNLL2003 dataset
14
  dataset = load_dataset("conll2003")
15
 
16
+ # Load the pretrained tokenizer and model checkpoint
17
  model_checkpoint = "dbmdz/bert-large-cased-finetuned-conll03-english"
18
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
19
 
20
+ # Tokenize the dataset; note that we use `is_split_into_words=True`
21
  def tokenize_and_align_labels(examples):
22
  tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
23
  return tokenized_inputs
24
 
25
  tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)
26
 
27
+ # Load the model for token classification, specifying number of labels
28
  model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=9)
29
 
30
+ # Define training arguments
31
  training_args = TrainingArguments(
32
+ output_dir=output_dir,
33
  evaluation_strategy="epoch",
34
  save_strategy="epoch",
35
  learning_rate=2e-5,
 
39
  weight_decay=0.01,
40
  )
41
 
42
+ # Load evaluation metric
43
  metric = load_metric("seqeval")
44
 
45
  def compute_metrics(eval_pred):
 
47
  predictions = predictions.argmax(-1)
48
  return metric.compute(predictions=predictions, references=labels)
49
 
50
+ # Initialize the Trainer
51
  trainer = Trainer(
52
  model=model,
53
  args=training_args,
 
57
  compute_metrics=compute_metrics,
58
  )
59
 
60
+ # Train the model
61
  trainer.train()
62
 
63
  # Ensure the output directory exists
 
64
  os.makedirs(output_dir, exist_ok=True)
65
 
66
+ # Explicitly set model_type in the configuration if it is missing or empty.
 
67
  if not hasattr(model.config, "model_type") or not model.config.model_type:
68
  model.config.model_type = "bert"
69