Spaces:
Build error
Build error
import spaces | |
import gradio as gr | |
# code | |
import pandas as pd | |
from datasets import load_dataset | |
# from sentence_transformers import ( | |
# SentenceTransformer, | |
# SentenceTransformerTrainer, | |
# SentenceTransformerTrainingArguments, | |
# SentenceTransformerModelCardData | |
# ) ### we can imporet everhtuing from the main class... | |
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer | |
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss | |
from sentence_transformers.evaluation import InformationRetrievalEvaluator | |
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers | |
def get_ir_evaluator(eval_ds): | |
"""create from anchor positive dataset instance... could make from a better dataset... LLM generate?""" | |
corpus = {} | |
queries = {} | |
relevant_docs = {} # relevant documents (qid => set[cid]) | |
for idx, example in enumerate(eval_ds): | |
query = example['anchor'] | |
queries[idx] = query | |
document = example['positive'] | |
corpus[idx] = document | |
relevant_docs[idx] = set([idx]) # note: should have more relevant docs here | |
ir_evaluator = InformationRetrievalEvaluator( | |
queries=queries, | |
corpus=corpus, | |
relevant_docs=relevant_docs, | |
name="ir-evaluator", | |
) | |
return ir_evaluator | |
def train(hf_token, dataset_id, model_id, num_epochs, dev=True): | |
ds = load_dataset(dataset_id, split="train", token=hf_token) | |
ds = ds.shuffle(seed=42) | |
if len(ds) > 1000 and dev: ds = ds.select(range(0, 999)) | |
ds = ds.train_test_split(train_size=0.75) | |
train_ds, eval_ds = ds['train'], ds['test'] | |
print('train: ', len(train_ds), 'eval: ', len(eval_ds)) | |
# model | |
model = SentenceTransformer(model_id) | |
# loss | |
loss = CachedMultipleNegativesRankingLoss(model) | |
# training args | |
args = SentenceTransformerTrainingArguments( | |
output_dir="outputs", # required | |
num_train_epochs=num_epochs, # optional... | |
per_device_train_batch_size=16, | |
warmup_ratio=0.1, | |
#fp16=True, # Set to False if your GPU can't handle FP16 | |
#bf16=False, # Set to True if your GPU supports BF16 | |
batch_sampler=BatchSamplers.NO_DUPLICATES, # Losses using "in-batch negatives" benefit from no duplicates | |
save_total_limit=2 | |
# per_device_eval_batch_size=1, | |
# eval_strategy="epoch", | |
# save_strategy="epoch", | |
# logging_steps=100, | |
# Optional tracking/debugging parameters: | |
# eval_strategy="steps", | |
# eval_steps=100, | |
# save_strategy="steps", | |
# save_steps=100, | |
# logging_steps=100, | |
# run_name="jina-code-vechain-pair", # Used in W&B if `wandb` is installed | |
) | |
# ir evaluator | |
ir_evaluator = get_ir_evaluator(eval_ds) | |
# base model metrics | |
base_metrics = ir_evaluator(model) | |
print(ir_evaluator.primary_metric) | |
print(base_metrics[ir_evaluator.primary_metric]) | |
# train | |
trainer = SentenceTransformerTrainer( | |
model=model, | |
args=args, | |
train_dataset=train_ds, | |
# eval_dataset=eval_ds, | |
loss=loss, | |
# evaluator=ir_evaluator, | |
) | |
trainer.train() | |
# fine tuned model metrics | |
ft_metrics = ir_evaluator(model) | |
print(ir_evaluator.primary_metric) | |
print(ft_metrics[ir_evaluator.primary_metric]) | |
if not dev: model.push_to_hub("fine-tuned-sentence-transformer", private=True, token=hf_token) | |
metrics = pd.DataFrame([base_metrics, ft_metrics]).T | |
print(metrics) | |
return str(metrics) | |
## logs to UI | |
# https://github.com/gradio-app/gradio/issues/2362#issuecomment-1424446778 | |
demo = gr.Interface(fn=train, inputs=["text", "text", "text", "number", "bool"], outputs=["text"]) # "dataframe" | |
demo.launch() | |