SearchMesh / tag.py
Nick Sorros
Tag more grants and implement most common
b493a01
import json
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import srsly
import typer
def load_data(data_path, sample_size):
with open(data_path) as f:
data = json.loads(f.read())
return data
def tag(data_path, tagged_data_path, sample_size: int = 10):
data = srsly.read_jsonl(data_path)
data = [next(data) for _ in range(sample_size)]
tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh")
model = AutoModel.from_pretrained(
"Wellcome/WellcomeBertMesh", trust_remote_code=True
)
texts = [grant["title_and_description"] for grant in data]
for batch_index in tqdm(range(0, len(texts), 10)):
batch_texts = texts[batch_index : batch_index + 10]
inputs = tokenizer(batch_texts, padding="max_length")
labels = model(**inputs, return_labels=True)
for i, tags in enumerate(labels):
data[batch_index + i]["tags"] = tags
srsly.write_jsonl(tagged_data_path, data)
if __name__ == "__main__":
typer.run(tag)