amiguel commited on
Commit
e8c22f8
·
verified ·
1 Parent(s): e6fe399

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -3,11 +3,10 @@ import torch
3
  import pandas as pd
4
  import PyPDF2
5
  import pickle
6
- import os
7
- from transformers import AutoTokenizer
8
- from huggingface_hub import login
9
  import time
10
- from ch09util import subsequent_mask # Ensure ch09util.py is available
11
 
12
  # Device setup
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -79,11 +78,10 @@ def load_model_and_resources(hf_token):
79
  token=hf_token
80
  )
81
 
82
- # Load model
83
- from transformers import PreTrainedModel, PretrainedConfig
84
  class TransformerConfig(PretrainedConfig):
85
  model_type = "custom_transformer"
86
- def __init__(self, src_vocab_size, tgt_vocab_size, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs):
87
  super().__init__(**kwargs)
88
  self.src_vocab_size = src_vocab_size
89
  self.tgt_vocab_size = tgt_vocab_size
@@ -93,11 +91,11 @@ def load_model_and_resources(hf_token):
93
  self.N = N
94
  self.dropout = dropout
95
 
 
96
  class CustomTransformer(PreTrainedModel):
97
  config_class = TransformerConfig
98
  def __init__(self, config):
99
  super().__init__(config)
100
- from utils.ch09util import create_model
101
  self.model = create_model(
102
  config.src_vocab_size,
103
  config.tgt_vocab_size,
@@ -110,18 +108,26 @@ def load_model_and_resources(hf_token):
110
  def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
111
  return self.model(src, tgt, src_mask, tgt_mask)
112
 
113
- config = TransformerConfig.from_pretrained(MODEL_NAME, token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
114
  model = CustomTransformer.from_pretrained(
115
  MODEL_NAME,
116
  config=config,
117
  token=hf_token
118
  ).to(DEVICE)
119
 
120
- # Load dictionaries (assumes dict.p was uploaded to the model repo)
121
- dict_path = "dict.p"
122
- if not os.path.exists(dict_path):
123
- st.error("Dictionary file (dict.p) not found. Please ensure it was uploaded to the model repository.")
124
- return None
125
  with open(dict_path, "rb") as fb:
126
  en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
127
 
 
3
  import pandas as pd
4
  import PyPDF2
5
  import pickle
6
+ from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
7
+ from huggingface_hub import login, hf_hub_download
 
8
  import time
9
+ from utils.ch09util import subsequent_mask, create_model # Ensure ch09util.py is available
10
 
11
  # Device setup
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
78
  token=hf_token
79
  )
80
 
81
+ # Define Transformer configuration
 
82
  class TransformerConfig(PretrainedConfig):
83
  model_type = "custom_transformer"
84
+ def __init__(self, src_vocab_size=11055, tgt_vocab_size=11239, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs):
85
  super().__init__(**kwargs)
86
  self.src_vocab_size = src_vocab_size
87
  self.tgt_vocab_size = tgt_vocab_size
 
91
  self.N = N
92
  self.dropout = dropout
93
 
94
+ # Define Transformer model
95
  class CustomTransformer(PreTrainedModel):
96
  config_class = TransformerConfig
97
  def __init__(self, config):
98
  super().__init__(config)
 
99
  self.model = create_model(
100
  config.src_vocab_size,
101
  config.tgt_vocab_size,
 
108
  def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
109
  return self.model(src, tgt, src_mask, tgt_mask)
110
 
111
+ # Load config with validation
112
+ config_dict = TransformerConfig.from_pretrained(MODEL_NAME, token=hf_token).to_dict()
113
+ if "src_vocab_size" not in config_dict or "tgt_vocab_size" not in config_dict:
114
+ st.warning(
115
+ f"Config at {MODEL_NAME}/config.json is missing 'src_vocab_size' or 'tgt_vocab_size'. "
116
+ "Using defaults (11055, 11239). For accuracy, update the training script to save these values."
117
+ )
118
+ config = TransformerConfig()
119
+ else:
120
+ config = TransformerConfig(**config_dict)
121
+
122
+ # Load model
123
  model = CustomTransformer.from_pretrained(
124
  MODEL_NAME,
125
  config=config,
126
  token=hf_token
127
  ).to(DEVICE)
128
 
129
+ # Load dictionaries from Hugging Face Hub
130
+ dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=hf_token)
 
 
 
131
  with open(dict_path, "rb") as fb:
132
  en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
133