karths commited on
Commit
c1d360b
·
verified ·
1 Parent(s): cf669ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -85
app.py CHANGED
@@ -1,45 +1,53 @@
1
- import json
2
  import gradio as gr
3
  import os
4
- import requests
5
- from huggingface_hub import AsyncInferenceClient
6
 
7
- HF_TOKEN = os.getenv('HF_TOKEN')
8
- api_url = os.getenv('API_URL')
9
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
10
- client = AsyncInferenceClient(api_url)
11
 
12
 
 
 
 
13
 
14
- title = "Python maintainability refactoring"
 
 
 
 
 
 
 
 
 
 
 
 
15
  description = """
16
  ## Instructions for Using the Model
17
-
18
  ### Model Loading Time:
19
- - Please allow 2 to 3 minutes for the model to load and run, especially on the first usage which might experience a "Cold Start."
20
-
21
  ### Code Submission:
22
- - You can enter or paste your python code you wish to have refactored, or use the provided example.
23
-
24
  ### Python Code Constraints:
25
- - When using this tool, keep the code under 120 lines due to GPU constraints.
26
-
27
  ### Understanding Changes:
28
- - It's important to read the "Changes made" section at the end of the refactored code response. This will help in understanding what modifications have been made to enhance the maintainability and readability of the code.
29
-
30
  ### Usage Recommendation:
31
- - Do not use this for personal projects; try it for research purposes only, as running these GPUs is costly.
 
32
 
33
- """
34
- system_prompt = """
35
- ### Instruction:
36
- Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with the comments on the changes made for improving the metrics.
37
  ### Input:
38
-
39
  """
 
40
  css = """.toast-wrap { display: none !important } """
41
- examples=[ ["""
42
- def analyze_sales_data(sales_records):
 
43
  active_sales = filter(lambda record: record['status'] == 'active', sales_records)
44
  sales_by_category = {}
45
  for record in active_sales:
@@ -51,102 +59,151 @@ def analyze_sales_data(sales_records):
51
  sales_by_category[category]['total_units'] += record['units_sold']
52
  average_sales_data = []
53
  for category, data in sales_by_category.items():
54
- average_sales = data['total_sales'] / data['total_units']
55
  sales_by_category[category]['average_sales'] = average_sales
56
  average_sales_data.append((category, average_sales))
57
  average_sales_data.sort(key=lambda x: x[1], reverse=True)
58
  for rank, (category, _) in enumerate(average_sales_data, start=1):
59
  sales_by_category[category]['rank'] = rank
60
- return sales_by_category
61
- """] ,
62
- ["""
63
- import pandas as pd
64
  import re
65
  import ast
66
- from code_bert_score import score
67
  import numpy as np
 
68
  def preprocess_code(source_text):
69
-
70
  def remove_comments_and_docstrings(source_code):
 
71
  source_code = re.sub(r'#.*', '', source_code)
 
72
  source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
73
- return source_code
 
 
74
  pattern = r"```python\s+(.+?)\s+```"
75
  matches = re.findall(pattern, source_text, re.DOTALL)
76
  code_to_process = '\n'.join(matches) if matches else source_text
 
77
  cleaned_code = remove_comments_and_docstrings(code_to_process)
78
  return cleaned_code
 
79
  def evaluate_dataframe(df):
80
-
81
  results = {'P': [], 'R': [], 'F1': [], 'F3': []}
82
  for index, row in df.iterrows():
83
  try:
84
- cands = [preprocess_code(row['generated_text'])]
85
- refs = [preprocess_code(row['output'])]
86
- P, R, F1, F3 = score(cands, refs, lang='python')
87
- results['P'].append(P[0])
88
- results['R'].append(R[0])
89
- results['F1'].append(F1[0])
90
- results['F3'].append(F3[0])
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
  print(f"Error processing row {index}: {e}")
93
  for key in results.keys():
94
- results[key].append(None)
 
95
  df_metrics = pd.DataFrame(results)
96
  return df_metrics
 
97
  def evaluate_dataframe_multiple_runs(df, runs=3):
98
-
99
  all_results = []
 
100
  for run in range(runs):
101
- df_metrics = evaluate_dataframe(df)
 
102
  all_results.append(df_metrics)
103
- # Calculate mean and std deviation of metrics across runs
104
- df_metrics_mean = pd.concat(all_results).groupby(level=0).mean()
105
- df_metrics_std = pd.concat(all_results).groupby(level=0).std()
106
- return df_metrics_mean, df_metrics_std
107
- """ ] ]
108
-
109
-
110
- # Stream text - stream tokens with InferenceClient from TGI
111
- async def predict(message, chatbot, temperature=0.1, max_new_tokens=4096, top_p=0.6, repetition_penalty=1.15,):
112
-
113
-
114
-
115
- temperature = float(temperature)
116
- if temperature < 1e-2:
117
- temperature = 1e-2
118
- top_p = float(top_p)
119
-
120
-
121
- input_prompt = system_prompt + str(message) + " [/INST] "
122
-
123
- partial_message = ""
124
- async for token in await client.text_generation(prompt=input_prompt,
125
- max_new_tokens=max_new_tokens,
126
- stream=True,
127
- best_of=1,
128
- temperature=temperature,
129
- top_p=top_p,
130
- do_sample=True,
131
- repetition_penalty=repetition_penalty):
132
- partial_message = partial_message + token
133
- yield partial_message
134
-
135
-
136
-
137
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  gr.ChatInterface(
139
  predict,
140
- chatbot=gr.Chatbot(height=500),
141
- textbox=gr.Textbox(lines=10, label="Python Code" , placeholder="Enter or Paste your Python code here..."),
142
  title=title,
143
  description=description,
144
- theme="abidlabs/Lime",
145
  examples=examples,
146
- cache_examples=False,
147
- submit_btn = "Submit_code",
148
  retry_btn="Retry",
149
  undo_btn="Undo",
150
  clear_btn="Clear",
151
- ).queue().launch(share=True)
 
