pierreguillou commited on
Commit
1af9124
·
1 Parent(s): dfc5d28

Update files/functions.py

Browse files
Files changed (1) hide show
  1. files/functions.py +19 -7
files/functions.py CHANGED
@@ -137,13 +137,21 @@ langdetect2Tesseract = {v:k for k,v in Tesseract2langdetect.items()}
137
 
138
  ## model / feature extractor / tokenizer
139
 
 
140
  import torch
141
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142
 
143
- from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
 
 
 
 
 
144
 
145
- model = LayoutLMv2ForTokenClassification.from_pretrained(model_id);
146
- model.to(device);
 
 
147
 
148
  # feature extractor
149
  from transformers import LayoutLMv2FeatureExtractor
@@ -151,12 +159,16 @@ feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
151
 
152
  # tokenizer
153
  from transformers import AutoTokenizer
154
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
155
 
156
  # get labels
157
- id2label = model.config.id2label
158
- label2id = model.config.label2id
159
- num_labels = len(id2label)
 
 
 
 
160
 
161
  ## General
162
 
 
137
 
138
  ## model / feature extractor / tokenizer
139
 
140
+ # get device
141
  import torch
142
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
 
144
+ ## model LiLT
145
+ import transformers
146
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
147
+ tokenizer_lilt = AutoTokenizer.from_pretrained(model_id_lilt)
148
+ model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
149
+ model_lilt.to(device);
150
 
151
+ ## model LayoutXLM
152
+ from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
153
+ model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
154
+ model_layoutxlm.to(device);
155
 
156
  # feature extractor
157
  from transformers import LayoutLMv2FeatureExtractor
 
159
 
160
  # tokenizer
161
  from transformers import AutoTokenizer
162
+ tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
163
 
164
  # get labels
165
+ id2label_lilt = model_lilt.config.id2label
166
+ label2id_lilt = model_lilt.config.label2id
167
+ num_labels_lilt = len(id2label_lilt)
168
+
169
+ id2label_layoutxlm = model_layoutxlm.config.id2label
170
+ label2id_layoutxlm = model_layoutxlm.config.label2id
171
+ num_labels_layoutxlm = len(id2label_layoutxlm)
172
 
173
  ## General
174