Spaces:
Runtime error
Runtime error
GitHub Actions Bot
Changes from 'https://github.com/ggruber193/pubmed-central-semantic-search'
fe5e03c
from enum import Enum | |
from datasets import Dataset | |
from pathlib import Path | |
from src.backend.data_fetching.data_fields import DataFields | |
class ChunkLevel(Enum): | |
SENTENCE = ("sentence", '. ') | |
PARAGRAPH = ("paragraph", '\n') | |
SECTION = ("section", '$$$$$$$$$$$$$$$$$$$$$$$$') | |
def __repr__(self): | |
return self.value[0] | |
def sep(self): | |
return self.value[1] | |
class ArticleChunker: | |
def __init__(self, chunk_level: ChunkLevel = ChunkLevel.SENTENCE): | |
self.chunk_level = chunk_level | |
def sep(self): | |
return self.chunk_level.sep | |
def __call__(self, text): | |
sep = self.chunk_level.sep | |
return text.split(sep) | |
class DataFetcher: | |
def __init__(self, chunk_level: ChunkLevel = ChunkLevel.SENTENCE): | |
self._chunker = ArticleChunker(chunk_level) | |
# function to make it easier using this specific dataset | |
def from_hugging_face_scientific_papers_dataset(self, dataset: Dataset, used_separator="\n"): | |
dataset = dataset.map( | |
lambda x: {DataFields.SECTIONS: used_separator.join(x["article_abstract"]).split(self._chunker.sep) + used_separator.join(x[DataFields.SECTIONS]).split(self._chunker.sep)}) | |
dataset = dataset.map(lambda x: {DataFields.SECTION_NAMES: ["Abstract"] + x[DataFields.SECTION_NAMES]}) | |
return dataset | |
def from_dataset(self, dataset: Dataset, used_separator=''): | |
assert all(i in dataset.column_names for i in DataFields) | |
dataset = dataset.map(lambda x: {DataFields.SECTIONS: used_separator.join(x[DataFields.SECTIONS]).split(self._chunker.sep)}) | |
return dataset | |
def from_pmcid(self, pmcid: str | list[str]): | |
if isinstance(pmcid, str): | |
pmcid = [pmcid] | |
from src.backend.data_fetching.fetch_pmcid import (fetch_from_pmcid) | |
output = [] | |
for pmcid in pmcid: | |
article_out = fetch_from_pmcid(pmcid) | |
article_out[DataFields.SECTIONS] = [self._chunker(i) for i in article_out[DataFields.SECTIONS]] | |
output.append(article_out) | |
return output | |
def from_pdf(self, file: str | Path): | |
raise NotImplementedError() | |
pass | |