karths commited on
Commit
709da00
·
verified ·
1 Parent(s): de24b75

Create app_test.py

Browse files
Files changed (1) hide show
  1. app_test.py +241 -0
app_test.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from huggingface_hub import login, HfFolder
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, TextIteratorStreamer
8
+ from scipy.special import softmax
9
+ import logging
10
+ import spaces
11
+ from threading import Thread
12
+ from collections.abc import Iterator
13
+
14
+ # Setup logging
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
16
+
17
+ # Set a seed for reproducibility
18
+ seed = 42
19
+ np.random.seed(seed)
20
+ random.seed(seed)
21
+ torch.manual_seed(seed)
22
+ if torch.cuda.is_available():
23
+ torch.cuda.manual_seed_all(seed)
24
+
25
+ # Login to Hugging Face
26
+ token = os.getenv("hf_token")
27
+ HfFolder.save_token(token)
28
+ login(token)
29
+
30
+ # --- Quality Prediction Model Setup ---
31
+ model_paths = [
32
+ 'karths/binary_classification_train_test',
33
+ "karths/binary_classification_train_process",
34
+ "karths/binary_classification_train_infrastructure",
35
+ "karths/binary_classification_train_documentation",
36
+ "karths/binary_classification_train_design",
37
+ "karths/binary_classification_train_defect",
38
+ "karths/binary_classification_train_code",
39
+ "karths/binary_classification_train_build",
40
+ "karths/binary_classification_train_automation",
41
+ "karths/binary_classification_train_people",
42
+ "karths/binary_classification_train_architecture",
43
+ ]
44
+
45
+ quality_mapping = {
46
+ 'binary_classification_train_test': 'Test',
47
+ 'binary_classification_train_process': 'Process',
48
+ 'binary_classification_train_infrastructure': 'Infrastructure',
49
+ 'binary_classification_train_documentation': 'Documentation',
50
+ 'binary_classification_train_design': 'Design',
51
+ 'binary_classification_train_defect': 'Defect',
52
+ 'binary_classification_train_code': 'Code',
53
+ 'binary_classification_train_build': 'Build',
54
+ 'binary_classification_train_automation': 'Automation',
55
+ 'binary_classification_train_people': 'People',
56
+ 'binary_classification_train_architecture': 'Architecture'
57
+ }
58
+
59
+ # Pre-load models and tokenizer for quality prediction
60
+ tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
61
+ models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths}
62
+
63
+ def get_quality_name(model_name):
64
+ return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
65
+
66
+ @spaces.GPU
67
+ def model_prediction(model, text, device):
68
+ model.to(device)
69
+ model.eval()
70
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
71
+ inputs = {k: v.to(device) for k, v in inputs.items()}
72
+ with torch.no_grad():
73
+ outputs = model(**inputs)
74
+ logits = outputs.logits
75
+ probs = softmax(logits.cpu().numpy(), axis=1)
76
+ avg_prob = np.mean(probs[:, 1])
77
+ return avg_prob
78
+
79
+ # --- Llama 3.2 3B Model Setup ---
80
+ LLAMA_MAX_MAX_NEW_TOKENS = 2048
81
+ LLAMA_DEFAULT_MAX_NEW_TOKENS = 1024
82
+ LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
83
+ llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicitly define device
84
+ llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
85
+ llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
86
+ llama_model = AutoModelForCausalLM.from_pretrained(
87
+ llama_model_id,
88
+ device_map="auto", # Automatically distribute model across devices
89
+ torch_dtype=torch.bfloat16,
90
+ )
91
+ llama_model.eval()
92
+
93
+
94
+ @spaces.GPU(duration=90)
95
+ def llama_generate(
96
+ message: str,
97
+ max_new_tokens: int = LLAMA_DEFAULT_MAX_NEW_TOKENS,
98
+ temperature: float = 0.6,
99
+ top_p: float = 0.9,
100
+ top_k: int = 50,
101
+ repetition_penalty: float = 1.2,
102
+ ) -> Iterator[str]:
103
+
104
+ input_ids = llama_tokenizer.encode(message, return_tensors="pt").to(llama_model.device)
105
+
106
+ if input_ids.shape[1] > LLAMA_MAX_INPUT_TOKEN_LENGTH:
107
+ input_ids = input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
108
+ gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
109
+
110
+ streamer = TextIteratorStreamer(llama_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
111
+ generate_kwargs = dict(
112
+ {"input_ids": input_ids},
113
+ streamer=streamer,
114
+ max_new_tokens=max_new_tokens,
115
+ do_sample=True,
116
+ top_p=top_p,
117
+ top_k=top_k,
118
+ temperature=temperature,
119
+ num_beams=1,
120
+ repetition_penalty=repetition_penalty,
121
+ )
122
+ t = Thread(target=llama_model.generate, kwargs=generate_kwargs)
123
+ t.start()
124
+ outputs = []
125
+ for text in streamer:
126
+ outputs.append(text)
127
+ yield "".join(outputs)
128
+
129
+
130
+
131
+ def generate_explanation(issue_text, top_qualities):
132
+ """Generates an explanation using Llama 3.2 3B."""
133
+ if not top_qualities:
134
+ return "No explanation available as no quality tags were predicted."
135
+
136
+ prompt = f"""
137
+ Given the following issue description:
138
+ ---
139
+ {issue_text}
140
+ ---
141
+ Explain why this issue might be classified under the following quality categories: {', '.join([q[0] for q in top_qualities])}.
142
+ Provide a concise explanation for each category, relating it back to the issue description.
143
+ """
144
+ explanation = ""
145
+ try:
146
+ for chunk in llama_generate(prompt):
147
+ explanation += chunk # Accumulate generated text
148
+ except Exception as e:
149
+ logging.error(f"Error during Llama generation: {e}")
150
+ return "An error occurred while generating the explanation."
151
+
152
+ return explanation
153
+
154
+
155
+ def main_interface(text):
156
+ if not text.strip():
157
+ return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
158
+
159
+ if len(text) < 30:
160
+ return "<div style='color: red;'>Text is less than 30 characters.</div>", "", ""
161
+
162
+ device = "cuda" if torch.cuda.is_available() else "cpu"
163
+ results = []
164
+ for model_path, model in models.items():
165
+ quality_name = get_quality_name(model_path)
166
+ avg_prob = model_prediction(model, text, device)
167
+ if avg_prob >= 0.95:
168
+ results.append((quality_name, avg_prob))
169
+ logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
170
+
171
+ if not results:
172
+ return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold. </div>", "", ""
173
+
174
+ top_qualities = sorted(results, key=lambda x: x[1], reverse=True)[:3]
175
+ output_html = render_html_output(top_qualities)
176
+
177
+ # Generate explanation using the top qualities and the original input text
178
+ explanation = generate_explanation(text, top_qualities)
179
+
180
+ return output_html, "", explanation # Return explanation as the third output
181
+
182
+ def render_html_output(top_qualities):
183
+ styles = """
184
+ <style>
185
+ .quality-container {
186
+ font-family: Arial, sans-serif;
187
+ text-align: center;
188
+ margin-top: 20px;
189
+ }
190
+ .quality-label, .ranking {
191
+ display: inline-block;
192
+ padding: 0.5em 1em;
193
+ font-size: 18px;
194
+ font-weight: bold;
195
+ color: white;
196
+ background-color: #007bff;
197
+ border-radius: 0.5rem;
198
+ margin-right: 10px;
199
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
200
+ }
201
+ .probability {
202
+ display: block;
203
+ margin-top: 10px;
204
+ font-size: 16px;
205
+ color: #007bff;
206
+ }
207
+ </style>
208
+ """
209
+ html_content = ""
210
+ ranking_labels = ['Top 1 Prediction', 'Top 2 Prediction', 'Top 3 Prediction']
211
+ top_n = min(len(top_qualities), len(ranking_labels))
212
+ for i in range(top_n):
213
+ quality, prob = top_qualities[i]
214
+ html_content += f"""
215
+ <div class="quality-container">
216
+ <span class="ranking">{ranking_labels[i]}</span>
217
+ <span class="quality-label">{quality}</span>
218
+ </div>
219
+ """
220
+ return styles + html_content
221
+
222
+ example_texts = [
223
+ ["The algorithm does not accurately distinguish between the positive and negative classes during edge cases.\n\nEnvironment: Production\nReproduction: Run the classifier on the test dataset with known edge cases."],
224
+ ["The regression tests do not cover scenarios involving concurrent user sessions.\n\nEnvironment: Test automation suite\nReproduction: Update the test scripts to include tests for concurrent sessions."],
225
+ ["There is frequent miscommunication between the development and QA teams regarding feature specifications.\n\nEnvironment: Inter-team meetings\nReproduction: Audit recent communication logs and meeting notes between the teams."],
226
+ ["The service-oriented architecture does not effectively isolate failures, leading to cascading failures across services.\n\nEnvironment: Microservices architecture\nReproduction: Simulate a service failure and observe the impact on other services."]
227
+ ]
228
+
229
+ interface = gr.Interface(
230
+ fn=main_interface,
231
+ inputs=gr.Textbox(lines=7, label="Issue Description", placeholder="Enter your issue text here"),
232
+ outputs=[
233
+ gr.HTML(label="Prediction Output"),
234
+ gr.Textbox(label="Predictions", visible=False),
235
+ gr.Textbox(label="Explanation", lines=5) # Added Textbox for explanation
236
+ ],
237
+ title="QualityTagger",
238
+ description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
239
+ examples=example_texts
240
+ )
241
+ interface.launch(share=True)