Update donut_inference.py
Browse files- donut_inference.py +3 -2
donut_inference.py
CHANGED
@@ -26,7 +26,7 @@ def inference(image):
|
|
26 |
|
27 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
# model.to(device)
|
29 |
-
|
30 |
outputs = model.generate(pixel_values.to(device),
|
31 |
decoder_input_ids=decoder_input_ids.to(device),
|
32 |
max_length=model.decoder.config.max_position_embeddings,
|
@@ -38,11 +38,12 @@ def inference(image):
|
|
38 |
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
39 |
return_dict_in_generate=True,
|
40 |
output_scores=True,)
|
41 |
-
|
42 |
sequence = processor.batch_decode(outputs.sequences)[0]
|
43 |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
44 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
45 |
print(processor.token2json(sequence))
|
|
|
46 |
return processor.token2json(sequence)
|
47 |
|
48 |
# data = inference(image)
|
|
|
26 |
|
27 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
# model.to(device)
|
29 |
+
start_time = time.time()
|
30 |
outputs = model.generate(pixel_values.to(device),
|
31 |
decoder_input_ids=decoder_input_ids.to(device),
|
32 |
max_length=model.decoder.config.max_position_embeddings,
|
|
|
38 |
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
39 |
return_dict_in_generate=True,
|
40 |
output_scores=True,)
|
41 |
+
end_time = time.time()
|
42 |
sequence = processor.batch_decode(outputs.sequences)[0]
|
43 |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
44 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
45 |
print(processor.token2json(sequence))
|
46 |
+
print(f"Donut Inference time {start_time-end_time}")
|
47 |
return processor.token2json(sequence)
|
48 |
|
49 |
# data = inference(image)
|