File size: 879 Bytes
4709571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import json

from transformers import AutoModel, AutoTokenizer
import srsly
import typer


def load_data(data_path, sample_size):
    with open(data_path) as f:
        data = json.loads(f.read())

    return data


def tag(data_path, tagged_data_path, sample_size: int = 10):
    data = srsly.read_jsonl(data_path)
    data = [next(data) for _ in range(sample_size)]

    tokenizer = AutoTokenizer.from_pretrained("Wellcome/WellcomeBertMesh")
    model = AutoModel.from_pretrained(
        "Wellcome/WellcomeBertMesh", trust_remote_code=True
    )

    texts = [grant["title_and_description"] for grant in data]
    inputs = tokenizer(texts, padding="max_length")
    labels = model(**inputs, return_labels=True)

    for i, tags in enumerate(labels):
        data[i]["tags"] = tags

    srsly.write_jsonl(tagged_data_path, data)


if __name__ == "__main__":
    typer.run(tag)