Spaces:
Sleeping
Sleeping
import json | |
import gradio as gr | |
import os | |
import requests | |
from huggingface_hub import AsyncInferenceClient | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
api_url = os.getenv('API_URL') | |
headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
client = AsyncInferenceClient(api_url) | |
title = "Python maintainability refactoring" | |
description = """ | |
## Instructions for Using the Model | |
### Model Loading Time: | |
- Please allow 2 to 3 minutes for the model to load and run, especially on the first usage which might experience a "Cold Start." | |
### Code Submission: | |
- You can enter or paste your python code you wish to have refactored, or use the provided example. | |
### Python Code Constraints: | |
- When using this tool, keep the code under 120 lines due to GPU constraints. | |
### Understanding Changes: | |
- It's important to read the "Changes made" section at the end of the refactored code response. This will help in understanding what modifications have been made to enhance the maintainability and readability of the code. | |
### Usage Recommendation: | |
- Do not use this for personal projects; try it for research purposes only, as running these GPUs is costly. | |
""" | |
system_prompt = """ | |
### Instruction: | |
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with the comments on the changes made for improving the metrics. | |
### Input: | |
""" | |
css = """.toast-wrap { display: none !important } """ | |
examples=[ [""" | |
def analyze_sales_data(sales_records): | |
active_sales = filter(lambda record: record['status'] == 'active', sales_records) | |
sales_by_category = {} | |
for record in active_sales: | |
category = record['category'] | |
total_sales = record['units_sold'] * record['price_per_unit'] | |
if category not in sales_by_category: | |
sales_by_category[category] = {'total_sales': 0, 'total_units': 0} | |
sales_by_category[category]['total_sales'] += total_sales | |
sales_by_category[category]['total_units'] += record['units_sold'] | |
average_sales_data = [] | |
for category, data in sales_by_category.items(): | |
average_sales = data['total_sales'] / data['total_units'] | |
sales_by_category[category]['average_sales'] = average_sales | |
average_sales_data.append((category, average_sales)) | |
average_sales_data.sort(key=lambda x: x[1], reverse=True) | |
for rank, (category, _) in enumerate(average_sales_data, start=1): | |
sales_by_category[category]['rank'] = rank | |
return sales_by_category | |
"""] , | |
[""" | |
import pandas as pd | |
import re | |
import ast | |
from code_bert_score import score | |
import numpy as np | |
def preprocess_code(source_text): | |
def remove_comments_and_docstrings(source_code): | |
source_code = re.sub(r'#.*', '', source_code) | |
source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL) | |
return source_code | |
pattern = r"```python\s+(.+?)\s+```" | |
matches = re.findall(pattern, source_text, re.DOTALL) | |
code_to_process = '\n'.join(matches) if matches else source_text | |
cleaned_code = remove_comments_and_docstrings(code_to_process) | |
return cleaned_code | |
def evaluate_dataframe(df): | |
results = {'P': [], 'R': [], 'F1': [], 'F3': []} | |
for index, row in df.iterrows(): | |
try: | |
cands = [preprocess_code(row['generated_text'])] | |
refs = [preprocess_code(row['output'])] | |
P, R, F1, F3 = score(cands, refs, lang='python') | |
results['P'].append(P[0]) | |
results['R'].append(R[0]) | |
results['F1'].append(F1[0]) | |
results['F3'].append(F3[0]) | |
except Exception as e: | |
print(f"Error processing row {index}: {e}") | |
for key in results.keys(): | |
results[key].append(None) | |
df_metrics = pd.DataFrame(results) | |
return df_metrics | |
def evaluate_dataframe_multiple_runs(df, runs=3): | |
all_results = [] | |
for run in range(runs): | |
df_metrics = evaluate_dataframe(df) | |
all_results.append(df_metrics) | |
# Calculate mean and std deviation of metrics across runs | |
df_metrics_mean = pd.concat(all_results).groupby(level=0).mean() | |
df_metrics_std = pd.concat(all_results).groupby(level=0).std() | |
return df_metrics_mean, df_metrics_std | |
""" ] ] | |
# Stream text - stream tokens with InferenceClient from TGI | |
async def predict(message, chatbot, temperature=0.1, max_new_tokens=4096, top_p=0.6, repetition_penalty=1.15,): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
input_prompt = system_prompt + str(message) + " [/INST] " | |
partial_message = "" | |
async for token in await client.text_generation(prompt=input_prompt, | |
max_new_tokens=max_new_tokens, | |
stream=True, | |
best_of=1, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
repetition_penalty=repetition_penalty): | |
partial_message = partial_message + token | |
yield partial_message | |
gr.ChatInterface( | |
predict, | |
chatbot=gr.Chatbot(height=500), | |
textbox=gr.Textbox(lines=10, label="Python Code" , placeholder="Enter or Paste your Python code here..."), | |
title=title, | |
description=description, | |
theme="abidlabs/Lime", | |
examples=examples, | |
cache_examples=False, | |
submit_btn = "Submit_code", | |
retry_btn="Retry", | |
undo_btn="Undo", | |
clear_btn="Clear", | |
).queue().launch(share=True) | |