Lord-Raven commited on
Commit
523f8fd
·
1 Parent(s): d07a287

Experimenting with few-shot classification.

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -57,22 +57,17 @@ classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=to
57
 
58
  few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
59
  ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
60
- few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english")
61
 
62
-
63
-
64
- candidate_labels = ["true", "false"]
65
  reference_dataset = load_dataset("emotion")
66
  dummy_dataset = Dataset.from_dict({})
67
- train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="This statement is {}")
68
-
69
-
70
-
71
  args = TrainingArguments(
72
  batch_size=32,
73
  num_epochs=1
74
  )
75
-
76
  trainer = Trainer(
77
  model=few_shot_model,
78
  args=args,
@@ -81,8 +76,8 @@ trainer = Trainer(
81
  )
82
  trainer.train()
83
 
84
- metrics = trainer.evaluate()
85
- print(metrics)
86
 
87
  onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
88
 
 
57
 
58
  few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
59
  ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
60
+ few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english", multi_target_strategy="multi-output")
61
 
62
+ # Train few_shot_model
63
+ candidate_labels = ["supported", "refuted"]
 
64
  reference_dataset = load_dataset("emotion")
65
  dummy_dataset = Dataset.from_dict({})
66
+ train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="This hypothesis is {}")
 
 
 
67
  args = TrainingArguments(
68
  batch_size=32,
69
  num_epochs=1
70
  )
 
71
  trainer = Trainer(
72
  model=few_shot_model,
73
  args=args,
 
76
  )
77
  trainer.train()
78
 
79
+ # metrics = trainer.evaluate()
80
+ # print(metrics)
81
 
82
  onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
83