w3robotics commited on
Commit
220ef37
·
verified ·
1 Parent(s): 2baa8f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -23,7 +23,7 @@ st.title("Classify Document Image")
23
  file_name = st.file_uploader("Upload a candidate image")
24
 
25
  if file_name is not None:
26
- col1, col2 = st.columns(2)
27
 
28
  image = Image.open(file_name)
29
  image = image.convert("RGB")
@@ -58,7 +58,32 @@ if file_name is not None:
58
 
59
  col2.header("Results")
60
  col2.subheader(processor.token2json(sequence))
61
-
 
 
62
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
64
 
 
 
 
 
 
 
 
23
  file_name = st.file_uploader("Upload a candidate image")
24
 
25
  if file_name is not None:
26
+ col1, col2, col3 = st.columns(3)
27
 
28
  image = Image.open(file_name)
29
  image = image.convert("RGB")
 
58
 
59
  col2.header("Results")
60
  col2.subheader(processor.token2json(sequence))
61
+
62
+ processor_ext = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
63
+ model_ext = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
64
 
65
+ model_ext.to(device)
66
+
67
+ # prepare decoder inputs
68
+ task_prompt = "<s_cord-v2>"
69
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
70
+
71
+ pixel_values = processor(image, return_tensors="pt").pixel_values
72
 
73
+ outputs = model_ext.generate(
74
+ pixel_values.to(device),
75
+ decoder_input_ids=decoder_input_ids.to(device),
76
+ max_length=model.decoder.config.max_position_embeddings,
77
+ pad_token_id=processor.tokenizer.pad_token_id,
78
+ eos_token_id=processor.tokenizer.eos_token_id,
79
+ use_cache=True,
80
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
81
+ return_dict_in_generate=True,
82
+ )
83
 
84
+ sequence = processor.batch_decode(outputs.sequences)[0]
85
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
86
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
87
+ col3.header("Features")
88
+ col3.subheader(processor.token2json(sequence))
89
+