File size: 2,197 Bytes
fe5e03c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from enum import Enum
from datasets import Dataset
from pathlib import Path

from src.backend.data_fetching.data_fields import DataFields

class ChunkLevel(Enum):
    SENTENCE = ("sentence", '. ')
    PARAGRAPH = ("paragraph", '\n')
    SECTION = ("section", '$$$$$$$$$$$$$$$$$$$$$$$$')

    def __repr__(self):
        return self.value[0]

    @property
    def sep(self):
        return self.value[1]


class ArticleChunker:
    def __init__(self, chunk_level: ChunkLevel = ChunkLevel.SENTENCE):
        self.chunk_level = chunk_level

    @property
    def sep(self):
        return self.chunk_level.sep

    def __call__(self, text):
        sep = self.chunk_level.sep
        return text.split(sep)

class DataFetcher:
    def __init__(self, chunk_level: ChunkLevel = ChunkLevel.SENTENCE):
        self._chunker = ArticleChunker(chunk_level)

    # function to make it easier using this specific dataset
    def from_hugging_face_scientific_papers_dataset(self, dataset: Dataset, used_separator="\n"):
        dataset = dataset.map(
            lambda x: {DataFields.SECTIONS: used_separator.join(x["article_abstract"]).split(self._chunker.sep) + used_separator.join(x[DataFields.SECTIONS]).split(self._chunker.sep)})
        dataset = dataset.map(lambda x: {DataFields.SECTION_NAMES: ["Abstract"] + x[DataFields.SECTION_NAMES]})
        return dataset

    def from_dataset(self, dataset: Dataset, used_separator=''):
        assert all(i in dataset.column_names for i in DataFields)
        dataset = dataset.map(lambda x: {DataFields.SECTIONS: used_separator.join(x[DataFields.SECTIONS]).split(self._chunker.sep)})
        return dataset

    def from_pmcid(self, pmcid: str | list[str]):
        if isinstance(pmcid, str):
            pmcid = [pmcid]
        from src.backend.data_fetching.fetch_pmcid import (fetch_from_pmcid)
        output = []
        for pmcid in pmcid:
            article_out = fetch_from_pmcid(pmcid)
            article_out[DataFields.SECTIONS] = [self._chunker(i) for i in article_out[DataFields.SECTIONS]]
            output.append(article_out)
        return output

    def from_pdf(self, file: str | Path):
        raise NotImplementedError()
        pass