ejschwartz commited on
Commit
2ab342b
·
1 Parent(s): 762a224

let's go field decoding

Browse files
Files changed (1) hide show
  1. app.py +36 -30
app.py CHANGED
@@ -15,10 +15,10 @@ huggingface_hub.login(token=hf_key)
15
 
16
  tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
17
  vardecoder_model = AutoModelForCausalLM.from_pretrained(
18
- "ejschwartz/resym-vardecoder", torch_dtype=torch.bfloat16#, device_map={"": 0}
19
  ).to("cuda")
20
  fielddecoder_model = AutoModelForCausalLM.from_pretrained(
21
- "ejschwartz/resym-fielddecoder", torch_dtype=torch.bfloat16#, device_map={"": 0}
22
  ).to("cuda")
23
 
24
  gradio_client = Client("https://ejschwartz-resym-field-helper.hf.space/")
@@ -42,10 +42,12 @@ def field_prompt(code):
42
  print(f"fields: {fields}")
43
 
44
  prompt = f"```\n{code}\n```\nWhat are the variable name and type for the following memory accesses:{', '.join(fields)}?\n"
 
 
45
 
46
  print(f"field prompt: {prompt}")
47
 
48
- return prompt, field_helper_result
49
 
50
  @spaces.GPU
51
  def infer(code):
@@ -65,18 +67,18 @@ def infer(code):
65
 
66
  varstring = ", ".join([f"`{v}`" for v in vars])
67
 
68
- var_name = vars[0]
69
 
70
  # ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py
71
- var_prompt = f"What are the original name and data types of variables {varstring}?\n```\n{code}\n```{var_name}"
72
 
73
  print(f"Prompt:\n{var_prompt}")
74
 
75
- input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[
76
  :, : 8192 - 1024
77
  ]
78
  var_output = vardecoder_model.generate(
79
- input_ids=input_ids,
80
  max_new_tokens=1024,
81
  num_beams=4,
82
  num_return_sequences=1,
@@ -86,32 +88,36 @@ def infer(code):
86
  eos_token_id=0,
87
  )[0]
88
  var_output = tokenizer.decode(
89
- var_output[input_ids.size(1) :],
90
  skip_special_tokens=True,
91
  clean_up_tokenization_spaces=True,
92
  )
93
 
94
- field_prompt_result, field_helper_result = field_prompt(code)
95
-
96
- # field_output = fielddecoder_model.generate(
97
- # input_ids=input_ids,
98
- # max_new_tokens=1024,
99
- # num_beams=4,
100
- # num_return_sequences=1,
101
- # do_sample=False,
102
- # early_stopping=False,
103
- # pad_token_id=0,
104
- # eos_token_id=0,
105
- # )[0]
106
- # field_output = tokenizer.decode(
107
- # field_output[input_ids.size(1) :],
108
- # skip_special_tokens=True,
109
- # clean_up_tokenization_spaces=True,
110
- # )
111
-
112
- var_output = var_name + ":" + var_output
113
- # field_output = var_name + ":" + field_output
114
- return var_output, varstring
 
 
 
 
115
 
116
 
117
  demo = gr.Interface(
@@ -121,7 +127,7 @@ demo = gr.Interface(
121
  ],
122
  outputs=[
123
  gr.Text(label="Var Decoder Output"),
124
- # gr.Text(label="Field Decoder Output"),
125
  gr.Text(label="Generated Variable List"),
126
  ],
127
  description=frontmatter.load("README.md").content,
 
15
 
16
  tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
17
  vardecoder_model = AutoModelForCausalLM.from_pretrained(
18
+ "ejschwartz/resym-vardecoder", torch_dtype=torch.bfloat16
19
  ).to("cuda")
20
  fielddecoder_model = AutoModelForCausalLM.from_pretrained(
21
+ "ejschwartz/resym-fielddecoder", torch_dtype=torch.bfloat16
22
  ).to("cuda")
23
 
24
  gradio_client = Client("https://ejschwartz-resym-field-helper.hf.space/")
 
42
  print(f"fields: {fields}")
43
 
44
  prompt = f"```\n{code}\n```\nWhat are the variable name and type for the following memory accesses:{', '.join(fields)}?\n"
45
+ if len(fields) > 0:
46
+ prompt += f"{fields[0]}:"
47
 
48
  print(f"field prompt: {prompt}")
49
 
50
+ return prompt, fields, field_helper_result
51
 
52
  @spaces.GPU
53
  def infer(code):
 
67
 
68
  varstring = ", ".join([f"`{v}`" for v in vars])
69
 
70
+ first_var = vars[0]
71
 
72
  # ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py
73
+ var_prompt = f"What are the original name and data types of variables {varstring}?\n```\n{code}\n```{first_var}"
74
 
75
  print(f"Prompt:\n{var_prompt}")
76
 
77
+ var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[
78
  :, : 8192 - 1024
79
  ]
80
  var_output = vardecoder_model.generate(
81
+ input_ids=var_input_ids,
82
  max_new_tokens=1024,
83
  num_beams=4,
84
  num_return_sequences=1,
 
88
  eos_token_id=0,
89
  )[0]
90
  var_output = tokenizer.decode(
91
+ var_output[var_input_ids.size(1) :],
92
  skip_special_tokens=True,
93
  clean_up_tokenization_spaces=True,
94
  )
95
 
96
+ field_prompt_result, fields, field_helper_result = field_prompt(code)
97
+ field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").cuda()[
98
+ :, : 8192 - 1024
99
+ ]
100
+
101
+ field_output = fielddecoder_model.generate(
102
+ input_ids=field_input_ids,
103
+ max_new_tokens=1024,
104
+ num_beams=4,
105
+ num_return_sequences=1,
106
+ do_sample=False,
107
+ early_stopping=False,
108
+ pad_token_id=0,
109
+ eos_token_id=0,
110
+ )[0]
111
+ field_output = tokenizer.decode(
112
+ field_output[var_input_ids.size(1) :],
113
+ skip_special_tokens=True,
114
+ clean_up_tokenization_spaces=True,
115
+ )
116
+
117
+ var_output = first_var + ":" + var_output
118
+ if len(fields) > 0:
119
+ field_output = fields[0] + ":" + field_output
120
+ return var_output, field_output, varstring
121
 
122
 
123
  demo = gr.Interface(
 
127
  ],
128
  outputs=[
129
  gr.Text(label="Var Decoder Output"),
130
+ gr.Text(label="Field Decoder Output"),
131
  gr.Text(label="Generated Variable List"),
132
  ],
133
  description=frontmatter.load("README.md").content,