Requirements / app.py
iisadia's picture
Create app.py
ce64498 verified
raw
history blame
2.82 kB
# app.py
from transformers import pipeline
# Load Zero-Shot Classification Model (for detecting Requirement Type, Domain, Stakeholders, and Defects)
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Load T5 model for Rewriting (Paraphrasing)
from transformers import T5ForConditionalGeneration, T5Tokenizer
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
# Function to classify requirement type (Functional/Non-Functional)
def classify_requirement_type(requirement):
candidate_labels = ["Functional", "Non-Functional"]
result = classifier(requirement, candidate_labels)
return result['labels'][0]
# Function to identify stakeholders
def identify_stakeholders(requirement):
candidate_labels = ["End User", "Developer", "System Analyst", "Project Manager"]
result = classifier(requirement, candidate_labels)
return result['labels'][0]
# Function to classify domain of the requirement
def classify_domain(requirement):
candidate_labels = ["Bank", "Healthcare", "Education", "Finance", "Cybersecurity", "E-commerce"]
result = classifier(requirement, candidate_labels)
return result['labels'][0]
# Function to detect defects (e.g., Ambiguity, Incompleteness)
def detect_defects(requirement):
candidate_labels = ["Ambiguity", "Incompleteness", "Security Flaw", "Redundancy", "Performance Issue"]
result = classifier(requirement, candidate_labels)
return result['labels'][0]
# Function to rewrite the requirement in a simpler way using T5
def rewrite_requirement(requirement):
input_text = "paraphrase: " + requirement
input_ids = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
output_ids = t5_model.generate(input_ids, max_length=150, num_beams=5, early_stopping=True)
output_text = t5_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output_text
# Main function to take input requirement and process it
def process_requirement(requirement):
requirement_type = classify_requirement_type(requirement)
stakeholder = identify_stakeholders(requirement)
domain = classify_domain(requirement)
defects = detect_defects(requirement)
rewritten_requirement = rewrite_requirement(requirement)
return {
"Requirement Type": requirement_type,
"Stakeholder": stakeholder,
"Domain": domain,
"Defects": defects,
"Rewritten Requirement": rewritten_requirement
}
if __name__ == "__main__":
# Example usage:
requirement = input("Enter the software requirement: ")
result = process_requirement(requirement)
# Print the results
for key, value in result.items():
print(f"{key}: {value}")