Lord-Raven commited on
Commit
ef3a388
·
1 Parent(s): 8ec85f2

Experimenting with few-shot classification.

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -74,7 +74,7 @@ few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-ss
74
  candidate_labels = ["supported", "refuted"]
75
  reference_dataset = load_dataset("emotion")
76
  dummy_dataset = Dataset.from_dict({})
77
- train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="This hypothesis is {}")
78
  args = TrainingArguments(
79
  batch_size=32,
80
  num_epochs=1
@@ -110,13 +110,13 @@ def zero_shot_classification(data):
110
  return response_string
111
 
112
  def create_sequences(data):
113
- return ['###Given:\n' + data['sequence'] + '\n###Hypothesis:\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
114
 
115
  def few_shot_classification(data):
116
  sequences = create_sequences(data)
117
  print(sequences)
118
  results = onnx_few_shot_model(sequences)
119
- probs = onnx_few_shot_model.predict_proba(data['sequence'])
120
  print(results)
121
  print(probs)
122
  response_string = json.dumps(results.tolist())
 
74
  candidate_labels = ["supported", "refuted"]
75
  reference_dataset = load_dataset("emotion")
76
  dummy_dataset = Dataset.from_dict({})
77
+ train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8, template="Based on the Given passage, the hypothesis is {}.")
78
  args = TrainingArguments(
79
  batch_size=32,
80
  num_epochs=1
 
110
  return response_string
111
 
112
  def create_sequences(data):
113
+ return ['###Given:\n' + data['sequence'] + '\n###End Given\n###Hypothesis:\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels'] + "\n###End Hypothesis" + "\nBased on the Given passage, the Hypothesis is {}"]
114
 
115
  def few_shot_classification(data):
116
  sequences = create_sequences(data)
117
  print(sequences)
118
  results = onnx_few_shot_model(sequences)
119
+ probs = onnx_few_shot_model.predict_proba(sequences)
120
  print(results)
121
  print(probs)
122
  response_string = json.dumps(results.tolist())