Spaces:
Sleeping
Sleeping
from pickle import dump | |
from typing import List | |
from cohere import Client | |
from numpy import array | |
from gossip_semantic_search.models import Article, ProcessedDataset | |
from gossip_semantic_search.utils import embed_content, CustomUnpickler | |
class DatasetProcessor: | |
def __init__(self, | |
dataset_path: str, | |
saved_processed_dataset_path: str): | |
self.dataset_path = dataset_path | |
self.saved_processed_dataset_path = saved_processed_dataset_path | |
self.processed_dataset: ProcessedDataset = None | |
def load_dataset(dataset_path: str) -> List[Article]: | |
with open(dataset_path, 'rb') as file: | |
unpickler = CustomUnpickler(file) | |
data = unpickler.load() | |
return data | |
def process_dataset(self, | |
data: List[Article]): | |
client = Client() | |
y_true = [] | |
questions = [] | |
for i, sample in enumerate(data): | |
for question in sample.questions: | |
y_true.append(i) | |
questions.append(question) | |
self.processed_dataset = ProcessedDataset( | |
y_true = array(y_true), | |
embedded_queries=embed_content(questions, client), | |
embedded_context=array([sample.embeded_content for sample in data])) | |
def save_articles(self): | |
with open(self.saved_processed_dataset_path, 'wb') as f: | |
dump(self.processed_dataset, f) | |
def run(self): | |
data = self.load_dataset(self.dataset_path) | |
self.process_dataset(data) | |
self.save_articles() | |