Petr Tsvetkov
Use local timestamp instead of gradio-updated one to increase the keylogging precision
b2f40c2
import json | |
import os | |
import random | |
import uuid | |
from datetime import datetime | |
from difflib import ndiff | |
import gradio as gr | |
from data_loader import load_data | |
from hf_dataset_saver_builder import get_dataset_saver | |
HF_TOKEN = os.environ.get('HF_REWRITING_TOKEN') | |
HF_DATASET = os.environ.get('HF_REWRITING_DATASET') | |
data = load_data() | |
n_samples = len(data) | |
saver = get_dataset_saver(HF_TOKEN, HF_DATASET, private=True) | |
def convert_diff_to_unified(diff_string): | |
diff = json.loads(diff_string) | |
result = "\n".join( | |
[ | |
f'--- {modified_file["old_path"]}\n' | |
f'+++ {modified_file["new_path"]}\n' | |
f'{modified_file["diff"]}' | |
for modified_file in diff | |
] | |
) | |
return result | |
def get_diff2html_view(raw_diff): | |
html = f""" | |
<div style='width:100%; height:1400px; overflow:auto; position: relative'> | |
<div id='diff-raw' hidden>{raw_diff}</div> | |
<div class="d2h-view-wrapper"> | |
<div id='diff-view'></div> | |
</div> | |
</div> | |
""" | |
return html | |
def get_github_link_md(repo, hash): | |
return f'[See the commit on Github](https://github.com/{repo}/commit/{hash})' | |
def char_diff_obj(change_type, pos, character, timestamp): | |
return {"type": change_type, "pos": pos, "char": character, "timestamp": timestamp} | |
def update_commit_view(sample_ind): | |
if sample_ind >= n_samples: | |
return None | |
record = data[sample_ind] | |
diff_view = get_diff2html_view(convert_diff_to_unified(record['mods'])) | |
repo_val = record['repo'] | |
hash_val = record['hash'] | |
github_link_md = get_github_link_md(repo_val, hash_val) | |
diff_loaded_timestamp = datetime.now().isoformat() | |
summary_md = f"{record['summary']}" | |
commit_message = record['prediction'] | |
commit_message_start = commit_message | |
commit_message_prev = commit_message | |
commit_message_history = [] | |
return ( | |
github_link_md, diff_view, repo_val, hash_val, diff_loaded_timestamp, summary_md, | |
commit_message_start, commit_message, commit_message_prev, commit_message_history) | |
def next_sample(current_sample_ind, shuffled_idx): | |
if current_sample_ind == n_samples: | |
return None | |
current_sample_ind += 1 | |
updated_view = update_commit_view(shuffled_idx[current_sample_ind]) | |
return (current_sample_ind,) + updated_view | |
with open("head.html") as head_file: | |
head_html = head_file.read() | |
with gr.Blocks(theme=gr.themes.Soft(), head=head_html, css="style_overrides.css") as application: | |
repo_val = gr.Textbox(interactive=False, label='repo', visible=False) | |
hash_val = gr.Textbox(interactive=False, label='hash', visible=False) | |
shuffled_idx_val = gr.JSON(visible=False) | |
with gr.Row(): | |
with gr.Accordion("Help"): | |
with open("survey_guide.md") as content_file: | |
gr.Markdown(content_file.read()) | |
with gr.Row(): | |
current_sample_sld = gr.Slider(minimum=0, maximum=n_samples, step=1, | |
value=0, | |
interactive=False, | |
label='sample_ind', | |
info=f"Samples labeled/skipped (out of {n_samples})", | |
show_label=False, | |
container=False, | |
scale=5) | |
with gr.Column(scale=1): | |
skip_btn = gr.Button("Skip the current sample") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
github_link = gr.Markdown() | |
diff_view = gr.HTML() | |
with gr.Column(scale=1): | |
with gr.Accordion("Commit summary (AI generated)", open=False): | |
commit_summary = gr.Markdown() | |
commit_msg_start = gr.TextArea(label="commit_msg_start", visible=False) | |
commit_msg = gr.TextArea(label="commit_msg_end", show_label=False, | |
info="Commit message (can be scrollable)") | |
commit_msg_prev = gr.TextArea(visible=False) | |
commit_msg_history = gr.JSON(label="commit_msg_history", visible=False) | |
submit_btn = gr.Button("Submit") | |
session_val = gr.Textbox(info='Session', interactive=False, container=True, show_label=False, | |
label='session') | |
with gr.Row(visible=False): | |
sample_loaded_timestamp = gr.Textbox(info="Sample loaded", label='loaded_ts', interactive=False, | |
container=True, show_label=False) | |
now_timestamp = gr.Textbox(info="Current time", | |
interactive=False, container=True, show_label=False, | |
value=lambda: datetime.now().isoformat(), every=0.1, | |
label='submitted_ts') | |
commit_view = [ | |
github_link, | |
diff_view, | |
repo_val, | |
hash_val, | |
sample_loaded_timestamp, | |
commit_summary, | |
commit_msg_start, | |
commit_msg, | |
commit_msg_prev, | |
commit_msg_history | |
] | |
feedback_metadata = [ | |
session_val, | |
repo_val, | |
hash_val, | |
sample_loaded_timestamp, | |
now_timestamp | |
] | |
feedback_form = [ | |
commit_msg_start, | |
commit_msg, | |
commit_msg_history | |
] | |
saver.setup([current_sample_sld] + feedback_metadata + feedback_form, "feedback") | |
skip_btn.click(next_sample, inputs=[current_sample_sld, shuffled_idx_val], | |
outputs=[current_sample_sld] + commit_view) | |
def submit(current_sample, shuffled_idx, *args): | |
saver.flag((current_sample,) + args) | |
return next_sample(current_sample, shuffled_idx) | |
submit_btn.click( | |
submit, | |
inputs=[current_sample_sld, shuffled_idx_val] + feedback_metadata + feedback_form, | |
outputs=[current_sample_sld] + commit_view | |
) | |
def on_commit_msg_changed(message, prev_message, history): | |
timestamp = datetime.now().isoformat() | |
for i, s in enumerate(ndiff(prev_message, message)): | |
diff = char_diff_obj(s[0], i, s[-1], timestamp) | |
if diff['type'] in ('+', '-'): | |
history.append(diff) | |
return message, history | |
commit_msg.change(on_commit_msg_changed, inputs=[commit_msg, commit_msg_prev, commit_msg_history], | |
outputs=[commit_msg_prev, commit_msg_history]) | |
def init_session(current_sample): | |
session = str(uuid.uuid4()) | |
shuffled_idx = list(range(n_samples)) | |
random.shuffle(shuffled_idx) | |
return (session, shuffled_idx) + update_commit_view(shuffled_idx[current_sample]) | |
application.load(init_session, | |
inputs=[current_sample_sld], | |
outputs=[session_val, shuffled_idx_val] + commit_view, ) | |
application.launch() | |