Spaces:
Sleeping
Sleeping
# app.py | |
from transformers import pipeline | |
# Load Zero-Shot Classification Model (for detecting Requirement Type, Domain, Stakeholders, and Defects) | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
# Load T5 model for Rewriting (Paraphrasing) | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small") | |
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small") | |
# Function to classify requirement type (Functional/Non-Functional) | |
def classify_requirement_type(requirement): | |
candidate_labels = ["Functional", "Non-Functional"] | |
result = classifier(requirement, candidate_labels) | |
return result['labels'][0] | |
# Function to identify stakeholders | |
def identify_stakeholders(requirement): | |
candidate_labels = ["End User", "Developer", "System Analyst", "Project Manager"] | |
result = classifier(requirement, candidate_labels) | |
return result['labels'][0] | |
# Function to classify domain of the requirement | |
def classify_domain(requirement): | |
candidate_labels = ["Bank", "Healthcare", "Education", "Finance", "Cybersecurity", "E-commerce"] | |
result = classifier(requirement, candidate_labels) | |
return result['labels'][0] | |
# Function to detect defects (e.g., Ambiguity, Incompleteness) | |
def detect_defects(requirement): | |
candidate_labels = ["Ambiguity", "Incompleteness", "Security Flaw", "Redundancy", "Performance Issue"] | |
result = classifier(requirement, candidate_labels) | |
return result['labels'][0] | |
# Function to rewrite the requirement in a simpler way using T5 | |
def rewrite_requirement(requirement): | |
input_text = "paraphrase: " + requirement | |
input_ids = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True) | |
output_ids = t5_model.generate(input_ids, max_length=150, num_beams=5, early_stopping=True) | |
output_text = t5_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return output_text | |
# Main function to take input requirement and process it | |
def process_requirement(requirement): | |
requirement_type = classify_requirement_type(requirement) | |
stakeholder = identify_stakeholders(requirement) | |
domain = classify_domain(requirement) | |
defects = detect_defects(requirement) | |
rewritten_requirement = rewrite_requirement(requirement) | |
return { | |
"Requirement Type": requirement_type, | |
"Stakeholder": stakeholder, | |
"Domain": domain, | |
"Defects": defects, | |
"Rewritten Requirement": rewritten_requirement | |
} | |
if __name__ == "__main__": | |
# Example usage: | |
requirement = input("Enter the software requirement: ") | |
result = process_requirement(requirement) | |
# Print the results | |
for key, value in result.items(): | |
print(f"{key}: {value}") | |