terapyon commited on
Commit
9968acc
·
1 Parent(s): 097953f

modify get model with token

Browse files
Files changed (1) hide show
  1. inference.py +4 -3
inference.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
  from pathlib import Path
3
  from typing import Generator
@@ -13,7 +14,7 @@ from huggingface_hub import PyTorchModelHubMixin # type: ignore
13
  from scipy import stats # type: ignore
14
  from sudachipy import dictionary, tokenizer # type: ignore
15
 
16
- # from transformers import AutoModel
17
 
18
  MODELS_PATH = Path(__file__).parent / "saved_model"
19
  # model_base_path = MODELS_PATH / "two_class"
@@ -38,7 +39,7 @@ device = torch.device(f"cuda:{gpu}" if gpu>=0 else "cpu")
38
 
39
  #BERTモデルの定義
40
  class BertClassifier(nn.Module, PyTorchModelHubMixin):
41
- def __init__(self, cls_num: int = 1):
42
  super().__init__()
43
  self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
44
  self.fc = nn.Linear(768, cls_num, bias=True)
@@ -123,7 +124,7 @@ def make_traind_model():
123
  # trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
124
  # trained_models.append(trained_model)
125
  model_name = MODEL_BASE + str(k)
126
- trained_model = BertClassifier.from_pretrained(model_name).to(device)
127
  print(f"Got model {model_name}")
128
  trained_models.append(trained_model)
129
  return trained_models
 
1
+ import os
2
  import re
3
  from pathlib import Path
4
  from typing import Generator
 
14
  from scipy import stats # type: ignore
15
  from sudachipy import dictionary, tokenizer # type: ignore
16
 
17
+ HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
18
 
19
  MODELS_PATH = Path(__file__).parent / "saved_model"
20
  # model_base_path = MODELS_PATH / "two_class"
 
39
 
40
  #BERTモデルの定義
41
  class BertClassifier(nn.Module, PyTorchModelHubMixin):
42
+ def __init__(self, cls_num: int):
43
  super().__init__()
44
  self.bert = T.BertModel.from_pretrained(bert_model_name, output_attentions=True)
45
  self.fc = nn.Linear(768, cls_num, bias=True)
 
124
  # trained_model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
125
  # trained_models.append(trained_model)
126
  model_name = MODEL_BASE + str(k)
127
+ trained_model = BertClassifier.from_pretrained(model_name, token=HF_AUTH_TOKEN).to(device)
128
  print(f"Got model {model_name}")
129
  trained_models.append(trained_model)
130
  return trained_models