Spaces:
Runtime error
Runtime error
File size: 2,197 Bytes
fe5e03c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from enum import Enum
from datasets import Dataset
from pathlib import Path
from src.backend.data_fetching.data_fields import DataFields
class ChunkLevel(Enum):
SENTENCE = ("sentence", '. ')
PARAGRAPH = ("paragraph", '\n')
SECTION = ("section", '$$$$$$$$$$$$$$$$$$$$$$$$')
def __repr__(self):
return self.value[0]
@property
def sep(self):
return self.value[1]
class ArticleChunker:
def __init__(self, chunk_level: ChunkLevel = ChunkLevel.SENTENCE):
self.chunk_level = chunk_level
@property
def sep(self):
return self.chunk_level.sep
def __call__(self, text):
sep = self.chunk_level.sep
return text.split(sep)
class DataFetcher:
def __init__(self, chunk_level: ChunkLevel = ChunkLevel.SENTENCE):
self._chunker = ArticleChunker(chunk_level)
# function to make it easier using this specific dataset
def from_hugging_face_scientific_papers_dataset(self, dataset: Dataset, used_separator="\n"):
dataset = dataset.map(
lambda x: {DataFields.SECTIONS: used_separator.join(x["article_abstract"]).split(self._chunker.sep) + used_separator.join(x[DataFields.SECTIONS]).split(self._chunker.sep)})
dataset = dataset.map(lambda x: {DataFields.SECTION_NAMES: ["Abstract"] + x[DataFields.SECTION_NAMES]})
return dataset
def from_dataset(self, dataset: Dataset, used_separator=''):
assert all(i in dataset.column_names for i in DataFields)
dataset = dataset.map(lambda x: {DataFields.SECTIONS: used_separator.join(x[DataFields.SECTIONS]).split(self._chunker.sep)})
return dataset
def from_pmcid(self, pmcid: str | list[str]):
if isinstance(pmcid, str):
pmcid = [pmcid]
from src.backend.data_fetching.fetch_pmcid import (fetch_from_pmcid)
output = []
for pmcid in pmcid:
article_out = fetch_from_pmcid(pmcid)
article_out[DataFields.SECTIONS] = [self._chunker(i) for i in article_out[DataFields.SECTIONS]]
output.append(article_out)
return output
def from_pdf(self, file: str | Path):
raise NotImplementedError()
pass
|