SyedHutter commited on
Commit
b58b37b
·
verified ·
1 Parent(s): 144d63a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -0
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict
4
+ from transformers import pipeline
5
+ from itertools import groupby
6
+
7
+ # Initialize the FastAPI app
8
+ app = FastAPI()
9
+
10
+ # Initialize the NER pipeline
11
+ ner_pipeline = pipeline(
12
+ "ner",
13
+ model="dbmdz/bert-large-cased-finetuned-conll03-english",
14
+ grouped_entities=True
15
+ )
16
+
17
+ # Initialize the QA pipeline
18
+ qa_pipeline = pipeline(
19
+ "question-answering",
20
+ model="deepset/roberta-base-squad2"
21
+ )
22
+
23
+ # Allowed domains for filtering
24
+ allowed_domains = [
25
+ "clothing",
26
+ "fashion",
27
+ "shopping",
28
+ "accessories",
29
+ "sustainability",
30
+ "shoes",
31
+ "hats",
32
+ "shirts",
33
+ "dresses",
34
+ "pants",
35
+ "jeans",
36
+ "skirts",
37
+ "jackets",
38
+ "coats",
39
+ "t-shirts",
40
+ "sweaters",
41
+ "hoodies",
42
+ "activewear",
43
+ "formal wear",
44
+ "casual wear",
45
+ "sportswear",
46
+ "outerwear",
47
+ "swimwear",
48
+ "underwear",
49
+ "lingerie",
50
+ "socks",
51
+ "scarves",
52
+ "gloves",
53
+ "belts",
54
+ "ties",
55
+ "caps",
56
+ "beanies",
57
+ "boots",
58
+ "sandals",
59
+ "heels",
60
+ "sneakers",
61
+ "materials",
62
+ "cotton",
63
+ "polyester",
64
+ "wool",
65
+ "silk",
66
+ "leather",
67
+ "denim",
68
+ "linen",
69
+ "athleisure",
70
+ "ethnic wear",
71
+ "fashion trends",
72
+ "custom clothing",
73
+ "tailoring",
74
+ "sustainable materials",
75
+ "recycled clothing",
76
+ "fashion brands",
77
+ "streetwear"
78
+ ]
79
+
80
+ # Pydantic models for structured response
81
+ class Entity(BaseModel):
82
+ word: str
83
+ entity_group: str
84
+ score: float
85
+
86
+ class NERResponse(BaseModel):
87
+ entities: List[Entity]
88
+
89
+ class QAResponse(BaseModel):
90
+ question: str
91
+ answer: str
92
+ score: float
93
+
94
+ class CombinedRequest(BaseModel):
95
+ text: str # The input text prompt
96
+
97
+ class CombinedResponse(BaseModel):
98
+ ner: NERResponse # NER output
99
+ qa: QAResponse # QA output
100
+
101
+ # Function to check if the input text belongs to allowed domains
102
+ def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool:
103
+ for domain in domains:
104
+ if domain in text.lower():
105
+ return True
106
+ return False
107
+
108
+ # Combined endpoint for NER and QA with domain filtering
109
+ @app.post("/process/", response_model=CombinedResponse)
110
+ async def process_request(request: CombinedRequest):
111
+ """
112
+ Process the input text for both NER and QA, returning both responses,
113
+ only if the text matches the allowed domains.
114
+ """
115
+ input_text = request.text
116
+
117
+ # Check if the input text belongs to the allowed domains
118
+ if not is_text_in_allowed_domain(input_text, allowed_domains):
119
+ raise HTTPException(
120
+ status_code=400,
121
+ detail=(
122
+ "The input text does not match the allowed domains. "
123
+ "Please provide a query related to clothing, fashion, or accessories."
124
+ )
125
+ )
126
+
127
+ # Perform Named Entity Recognition (NER)
128
+ ner_entities = ner_pipeline(input_text)
129
+
130
+ # Process the NER entities into the required format
131
+ formatted_entities = [
132
+ {
133
+ "word": entity["word"],
134
+ "entity_group": entity["entity_group"],
135
+ "score": float(entity["score"]), # Convert numpy.float32 to Python float
136
+ }
137
+ for entity in ner_entities
138
+ ]
139
+ ner_response = {"entities": formatted_entities}
140
+
141
+ # Perform Question Answering (QA)
142
+ qa_result = qa_pipeline(question=input_text, context=input_text)
143
+ qa_result["score"] = float(qa_result["score"]) # Convert numpy.float32 to Python float
144
+
145
+ qa_response = {
146
+ "question": input_text,
147
+ "answer": qa_result["answer"],
148
+ "score": qa_result["score"]
149
+ }
150
+
151
+ # Return both NER and QA responses
152
+ return {"ner": ner_response, "qa": qa_response}
153
+
154
+
155
+ # Root endpoint
156
+ @app.get("/")
157
+ async def root():
158
+ """
159
+ Root endpoint to confirm the server is running.
160
+ """
161
+ return {"message": "Welcome to the filtered NER and QA API!"}
162
+
163
+
164
+ # JSon response
165
+
166
+ # {
167
+ # "entities": [
168
+ # {
169
+ # "word": "Nike",
170
+ # "entity_group": "ORG",
171
+ # "score": 0.995
172
+ # },
173
+ # {
174
+ # "word": "running shoes",
175
+ # "entity_group": "PRODUCT",
176
+ # "score": 0.987
177
+ # },
178
+ # {
179
+ # "word": "outdoor activities",
180
+ # "entity_group": "ACTIVITY",
181
+ # "score": 0.960
182
+ # }
183
+ # ]
184
+ # }
185
+ # {
186
+ # "question": "Can you suggest comfortable Nike running shoes for outdoor activities?",
187
+ # "answer": "Nike Air Zoom Pegasus or React Infinity Run are great options for outdoor running.",
188
+ # "score": 0.978
189
+ # }