pierre Brault
imit
3ff674d
raw
history blame contribute delete
1.59 kB
from pickle import dump
from typing import List
from cohere import Client
from numpy import array
from gossip_semantic_search.models import Article, ProcessedDataset
from gossip_semantic_search.utils import embed_content, CustomUnpickler
class DatasetProcessor:
def __init__(self,
dataset_path: str,
saved_processed_dataset_path: str):
self.dataset_path = dataset_path
self.saved_processed_dataset_path = saved_processed_dataset_path
self.processed_dataset: ProcessedDataset = None
@staticmethod
def load_dataset(dataset_path: str) -> List[Article]:
with open(dataset_path, 'rb') as file:
unpickler = CustomUnpickler(file)
data = unpickler.load()
return data
def process_dataset(self,
data: List[Article]):
client = Client()
y_true = []
questions = []
for i, sample in enumerate(data):
for question in sample.questions:
y_true.append(i)
questions.append(question)
self.processed_dataset = ProcessedDataset(
y_true = array(y_true),
embedded_queries=embed_content(questions, client),
embedded_context=array([sample.embeded_content for sample in data]))
def save_articles(self):
with open(self.saved_processed_dataset_path, 'wb') as f:
dump(self.processed_dataset, f)
def run(self):
data = self.load_dataset(self.dataset_path)
self.process_dataset(data)
self.save_articles()