Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from openai import OpenAI
|
5 |
+
import sys # Added for flushing output in case of direct printing
|
6 |
+
|
7 |
+
# Load sensitive information from environment variables
|
8 |
+
RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY')
|
9 |
+
RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')
|
10 |
+
|
11 |
+
# --- Basic Input Validation ---
|
12 |
+
if not RUNPOD_API_KEY:
|
13 |
+
raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable.")
|
14 |
+
if not RUNPOD_ENDPOINT_ID:
|
15 |
+
raise ValueError("RunPod Endpoint ID not found. Please set the RUNPOD_ENDPOINT_ID environment variable.")
|
16 |
+
|
17 |
+
BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1"
|
18 |
+
MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod
|
19 |
+
MAX_TOKENS = 4096 # Max tokens for the model response
|
20 |
+
|
21 |
+
# --- OpenAI Client Initialization ---
|
22 |
+
client = OpenAI(
|
23 |
+
api_key=RUNPOD_API_KEY,
|
24 |
+
base_url=BASE_URL,
|
25 |
+
)
|
26 |
+
|
27 |
+
# --- Gradio App Configuration ---
|
28 |
+
title = "Python Maintainability Refactoring (RunPod - Streaming)"
|
29 |
+
description = """
|
30 |
+
## Instructions for Using the Model
|
31 |
+
### Model Loading Time:
|
32 |
+
- Please allow time for the model on RunPod to initialize if it's starting fresh ("Cold Start"). The response will appear token by token.
|
33 |
+
### Code Submission:
|
34 |
+
- You can enter or paste your Python code you wish to have refactored, or use the provided example.
|
35 |
+
### Python Code Constraints:
|
36 |
+
- Keep the code reasonably sized. Large code blocks might face limitations depending on the RunPod instance and model constraints. Max response length is set to {} tokens.
|
37 |
+
### Understanding Changes:
|
38 |
+
- It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made.
|
39 |
+
### Usage Recommendation:
|
40 |
+
- Intended for research and evaluation purposes.
|
41 |
+
""".format(MAX_TOKENS)
|
42 |
+
|
43 |
+
system_prompt = """### Instruction:
|
44 |
+
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics.
|
45 |
+
### Input:
|
46 |
+
"""
|
47 |
+
|
48 |
+
css = """.toast-wrap { display: none !important } """
|
49 |
+
|
50 |
+
examples = [
|
51 |
+
["""def analyze_sales_data(sales_records):
|
52 |
+
active_sales = filter(lambda record: record['status'] == 'active', sales_records)
|
53 |
+
sales_by_category = {}
|
54 |
+
for record in active_sales:
|
55 |
+
category = record['category']
|
56 |
+
total_sales = record['units_sold'] * record['price_per_unit']
|
57 |
+
if category not in sales_by_category:
|
58 |
+
sales_by_category[category] = {'total_sales': 0, 'total_units': 0}
|
59 |
+
sales_by_category[category]['total_sales'] += total_sales
|
60 |
+
sales_by_category[category]['total_units'] += record['units_sold']
|
61 |
+
average_sales_data = []
|
62 |
+
for category, data in sales_by_category.items():
|
63 |
+
average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero
|
64 |
+
sales_by_category[category]['average_sales'] = average_sales
|
65 |
+
average_sales_data.append((category, average_sales))
|
66 |
+
average_sales_data.sort(key=lambda x: x[1], reverse=True)
|
67 |
+
for rank, (category, _) in enumerate(average_sales_data, start=1):
|
68 |
+
sales_by_category[category]['rank'] = rank
|
69 |
+
return sales_by_category"""],
|
70 |
+
["""import pandas as pd
|
71 |
+
import re
|
72 |
+
import ast
|
73 |
+
from code_bert_score import score # Assuming this library is available in the environment
|
74 |
+
import numpy as np
|
75 |
+
|
76 |
+
def preprocess_code(source_text):
|
77 |
+
def remove_comments_and_docstrings(source_code):
|
78 |
+
# Remove single-line comments
|
79 |
+
source_code = re.sub(r'#.*', '', source_code)
|
80 |
+
# Remove multi-line strings (docstrings)
|
81 |
+
source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
|
82 |
+
return source_code.strip() # Added strip
|
83 |
+
|
84 |
+
# Pattern to extract code specifically from markdown blocks if present
|
85 |
+
pattern = r"```python\s+(.+?)\s+```"
|
86 |
+
matches = re.findall(pattern, source_text, re.DOTALL)
|
87 |
+
code_to_process = '\n'.join(matches) if matches else source_text
|
88 |
+
|
89 |
+
cleaned_code = remove_comments_and_docstrings(code_to_process)
|
90 |
+
return cleaned_code
|
91 |
+
|
92 |
+
def evaluate_dataframe(df):
|
93 |
+
results = {'P': [], 'R': [], 'F1': [], 'F3': []}
|
94 |
+
for index, row in df.iterrows():
|
95 |
+
try:
|
96 |
+
# Ensure inputs are lists of strings
|
97 |
+
cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion
|
98 |
+
refs = [preprocess_code(str(row['output']))] # Added str() conversion
|
99 |
+
|
100 |
+
# Ensure code_bert_score.score returns four values
|
101 |
+
score_results = score(cands, refs, lang='python')
|
102 |
+
if len(score_results) == 4:
|
103 |
+
P, R, F1, F3 = score_results
|
104 |
+
results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output
|
105 |
+
results['R'].append(R.item() if hasattr(R, 'item') else R)
|
106 |
+
results['F1'].append(F1.item() if hasattr(F1, 'item') else F1)
|
107 |
+
results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned
|
108 |
+
else:
|
109 |
+
print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.")
|
110 |
+
for key in results.keys():
|
111 |
+
results[key].append(np.nan) # Append NaN for unexpected format
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
print(f"Error processing row {index}: {e}")
|
115 |
+
for key in results.keys():
|
116 |
+
results[key].append(np.nan) # Use NaN for errors
|
117 |
+
|
118 |
+
df_metrics = pd.DataFrame(results)
|
119 |
+
return df_metrics
|
120 |
+
|
121 |
+
def evaluate_dataframe_multiple_runs(df, runs=3):
|
122 |
+
all_results = []
|
123 |
+
print(f"Starting evaluation for {runs} runs...")
|
124 |
+
for run in range(runs):
|
125 |
+
print(f"Run {run + 1}/{runs}")
|
126 |
+
df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified
|
127 |
+
all_results.append(df_metrics)
|
128 |
+
print(f"Run {run + 1} completed.")
|
129 |
+
|
130 |
+
if not all_results:
|
131 |
+
print("No results collected.")
|
132 |
+
return pd.DataFrame(), pd.DataFrame()
|
133 |
+
|
134 |
+
# Concatenate results and calculate statistics
|
135 |
+
try:
|
136 |
+
concatenated_results = pd.concat(all_results)
|
137 |
+
df_metrics_mean = concatenated_results.groupby(level=0).mean()
|
138 |
+
df_metrics_std = concatenated_results.groupby(level=0).std()
|
139 |
+
print("Mean and standard deviation calculated.")
|
140 |
+
except Exception as e:
|
141 |
+
print(f"Error calculating statistics: {e}")
|
142 |
+
# Return empty DataFrames or handle as appropriate
|
143 |
+
return pd.DataFrame(), pd.DataFrame()
|
144 |
+
|
145 |
+
return df_metrics_mean, df_metrics_std"""]
|
146 |
+
]
|
147 |
+
|
148 |
+
# --- Core Logic (Modified for Streaming) ---
|
149 |
+
def gen_solution_stream(prompt):
|
150 |
+
"""
|
151 |
+
Generates a solution for a given problem prompt by calling the LLM via RunPod
|
152 |
+
and yielding the response chunks as they arrive (streaming).
|
153 |
+
|
154 |
+
Parameters:
|
155 |
+
- prompt (str): The problem prompt including the system message and user input.
|
156 |
+
|
157 |
+
Yields:
|
158 |
+
- str: Chunks of the generated solution text.
|
159 |
+
- str: An error message if an exception occurs.
|
160 |
+
"""
|
161 |
+
try:
|
162 |
+
# Call the OpenAI compatible endpoint on RunPod with streaming enabled
|
163 |
+
stream = client.chat.completions.create(
|
164 |
+
model=MODEL_NAME,
|
165 |
+
messages=[{"role": "user", "content": prompt}],
|
166 |
+
temperature=0.1, # Keep temperature low for deterministic refactoring
|
167 |
+
top_p=1.0,
|
168 |
+
max_tokens=MAX_TOKENS,
|
169 |
+
stream=True # Enable streaming
|
170 |
+
)
|
171 |
+
# Yield content chunks from the stream
|
172 |
+
for chunk in stream:
|
173 |
+
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
|
174 |
+
content = chunk.choices[0].delta.content
|
175 |
+
yield content
|
176 |
+
# Optional: Handle finish reason if needed
|
177 |
+
# if chunk.choices and chunk.choices[0].finish_reason:
|
178 |
+
# print(f"\nStream finished with reason: {chunk.choices[0].finish_reason}")
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
error_message = f"Error: Could not get streaming response from the model. Details: {str(e)}"
|
182 |
+
print(error_message, file=sys.stderr) # Log error to stderr
|
183 |
+
yield error_message # Yield the error message to be displayed in the UI
|
184 |
+
|
185 |
+
# --- Gradio Interface Function (Modified for Streaming) ---
|
186 |
+
def predict(message, history):
|
187 |
+
"""
|
188 |
+
Handles the user input, calls the backend model stream, and yields the response chunks.
|
189 |
+
'history' parameter is required by gr.ChatInterface but might not be used here.
|
190 |
+
"""
|
191 |
+
# Construct the full prompt
|
192 |
+
input_prompt = system_prompt + str(message)
|
193 |
+
|
194 |
+
# Get the refactored code stream from the backend
|
195 |
+
response_stream = gen_solution_stream(input_prompt)
|
196 |
+
|
197 |
+
# Yield each chunk received from the stream generator
|
198 |
+
# Gradio's ChatInterface handles accumulating these yields into the chatbot output
|
199 |
+
buffer = ""
|
200 |
+
for chunk in response_stream:
|
201 |
+
buffer += chunk
|
202 |
+
yield buffer # Yield the accumulated buffer to update the UI incrementally
|
203 |
+
|
204 |
+
# --- Launch Gradio Interface ---
|
205 |
+
# Use gr.ChatInterface for a chat-like experience
|
206 |
+
gr.ChatInterface(
|
207 |
+
predict, # Pass the generator function
|
208 |
+
chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation", show_copy_button=True), # Added copy button
|
209 |
+
textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."),
|
210 |
+
title=title,
|
211 |
+
description=description,
|
212 |
+
theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default()
|
213 |
+
examples=examples,
|
214 |
+
cache_examples=False, # Consider enabling caching if examples don't change often
|
215 |
+
submit_btn="Submit Code",
|
216 |
+
retry_btn="Retry",
|
217 |
+
undo_btn="Undo",
|
218 |
+
clear_btn="Clear",
|
219 |
+
css=css # Apply custom CSS if needed
|
220 |
+
).queue().launch(share=True) # share=True creates a public link (use with caution)
|