S-Dreamer commited on
Commit
8302e50
·
verified ·
1 Parent(s): 352ff58

Create evaluator.py

Browse files
Files changed (1) hide show
  1. evaluator.py +39 -0
evaluator.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics import BLEUScore, METEOR
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ class CodeEvaluator:
6
+ def __init__(self, model_name):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.model.to(self.device)
11
+ self.bleu = BLEUScore()
12
+ self.meteor = METEOR()
13
+
14
+ def evaluate(self, nl_input, target_code):
15
+ inputs = self.tokenizer(nl_input, return_tensors="pt").to(self.device)
16
+ outputs = self.model.generate(
17
+ **inputs,
18
+ )
19
+ generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
20
+
21
+ bleu_score = self.bleu(generated_code, target_code)
22
+ meteor_score = self.meteor(generated_code, target_code)
23
+ return bleu_score, meteor_score
24
+
25
+ if __name__ == "__main__":
26
+ model_name = "S-Dreamer/PyCodeT5"
27
+ evaluator = CodeEvaluator(model_name)
28
+
29
+ nl_input = "Write a Python function to calculate the factorial of a number."
30
+ target_code = """
31
+ def factorial(n):
32
+ if n == 0:
33
+ return 1
34
+ else:
35
+ return n * factorial(n-1)
36
+ """
37
+ bleu_score, meteor_score = evaluator.evaluate(nl_input, target_code)
38
+ print(f"BLEU score: {bleu_score}")
39
+ print(f"METEOR score: {meteor_score}")