Spaces:
Sleeping
Sleeping
File size: 3,670 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "aaed9cbc",
"metadata": {},
"outputs": [],
"source": [
"import task\n",
"import deit\n",
"import trocr_models\n",
"import torch\n",
"import fairseq\n",
"from fairseq import utils\n",
"from fairseq_cli import generate\n",
"from PIL import Image\n",
"import torchvision.transforms as transforms\n",
"\n",
"\n",
"def init(model_path, beam=5):\n",
" model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(\n",
" [model_path],\n",
" arg_overrides={\"beam\": beam, \"task\": \"text_recognition\", \"data\": \"\", \"fp16\": False})\n",
"\n",
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
" model[0].to(device)\n",
"\n",
" img_transform = transforms.Compose([\n",
" transforms.Resize((384, 384), interpolation=3),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(0.5, 0.5)\n",
" ])\n",
"\n",
" generator = task.build_generator(\n",
" model, cfg.generation, extra_gen_cls_kwargs={'lm_model': None, 'lm_weight': None}\n",
" )\n",
"\n",
" bpe = task.build_bpe(cfg.bpe)\n",
"\n",
" return model, cfg, task, generator, bpe, img_transform, device\n",
"\n",
"\n",
"def preprocess(img_path, img_transform):\n",
" im = Image.open(img_path).convert('RGB').resize((384, 384))\n",
" im = img_transform(im).unsqueeze(0).to(device).float()\n",
"\n",
" sample = {\n",
" 'net_input': {\"imgs\": im},\n",
" }\n",
"\n",
" return sample\n",
"\n",
"\n",
"def get_text(cfg, generator, model, sample, bpe):\n",
" decoder_output = task.inference_step(generator, model, sample, prefix_tokens=None, constraints=None)\n",
" decoder_output = decoder_output[0][0] #top1\n",
"\n",
" hypo_tokens, hypo_str, alignment = utils.post_process_prediction(\n",
" hypo_tokens=decoder_output[\"tokens\"].int().cpu(),\n",
" src_str=\"\",\n",
" alignment=decoder_output[\"alignment\"],\n",
" align_dict=None,\n",
" tgt_dict=model[0].decoder.dictionary,\n",
" remove_bpe=cfg.common_eval.post_process,\n",
" extra_symbols_to_ignore=generate.get_symbols_to_strip_from_output(generator),\n",
" )\n",
"\n",
" detok_hypo_str = bpe.decode(hypo_str)\n",
"\n",
" return detok_hypo_str"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b95c01e4",
"metadata": {},
"outputs": [],
"source": [
"model_path = 'path/to/model'\n",
"jpg_path = \"path/to/pic\"\n",
"beam = 5\n",
"\n",
"model, cfg, task, generator, bpe, img_transform, device = init(model_path, beam)\n",
"\n",
"sample = preprocess(jpg_path, img_transform)\n",
"\n",
"text = get_text(cfg, generator, model, sample, bpe)\n",
"\n",
"print(text)\n",
"\n",
"print('done')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.5 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"vscode": {
"interpreter": {
"hash": "0b8488e5f98ef3932f4ff0893213e55e6ba8b00dde307078d0f3efb25017ce11"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|