pierreguillou commited on
Commit
dfc5d28
·
1 Parent(s): d9aec25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -52,16 +52,21 @@ os.system('python -m pip install --upgrade pip')
52
 
53
  ## model / feature extractor / tokenizer
54
 
 
55
  import torch
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
 
58
- # model
59
- from transformers import LayoutLMv2ForTokenClassification
 
 
 
 
60
 
61
- model_id = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-paragraphlevel-ml512"
62
-
63
- model = LayoutLMv2ForTokenClassification.from_pretrained(model_id);
64
- model.to(device);
65
 
66
  # feature extractor
67
  from transformers import LayoutLMv2FeatureExtractor
@@ -69,13 +74,16 @@ feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
69
 
70
  # tokenizer
71
  from transformers import AutoTokenizer
72
- tokenizer_id = "xlm-roberta-base"
73
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
74
 
75
  # get labels
76
- id2label = model.config.id2label
77
- label2id = model.config.label2id
78
- num_labels = len(id2label)
 
 
 
 
79
 
80
  # APP outputs
81
  def app_outputs(uploaded_pdf):
 
52
 
53
  ## model / feature extractor / tokenizer
54
 
55
+ # get device
56
  import torch
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
 
59
+ ## model LiLT
60
+ import transformers
61
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
62
+ tokenizer_lilt = AutoTokenizer.from_pretrained(model_id_lilt)
63
+ model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
64
+ model_lilt.to(device);
65
 
66
+ ## model LayoutXLM
67
+ from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
68
+ model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
69
+ model_layoutxlm.to(device);
70
 
71
  # feature extractor
72
  from transformers import LayoutLMv2FeatureExtractor
 
74
 
75
  # tokenizer
76
  from transformers import AutoTokenizer
77
+ tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
 
78
 
79
  # get labels
80
+ id2label_lilt = model_lilt.config.id2label
81
+ label2id_lilt = model_lilt.config.label2id
82
+ num_labels_lilt = len(id2label_lilt)
83
+
84
+ id2label_layoutxlm = model_layoutxlm.config.id2label
85
+ label2id_layoutxlm = model_layoutxlm.config.label2id
86
+ num_labels_layoutxlm = len(id2label_layoutxlm)
87
 
88
  # APP outputs
89
  def app_outputs(uploaded_pdf):