w3robotics commited on
Commit
8c52817
·
verified ·
1 Parent(s): 4f85dbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -5
app.py CHANGED
@@ -1,8 +1,47 @@
1
- import streamlit as st
 
 
 
 
2
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- # Load model directly
5
- from transformers import AutoTokenizer, AutoModel
6
 
7
- tokenizer = AutoTokenizer.from_pretrained("jinhybr/OCR-Donut-CORD")
8
- model = AutoModel.from_pretrained("jinhybr/OCR-Donut-CORD")
 
1
+ import re
2
+
3
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
4
+ from datasets import load_dataset
5
+ import torch
6
  from PIL import Image
7
+ import numpy as np
8
+ import streamlit as st
9
+
10
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
11
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model.to(device)
15
+
16
+ #image = Image.open(r"C:\Invoices\Sample Invoices\sample invoice 1.tif")
17
+ #image = image.convert("RGB")
18
+ #print(np.array(image).shape)
19
+
20
+
21
+ # load document image
22
+ dataset = load_dataset("hf-internal-testing/example-documents", split="test")
23
+ image = dataset[2]["image"]
24
+
25
+
26
+ task_prompt = "<s_rvlcdip>"
27
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
28
+
29
+ pixel_values = processor(image, return_tensors="pt").pixel_values
30
+
31
+ outputs = model.generate(
32
+ pixel_values.to(device),
33
+ decoder_input_ids=decoder_input_ids.to(device),
34
+ max_length=model.decoder.config.max_position_embeddings,
35
+ pad_token_id=processor.tokenizer.pad_token_id,
36
+ eos_token_id=processor.tokenizer.eos_token_id,
37
+ use_cache=True,
38
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
39
+ return_dict_in_generate=True,
40
+ )
41
+
42
+ sequence = processor.batch_decode(outputs.sequences)[0]
43
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
44
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
45
+ print(processor.token2json(sequence))
46
 
 
 
47