w3robotics commited on
Commit
29d33c6
·
verified ·
1 Parent(s): 04d83d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -62,28 +62,30 @@ if file_name is not None:
62
  processor_ext = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
63
  model_ext = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
64
 
 
65
  model_ext.to(device)
66
-
67
  # prepare decoder inputs
68
  task_prompt = "<s_cord-v2>"
69
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
70
-
71
  pixel_values = processor_ext(image, return_tensors="pt").pixel_values
72
 
73
  outputs = model_ext.generate(
74
  pixel_values.to(device),
75
  decoder_input_ids=decoder_input_ids.to(device),
76
- max_length=model.decoder.config.max_position_embeddings,
77
  pad_token_id=processor.tokenizer.pad_token_id,
78
  eos_token_id=processor.tokenizer.eos_token_id,
79
  use_cache=True,
80
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
81
  return_dict_in_generate=True,
82
  )
83
 
84
  sequence = processor_ext.batch_decode(outputs.sequences)[0]
85
  sequence = sequence.replace(processor_ext.tokenizer.eos_token, "").replace(processor_ext.tokenizer.pad_token, "")
86
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
 
87
  col3.header("Features")
88
  col3.subheader(processor_ext.token2json(sequence))
89
 
 
62
  processor_ext = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
63
  model_ext = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
64
 
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
  model_ext.to(device)
67
+
68
  # prepare decoder inputs
69
  task_prompt = "<s_cord-v2>"
70
+ decoder_input_ids = processor_ext.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
71
+
72
  pixel_values = processor_ext(image, return_tensors="pt").pixel_values
73
 
74
  outputs = model_ext.generate(
75
  pixel_values.to(device),
76
  decoder_input_ids=decoder_input_ids.to(device),
77
+ max_length=model_ext.decoder.config.max_position_embeddings,
78
  pad_token_id=processor.tokenizer.pad_token_id,
79
  eos_token_id=processor.tokenizer.eos_token_id,
80
  use_cache=True,
81
+ bad_words_ids=[[processor_ext.tokenizer.unk_token_id]],
82
  return_dict_in_generate=True,
83
  )
84
 
85
  sequence = processor_ext.batch_decode(outputs.sequences)[0]
86
  sequence = sequence.replace(processor_ext.tokenizer.eos_token, "").replace(processor_ext.tokenizer.pad_token, "")
87
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
88
+ print(processor_ext.token2json(sequence))
89
  col3.header("Features")
90
  col3.subheader(processor_ext.token2json(sequence))
91