iisadia commited on
Commit
ce64498
·
verified ·
1 Parent(s): dc7a1eb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ from transformers import pipeline
4
+
5
+ # Load Zero-Shot Classification Model (for detecting Requirement Type, Domain, Stakeholders, and Defects)
6
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
7
+
8
+ # Load T5 model for Rewriting (Paraphrasing)
9
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
10
+
11
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
12
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
13
+
14
+ # Function to classify requirement type (Functional/Non-Functional)
15
+ def classify_requirement_type(requirement):
16
+ candidate_labels = ["Functional", "Non-Functional"]
17
+ result = classifier(requirement, candidate_labels)
18
+ return result['labels'][0]
19
+
20
+ # Function to identify stakeholders
21
+ def identify_stakeholders(requirement):
22
+ candidate_labels = ["End User", "Developer", "System Analyst", "Project Manager"]
23
+ result = classifier(requirement, candidate_labels)
24
+ return result['labels'][0]
25
+
26
+ # Function to classify domain of the requirement
27
+ def classify_domain(requirement):
28
+ candidate_labels = ["Bank", "Healthcare", "Education", "Finance", "Cybersecurity", "E-commerce"]
29
+ result = classifier(requirement, candidate_labels)
30
+ return result['labels'][0]
31
+
32
+ # Function to detect defects (e.g., Ambiguity, Incompleteness)
33
+ def detect_defects(requirement):
34
+ candidate_labels = ["Ambiguity", "Incompleteness", "Security Flaw", "Redundancy", "Performance Issue"]
35
+ result = classifier(requirement, candidate_labels)
36
+ return result['labels'][0]
37
+
38
+ # Function to rewrite the requirement in a simpler way using T5
39
+ def rewrite_requirement(requirement):
40
+ input_text = "paraphrase: " + requirement
41
+ input_ids = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
42
+
43
+ output_ids = t5_model.generate(input_ids, max_length=150, num_beams=5, early_stopping=True)
44
+ output_text = t5_tokenizer.decode(output_ids[0], skip_special_tokens=True)
45
+
46
+ return output_text
47
+
48
+ # Main function to take input requirement and process it
49
+ def process_requirement(requirement):
50
+ requirement_type = classify_requirement_type(requirement)
51
+ stakeholder = identify_stakeholders(requirement)
52
+ domain = classify_domain(requirement)
53
+ defects = detect_defects(requirement)
54
+ rewritten_requirement = rewrite_requirement(requirement)
55
+
56
+ return {
57
+ "Requirement Type": requirement_type,
58
+ "Stakeholder": stakeholder,
59
+ "Domain": domain,
60
+ "Defects": defects,
61
+ "Rewritten Requirement": rewritten_requirement
62
+ }
63
+
64
+ if __name__ == "__main__":
65
+ # Example usage:
66
+ requirement = input("Enter the software requirement: ")
67
+ result = process_requirement(requirement)
68
+
69
+ # Print the results
70
+ for key, value in result.items():
71
+ print(f"{key}: {value}")