terapyon commited on
Commit
097953f
·
1 Parent(s): 26879ed

modify a bit

Browse files
Files changed (1) hide show
  1. inference.py +2 -1
inference.py CHANGED
@@ -38,7 +38,7 @@ device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")
38
 
39
  #BERTモデルの定義
40
  class BertClassifier(nn.Module, PyTorchModelHubMixin):
41
- def __init__(self, cls_num: int):
42
  super().__init__()
43
  self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
44
  self.fc = nn.Linear(768, cls_num, bias=True)
@@ -124,6 +124,7 @@ def make_traind_model():
124
  # trained_models.append(trained_model)
125
  model_name = MODEL_BASE + str(k)
126
  trained_model = BertClassifier.from_pretrained(model_name).to(device)
 
127
  trained_models.append(trained_model)
128
  return trained_models
129
 
 
38
 
39
  #BERTモデルの定義
40
  class BertClassifier(nn.Module, PyTorchModelHubMixin):
41
+ def __init__(self, cls_num: int = 1):
42
  super().__init__()
43
  self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
44
  self.fc = nn.Linear(768, cls_num, bias=True)
 
124
  # trained_models.append(trained_model)
125
  model_name = MODEL_BASE + str(k)
126
  trained_model = BertClassifier.from_pretrained(model_name).to(device)
127
+ print(f"Got model {model_name}")
128
  trained_models.append(trained_model)
129
  return trained_models
130