152
 
 
 
1
  import gradio as gr
2
  import os
3
+ import json
4
+ from openai import OpenAI
5
 
6
+ # Load sensitive information from environment variables
7
+ RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY')
8
+ RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')
 
9
 
10
 
11
+ BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1"
12
+ MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod
13
+ MAX_TOKENS = 4096# Max tokens for the model response
14
 
15
+ # --- OpenAI Client Initialization ---
16
+ # Check if the API key is provided
17
+ if not RUNPOD_API_KEY:
18
+ raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable or add it directly in the script.")
19
+
20
+ # Initialize the OpenAI client to connect to the RunPod endpoint
21
+ client = OpenAI(
22
+ api_key=RUNPOD_API_KEY,
23
+ base_url=BASE_URL,
24
+ )
25
+
26
+ # --- Gradio App Configuration ---
27
+ title = "Python Maintainability Refactoring (RunPod)"
28
  description = """
29
  ## Instructions for Using the Model
 
30
  ### Model Loading Time:
31
+ - Please allow time for the model on RunPod to initialize if it's starting fresh ("Cold Start").
 
32
  ### Code Submission:
33
+ - You can enter or paste your Python code you wish to have refactored, or use the provided example.
 
34
  ### Python Code Constraints:
35
+ - Keep the code reasonably sized. While the 120-line limit was for the previous setup, large code blocks might still face limitations depending on the RunPod instance and model constraints. Max response length is set to {} tokens.
 
36
  ### Understanding Changes:
37
+ - It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made.
 
38
  ### Usage Recommendation:
39
+ - Intended for research and evaluation purposes.
40
+ """.format(MAX_TOKENS)
41
 
42
+ system_prompt = """### Instruction:
43
+ Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics.
 
 
44
  ### Input:
 
45
  """
46
+
47
  css = """.toast-wrap { display: none !important } """
