karths commited on
Commit
0e34d9a
·
verified ·
1 Parent(s): 66e2369

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ from openai import OpenAI
5
+ import sys # Added for flushing output in case of direct printing
6
+
7
+ # Load sensitive information from environment variables
8
+ RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY')
9
+ RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')
10
+
11
+ # --- Basic Input Validation ---
12
+ if not RUNPOD_API_KEY:
13
+ raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable.")
14
+ if not RUNPOD_ENDPOINT_ID:
15
+ raise ValueError("RunPod Endpoint ID not found. Please set the RUNPOD_ENDPOINT_ID environment variable.")
16
+
17
+ BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1"
18
+ MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod
19
+ MAX_TOKENS = 4096 # Max tokens for the model response
20
+
21
+ # --- OpenAI Client Initialization ---
22
+ client = OpenAI(
23
+ api_key=RUNPOD_API_KEY,
24
+ base_url=BASE_URL,
25
+ )
26
+
27
+ # --- Gradio App Configuration ---
28
+ title = "Python Maintainability Refactoring (RunPod - Streaming)"
29
+ description = """
30
+ ## Instructions for Using the Model
31
+ ### Model Loading Time:
32
+ - Please allow time for the model on RunPod to initialize if it's starting fresh ("Cold Start"). The response will appear token by token.
33
+ ### Code Submission:
34
+ - You can enter or paste your Python code you wish to have refactored, or use the provided example.
35
+ ### Python Code Constraints:
36
+ - Keep the code reasonably sized. Large code blocks might face limitations depending on the RunPod instance and model constraints. Max response length is set to {} tokens.
37
+ ### Understanding Changes:
38
+ - It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made.
39
+ ### Usage Recommendation:
40
+ - Intended for research and evaluation purposes.
41
+ """.format(MAX_TOKENS)
42
+
43
+ system_prompt = """### Instruction:
44
+ Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics.
45
+ ### Input:
46
+ """
47
+
48
+ css = """.toast-wrap { display: none !important } """
49
+
50
+ examples = [
51
+ ["""def analyze_sales_data(sales_records):
52
+ active_sales = filter(lambda record: record['status'] == 'active', sales_records)
53
+ sales_by_category = {}
54
+ for record in active_sales:
55
+ category = record['category']
56
+ total_sales = record['units_sold'] * record['price_per_unit']
57
+ if category not in sales_by_category:
58
+ sales_by_category[category] = {'total_sales': 0, 'total_units': 0}
59
+ sales_by_category[category]['total_sales'] += total_sales
60
+ sales_by_category[category]['total_units'] += record['units_sold']
61
+ average_sales_data = []
62
+ for category, data in sales_by_category.items():
63
+ average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero
64
+ sales_by_category[category]['average_sales'] = average_sales
65
+ average_sales_data.append((category, average_sales))
66
+ average_sales_data.sort(key=lambda x: x[1], reverse=True)
67
+ for rank, (category, _) in enumerate(average_sales_data, start=1):
68
+ sales_by_category[category]['rank'] = rank
69
+ return sales_by_category"""],
70
+ ["""import pandas as pd
71
+ import re
72
+ import ast
73
+ from code_bert_score import score # Assuming this library is available in the environment
74
+ import numpy as np
75
+
76
+ def preprocess_code(source_text):
77
+ def remove_comments_and_docstrings(source_code):
78
+ # Remove single-line comments
79
+ source_code = re.sub(r'#.*', '', source_code)
80
+ # Remove multi-line strings (docstrings)
81
+ source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
82
+ return source_code.strip() # Added strip
83
+
84
+ # Pattern to extract code specifically from markdown blocks if present
85
+ pattern = r"```python\s+(.+?)\s+```"
86
+ matches = re.findall(pattern, source_text, re.DOTALL)
87
+ code_to_process = '\n'.join(matches) if matches else source_text
88
+
89
+ cleaned_code = remove_comments_and_docstrings(code_to_process)
90
+ return cleaned_code
91
+
92
+ def evaluate_dataframe(df):
93
+ results = {'P': [], 'R': [], 'F1': [], 'F3': []}
94
+ for index, row in df.iterrows():
95
+ try:
96
+ # Ensure inputs are lists of strings
97
+ cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion
98
+ refs = [preprocess_code(str(row['output']))] # Added str() conversion
99
+
100
+ # Ensure code_bert_score.score returns four values
101
+ score_results = score(cands, refs, lang='python')
102
+ if len(score_results) == 4:
103
+ P, R, F1, F3 = score_results
104
+ results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output
105
+ results['R'].append(R.item() if hasattr(R, 'item') else R)
106
+ results['F1'].append(F1.item() if hasattr(F1, 'item') else F1)
107
+ results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned
108
+ else:
109
+ print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.")
110
+ for key in results.keys():
111
+ results[key].append(np.nan) # Append NaN for unexpected format
112
+
113
+ except Exception as e:
114
+ print(f"Error processing row {index}: {e}")
115
+ for key in results.keys():
116
+ results[key].append(np.nan) # Use NaN for errors
117
+
118
+ df_metrics = pd.DataFrame(results)
119
+ return df_metrics
120
+
121
+ def evaluate_dataframe_multiple_runs(df, runs=3):
122
+ all_results = []
123
+ print(f"Starting evaluation for {runs} runs...")
124
+ for run in range(runs):
125
+ print(f"Run {run + 1}/{runs}")
126
+ df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified
127
+ all_results.append(df_metrics)
128
+ print(f"Run {run + 1} completed.")
129
+
130
+ if not all_results:
131
+ print("No results collected.")
132
+ return pd.DataFrame(), pd.DataFrame()
133
+
134
+ # Concatenate results and calculate statistics
135
+ try:
136
+ concatenated_results = pd.concat(all_results)
137
+ df_metrics_mean = concatenated_results.groupby(level=0).mean()
138
+ df_metrics_std = concatenated_results.groupby(level=0).std()
139
+ print("Mean and standard deviation calculated.")
140
+ except Exception as e:
141
+ print(f"Error calculating statistics: {e}")
142
+ # Return empty DataFrames or handle as appropriate
143
+ return pd.DataFrame(), pd.DataFrame()
144
+
145
+ return df_metrics_mean, df_metrics_std"""]
146
+ ]
147
+
148
+ # --- Core Logic (Modified for Streaming) ---
149
+ def gen_solution_stream(prompt):
150
+ """
151
+ Generates a solution for a given problem prompt by calling the LLM via RunPod
152
+ and yielding the response chunks as they arrive (streaming).
153
+
154
+ Parameters:
155
+ - prompt (str): The problem prompt including the system message and user input.
156
+
157
+ Yields:
158
+ - str: Chunks of the generated solution text.
159
+ - str: An error message if an exception occurs.
160
+ """
161
+ try:
162
+ # Call the OpenAI compatible endpoint on RunPod with streaming enabled
163
+ stream = client.chat.completions.create(
164
+ model=MODEL_NAME,
165
+ messages=[{"role": "user", "content": prompt}],
166
+ temperature=0.1, # Keep temperature low for deterministic refactoring
167
+ top_p=1.0,
168
+ max_tokens=MAX_TOKENS,
169
+ stream=True # Enable streaming
170
+ )
171
+ # Yield content chunks from the stream
172
+ for chunk in stream:
173
+ if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
174
+ content = chunk.choices[0].delta.content
175
+ yield content
176
+ # Optional: Handle finish reason if needed
177
+ # if chunk.choices and chunk.choices[0].finish_reason:
178
+ # print(f"\nStream finished with reason: {chunk.choices[0].finish_reason}")
179
+
180
+ except Exception as e:
181
+ error_message = f"Error: Could not get streaming response from the model. Details: {str(e)}"
182
+ print(error_message, file=sys.stderr) # Log error to stderr
183
+ yield error_message # Yield the error message to be displayed in the UI
184
+
185
+ # --- Gradio Interface Function (Modified for Streaming) ---
186
+ def predict(message, history):
187
+ """
188
+ Handles the user input, calls the backend model stream, and yields the response chunks.
189
+ 'history' parameter is required by gr.ChatInterface but might not be used here.
190
+ """
191
+ # Construct the full prompt
192
+ input_prompt = system_prompt + str(message)
193
+
194
+ # Get the refactored code stream from the backend
195
+ response_stream = gen_solution_stream(input_prompt)
196
+
197
+ # Yield each chunk received from the stream generator
198
+ # Gradio's ChatInterface handles accumulating these yields into the chatbot output
199
+ buffer = ""
200
+ for chunk in response_stream:
201
+ buffer += chunk
202
+ yield buffer # Yield the accumulated buffer to update the UI incrementally
203
+
204
+ # --- Launch Gradio Interface ---
205
+ # Use gr.ChatInterface for a chat-like experience
206
+ gr.ChatInterface(
207
+ predict, # Pass the generator function
208
+ chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation", show_copy_button=True), # Added copy button
209
+ textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."),
210
+ title=title,
211
+ description=description,
212
+ theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default()
213
+ examples=examples,
214
+ cache_examples=False, # Consider enabling caching if examples don't change often
215
+ submit_btn="Submit Code",
216
+ retry_btn="Retry",
217
+ undo_btn="Undo",
218
+ clear_btn="Clear",
219
+ css=css # Apply custom CSS if needed
220
+ ).queue().launch(share=True) # share=True creates a public link (use with caution)