Karpernik commited on
Commit
e687136
·
verified ·
1 Parent(s): 3c4bffd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -60,10 +60,9 @@ def load_model():
60
  model_name = 'model'
61
  cat_count = 358
62
 
63
- checkpoint = torch.load(os.path.join(chkp_folder, f"{model_name}.pt"), weights_only=False)
64
 
65
  # Создаём те же классы, что и внутри чекпоинта
66
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
67
 
68
  model_ = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-cased', num_labels=cat_count).to(device)
69
 
@@ -77,7 +76,7 @@ def load_model():
77
  model, optimizer = create_model_and_optimizer(model_)
78
 
79
  # Загружаем состояния из чекпоинта
80
- model.load_state_dict(checkpoint['model_state_dict'], map_location=torch.device('cpu'))
81
  ind_to_cat = checkpoint['ind_to_cat']
82
  tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-cased')
83
  return model, tokenizer, ind_to_cat
 
60
  model_name = 'model'
61
  cat_count = 358
62
 
63
+ checkpoint = torch.load(os.path.join(chkp_folder, f"{model_name}.pt"), weights_only=False, map_location=torch.device('cpu'))
64
 
65
  # Создаём те же классы, что и внутри чекпоинта
 
66
 
67
  model_ = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-cased', num_labels=cat_count).to(device)
68
 
 
76
  model, optimizer = create_model_and_optimizer(model_)
77
 
78
  # Загружаем состояния из чекпоинта
79
+ model.load_state_dict(checkpoint['model_state_dict'])
80
  ind_to_cat = checkpoint['ind_to_cat']
81
  tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-cased')
82
  return model, tokenizer, ind_to_cat