chandan06 commited on
Commit
8094952
·
1 Parent(s): 70cb8f2

Update donut_inference.py

Browse files
Files changed (1) hide show
  1. donut_inference.py +3 -2
donut_inference.py CHANGED
@@ -26,7 +26,7 @@ def inference(image):
26
 
27
  # device = "cuda" if torch.cuda.is_available() else "cpu"
28
  # model.to(device)
29
-
30
  outputs = model.generate(pixel_values.to(device),
31
  decoder_input_ids=decoder_input_ids.to(device),
32
  max_length=model.decoder.config.max_position_embeddings,
@@ -38,11 +38,12 @@ def inference(image):
38
  bad_words_ids=[[processor.tokenizer.unk_token_id]],
39
  return_dict_in_generate=True,
40
  output_scores=True,)
41
-
42
  sequence = processor.batch_decode(outputs.sequences)[0]
43
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
44
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
45
  print(processor.token2json(sequence))
 
46
  return processor.token2json(sequence)
47
 
48
  # data = inference(image)
 
26
 
27
  # device = "cuda" if torch.cuda.is_available() else "cpu"
28
  # model.to(device)
29
+ start_time = time.time()
30
  outputs = model.generate(pixel_values.to(device),
31
  decoder_input_ids=decoder_input_ids.to(device),
32
  max_length=model.decoder.config.max_position_embeddings,
 
38
  bad_words_ids=[[processor.tokenizer.unk_token_id]],
39
  return_dict_in_generate=True,
40
  output_scores=True,)
41
+ end_time = time.time()
42
  sequence = processor.batch_decode(outputs.sequences)[0]
43
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
44
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
45
  print(processor.token2json(sequence))
46
+ print(f"Donut Inference time {start_time-end_time}")
47
  return processor.token2json(sequence)
48
 
49
  # data = inference(image)