SyedHutter commited on
Commit
0888a3b
·
verified ·
1 Parent(s): 212a6ed

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ from transformers import pipeline
5
+ import torch
6
+
7
+ # Initialize the FastAPI app
8
+ app = FastAPI()
9
+
10
+ # Determine device (use GPU if available, otherwise CPU)
11
+ device = 0 if torch.cuda.is_available() else -1
12
+
13
+ # Initialize the NER pipeline
14
+ ner_pipeline = pipeline(
15
+ "ner",
16
+ model="dbmdz/bert-large-cased-finetuned-conll03-english",
17
+ aggregation_strategy="simple",
18
+ device=device
19
+ )
20
+
21
+ # Initialize the QA pipeline
22
+ qa_pipeline = pipeline(
23
+ "question-answering",
24
+ model="deepset/roberta-base-squad2",
25
+ device=device
26
+ )
27
+
28
+ # Allowed domains for filtering
29
+ allowed_domains = [
30
+ "clothing", "fashion", "shopping", "accessories", "pants", "jeans", "shirts", "sustainable materials"
31
+ ]
32
+
33
+ # Context for the QA pipeline
34
+ context_msg = """
35
+ We offer a wide variety of clothing options, including sustainable pants, jeans, chinos, and trousers.
36
+ Our products are made with eco-friendly materials and are available in styles such as casual wear, formal wear, and activewear.
37
+ """
38
+
39
+ # Pydantic models for structured responses
40
+ class Entity(BaseModel):
41
+ word: str
42
+ entity_group: str
43
+ score: float
44
+
45
+ class NERResponse(BaseModel):
46
+ entities: List[Entity]
47
+ words: List[str] # List of extracted words (added)
48
+
49
+ class QAResponse(BaseModel):
50
+ question: str
51
+ answer: str
52
+ score: float
53
+
54
+ class CombinedRequest(BaseModel):
55
+ text: str # The input text prompt
56
+
57
+ class CombinedResponse(BaseModel):
58
+ ner: NERResponse # NER output
59
+ qa: QAResponse # QA output
60
+
61
+ # Function to check if the input text belongs to allowed domains
62
+ def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool:
63
+ for domain in domains:
64
+ if domain in text.lower():
65
+ return True
66
+ return False
67
+
68
+ # Combined endpoint for NER and QA with domain filtering
69
+ @app.post("/process/", response_model=CombinedResponse)
70
+ async def process_request(request: CombinedRequest):
71
+ input_text = request.text
72
+
73
+ # Check if the input text belongs to the allowed domains
74
+ if not is_text_in_allowed_domain(input_text, allowed_domains):
75
+ raise HTTPException(
76
+ status_code=400,
77
+ detail="The input text does not match the allowed domains. Please provide a query related to clothing, fashion, or accessories."
78
+ )
79
+
80
+ # Perform Named Entity Recognition (NER)
81
+ ner_entities = ner_pipeline(input_text)
82
+
83
+ # Process NER results into a structured response
84
+ formatted_entities = [
85
+ {
86
+ "word": entity["word"],
87
+ "entity_group": entity["entity_group"],
88
+ "score": float(entity["score"]),
89
+ }
90
+ for entity in ner_entities
91
+ ]
92
+ ner_words = [entity["word"] for entity in ner_entities] # Collect only the words
93
+
94
+ ner_response = {
95
+ "entities": formatted_entities,
96
+ "words": ner_words # Include the list of words
97
+ }
98
+
99
+ # Perform Question Answering (QA)
100
+ qa_result = qa_pipeline(question=input_text, context=context_msg)
101
+ qa_result["score"] = float(qa_result["score"]) # Convert numpy.float32 to Python float
102
+
103
+ qa_response = {
104
+ "question": input_text,
105
+ "answer": qa_result["answer"],
106
+ "score": qa_result["score"]
107
+ }
108
+
109
+ # Return both NER and QA responses
110
+ return {"ner": ner_response, "qa": qa_response}
111
+
112
+ # Root endpoint
113
+ @app.get("/")
114
+ async def root():
115
+ return {"message": "Welcome to the NER and QA API!"}