Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,11 +3,10 @@ import torch
|
|
3 |
import pandas as pd
|
4 |
import PyPDF2
|
5 |
import pickle
|
6 |
-
import
|
7 |
-
from
|
8 |
-
from huggingface_hub import login
|
9 |
import time
|
10 |
-
from ch09util import subsequent_mask # Ensure ch09util.py is available
|
11 |
|
12 |
# Device setup
|
13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -79,11 +78,10 @@ def load_model_and_resources(hf_token):
|
|
79 |
token=hf_token
|
80 |
)
|
81 |
|
82 |
-
#
|
83 |
-
from transformers import PreTrainedModel, PretrainedConfig
|
84 |
class TransformerConfig(PretrainedConfig):
|
85 |
model_type = "custom_transformer"
|
86 |
-
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs):
|
87 |
super().__init__(**kwargs)
|
88 |
self.src_vocab_size = src_vocab_size
|
89 |
self.tgt_vocab_size = tgt_vocab_size
|
@@ -93,11 +91,11 @@ def load_model_and_resources(hf_token):
|
|
93 |
self.N = N
|
94 |
self.dropout = dropout
|
95 |
|
|
|
96 |
class CustomTransformer(PreTrainedModel):
|
97 |
config_class = TransformerConfig
|
98 |
def __init__(self, config):
|
99 |
super().__init__(config)
|
100 |
-
from utils.ch09util import create_model
|
101 |
self.model = create_model(
|
102 |
config.src_vocab_size,
|
103 |
config.tgt_vocab_size,
|
@@ -110,18 +108,26 @@ def load_model_and_resources(hf_token):
|
|
110 |
def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
|
111 |
return self.model(src, tgt, src_mask, tgt_mask)
|
112 |
|
113 |
-
config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
model = CustomTransformer.from_pretrained(
|
115 |
MODEL_NAME,
|
116 |
config=config,
|
117 |
token=hf_token
|
118 |
).to(DEVICE)
|
119 |
|
120 |
-
# Load dictionaries
|
121 |
-
dict_path = "dict.p"
|
122 |
-
if not os.path.exists(dict_path):
|
123 |
-
st.error("Dictionary file (dict.p) not found. Please ensure it was uploaded to the model repository.")
|
124 |
-
return None
|
125 |
with open(dict_path, "rb") as fb:
|
126 |
en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
|
127 |
|
|
|
3 |
import pandas as pd
|
4 |
import PyPDF2
|
5 |
import pickle
|
6 |
+
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
|
7 |
+
from huggingface_hub import login, hf_hub_download
|
|
|
8 |
import time
|
9 |
+
from utils.ch09util import subsequent_mask, create_model # Ensure ch09util.py is available
|
10 |
|
11 |
# Device setup
|
12 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
78 |
token=hf_token
|
79 |
)
|
80 |
|
81 |
+
# Define Transformer configuration
|
|
|
82 |
class TransformerConfig(PretrainedConfig):
|
83 |
model_type = "custom_transformer"
|
84 |
+
def __init__(self, src_vocab_size=11055, tgt_vocab_size=11239, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs):
|
85 |
super().__init__(**kwargs)
|
86 |
self.src_vocab_size = src_vocab_size
|
87 |
self.tgt_vocab_size = tgt_vocab_size
|
|
|
91 |
self.N = N
|
92 |
self.dropout = dropout
|
93 |
|
94 |
+
# Define Transformer model
|
95 |
class CustomTransformer(PreTrainedModel):
|
96 |
config_class = TransformerConfig
|
97 |
def __init__(self, config):
|
98 |
super().__init__(config)
|
|
|
99 |
self.model = create_model(
|
100 |
config.src_vocab_size,
|
101 |
config.tgt_vocab_size,
|
|
|
108 |
def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
|
109 |
return self.model(src, tgt, src_mask, tgt_mask)
|
110 |
|
111 |
+
# Load config with validation
|
112 |
+
config_dict = TransformerConfig.from_pretrained(MODEL_NAME, token=hf_token).to_dict()
|
113 |
+
if "src_vocab_size" not in config_dict or "tgt_vocab_size" not in config_dict:
|
114 |
+
st.warning(
|
115 |
+
f"Config at {MODEL_NAME}/config.json is missing 'src_vocab_size' or 'tgt_vocab_size'. "
|
116 |
+
"Using defaults (11055, 11239). For accuracy, update the training script to save these values."
|
117 |
+
)
|
118 |
+
config = TransformerConfig()
|
119 |
+
else:
|
120 |
+
config = TransformerConfig(**config_dict)
|
121 |
+
|
122 |
+
# Load model
|
123 |
model = CustomTransformer.from_pretrained(
|
124 |
MODEL_NAME,
|
125 |
config=config,
|
126 |
token=hf_token
|
127 |
).to(DEVICE)
|
128 |
|
129 |
+
# Load dictionaries from Hugging Face Hub
|
130 |
+
dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=hf_token)
|
|
|
|
|
|
|
131 |
with open(dict_path, "rb") as fb:
|
132 |
en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
|
133 |
|