48
+
49
+ examples = [
50
+ ["""def analyze_sales_data(sales_records):
51
  active_sales = filter(lambda record: record['status'] == 'active', sales_records)
52
  sales_by_category = {}
53
  for record in active_sales:
 
59
  sales_by_category[category]['total_units'] += record['units_sold']
60
  average_sales_data = []
61
  for category, data in sales_by_category.items():
62
+ average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero
63
  sales_by_category[category]['average_sales'] = average_sales
64
  average_sales_data.append((category, average_sales))
65
  average_sales_data.sort(key=lambda x: x[1], reverse=True)
66
  for rank, (category, _) in enumerate(average_sales_data, start=1):
67
  sales_by_category[category]['rank'] = rank
68
+ return sales_by_category"""],
69
+ ["""import pandas as pd
 
 
70
  import re
71
  import ast
72
+ from code_bert_score import score # Assuming this library is available in the environment
73
  import numpy as np
74
+
75
  def preprocess_code(source_text):
 
76
  def remove_comments_and_docstrings(source_code):
77
+ # Remove single-line comments
78
  source_code = re.sub(r'#.*', '', source_code)
79
+ # Remove multi-line strings (docstrings)
80
  source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
81
+ return source_code.strip() # Added strip
82
+
83
+ # Pattern to extract code specifically from markdown blocks if present
84
  pattern = r"```python\s+(.+?)\s+```"
85
  matches = re.findall(pattern, source_text, re.DOTALL)
86
  code_to_process = '\n'.join(matches) if matches else source_text
87
+
88
  cleaned_code = remove_comments_and_docstrings(code_to_process)
89
  return cleaned_code
90
+
91
  def evaluate_dataframe(df):
 
92
  results = {'P': [], 'R': [], 'F1': [], 'F3': []}
93
  for index, row in df.iterrows():
94
  try:
95
+ # Ensure inputs are lists of strings
96
+ cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion
97
+ refs = [preprocess_code(str(row['output']))] # Added str() conversion
98
+
99
+ # Ensure code_bert_score.score returns four values
100
+ score_results = score(cands, refs, lang='python')
101
+ if len(score_results) == 4:
102
+ P, R, F1, F3 = score_results
103
+ results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output
104
+ results['R'].append(R.item() if hasattr(R, 'item') else R)
105
+ results['F1'].append(F1.item() if hasattr(F1, 'item') else F1)
106
+ results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned
107
+ else:
108
+ print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.")
109
+ for key in results.keys():
110
+ results[key].append(np.nan) # Append NaN for unexpected format
111
+
112
  except Exception as e:
113
  print(f"Error processing row {index}: {e}")
114
  for key in results.keys():
115
+ results[key].append(np.nan) # Use NaN for errors
116
+
117
  df_metrics = pd.DataFrame(results)
118
  return df_metrics
119
+
120
  def evaluate_dataframe_multiple_runs(df, runs=3):
 
121
  all_results = []
122
+ print(f"Starting evaluation for {runs} runs...")
123
  for run in range(runs):
124
+ print(f"Run {run + 1}/{runs}")
125
+ df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified
126
  all_results.append(df_metrics)
127
+ print(f"Run {run + 1} completed.")
128
+
129
+ if not all_results:
130
+ print("No results collected.")
131
+ return pd.DataFrame(), pd.DataFrame()
132
+
133
+ # Concatenate results and calculate statistics
134
+ try:
135
+ concatenated_results = pd.concat(all_results)
136
+ df_metrics_mean = concatenated_results.groupby(level=0).mean()
137
+ df_metrics_std = concatenated_results.groupby(level=0).std()
138
+ print("Mean and standard deviation calculated.")
139
+ except Exception as e:
140
+ print(f"Error calculating statistics: {e}")
141
+ # Return empty DataFrames or handle as appropriate
142
+ return pd.DataFrame(), pd.DataFrame()
143
+
144
+ return df_metrics_mean, df_metrics_std"""]
145
+ ]
146
+
147
+ # --- Core Logic ---
148
+ def gen_solution(prompt):
149
+ """
150
+ Generates a solution for a given problem prompt by calling the LLM via RunPod.
151
+
152
+ Parameters:
153
+ - prompt (str): The problem prompt including the system message and user input.
154
+
155
+ Returns:
156
+ - str: The generated solution text, or an error message.
157
+ """
158
+ try:
159
+ # Call the OpenAI compatible endpoint on RunPod
160
+ completion = client.chat.completions.create(
161
+ model=MODEL_NAME,
162
+ messages=[{"role": "user", "content": prompt}],
163
+ temperature=0.1, # Keep temperature low for more deterministic refactoring
164
+ top_p=1.0,
165
+ max_tokens=MAX_TOKENS,
166
+ # stream=False # Explicitly setting stream to False (default)
167
+ )
168
+ # Extract the response content
169
+ response_content = completion.choices[0].message.content
170
+ return response_content
171
+
172
+ except Exception as e:
173
+ print(f"Error calling RunPod API: {e}")
174
+ # Provide a user-friendly error message
175
+ return f"Error: Could not get response from the model. Details: {str(e)}"
176
+
177
+ # --- Gradio Interface Function ---
178
+ def predict(message, history):
179
+ """
180
+ Handles the user input, calls the backend model, and returns the response.
181
+ 'history' parameter is required by gr.ChatInterface but might not be used here.
182
+ """
183
+ # Construct the full prompt
184
+ input_prompt = system_prompt + str(message) # Using the format from the original code
185
+
186
+ # Get the refactored code from the backend
187
+ refactored_code_response = gen_solution(input_prompt)
188
+
189
+ # The response is returned directly to the ChatInterface
190
+ return refactored_code_response
191
+
192
+ # --- Launch Gradio Interface ---
193
+ # Use gr.ChatInterface for a chat-like experience
194
  gr.ChatInterface(
195
  predict,
196
+ chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation"),
197
+ textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."),
198
  title=title,
199
  description=description,
200
+ theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default()
201
  examples=examples,
202
+ cache_examples=False, # Consider enabling caching if examples don't change often
203
+ submit_btn="Submit Code",
204
  retry_btn="Retry",
205
  undo_btn="Undo",
206
  clear_btn="Clear",
207
+ css=css # Apply custom CSS if needed
208
+ ).queue().launch(share=True) # share=True creates a public link (use with caution)
209