SyedHutter commited on
Commit
bf424e9
·
verified ·
1 Parent(s): 58ae0db

Upload 2 files

Browse files

Adding app and req Commit 2

Files changed (2) hide show
  1. app.py +119 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ from transformers import pipeline
5
+ import torch
6
+
7
+ # Initialize the FastAPI app
8
+ app = FastAPI()
9
+
10
+ # Determine device (use GPU if available, otherwise CPU)
11
+ device = 0 if torch.cuda.is_available() else -1
12
+
13
+ # Initialize the NER pipeline
14
+ ner_pipeline = pipeline(
15
+ "ner",
16
+ model="dbmdz/bert-large-cased-finetuned-conll03-english",
17
+ aggregation_strategy="simple", # Updated to replace deprecated grouped_entities
18
+ device=device
19
+ )
20
+
21
+ # Initialize the QA pipeline
22
+ qa_pipeline = pipeline(
23
+ "question-answering",
24
+ model="deepset/roberta-base-squad2",
25
+ device=device
26
+ )
27
+
28
+ # Allowed domains for filtering
29
+ allowed_domains = [
30
+ "clothing", "fashion", "shopping", "accessories", "sustainability", "shoes", "hats", "shirts",
31
+ "dresses", "pants", "jeans", "skirts", "jackets", "coats", "t-shirts", "sweaters", "hoodies",
32
+ "activewear", "formal wear", "casual wear", "sportswear", "outerwear", "swimwear", "underwear",
33
+ "lingerie", "socks", "scarves", "gloves", "belts", "ties", "caps", "beanies", "boots", "sandals",
34
+ "heels", "sneakers", "materials", "cotton", "polyester", "wool", "silk", "leather", "denim",
35
+ "linen", "athleisure", "ethnic wear", "fashion trends", "custom clothing", "tailoring",
36
+ "sustainable materials", "recycled clothing", "fashion brands", "streetwear"
37
+ ]
38
+
39
+ # Pydantic models for structured response
40
+ class Entity(BaseModel):
41
+ word: str
42
+ entity_group: str
43
+ score: float
44
+
45
+ class NERResponse(BaseModel):
46
+ entities: List[Entity]
47
+
48
+ class QAResponse(BaseModel):
49
+ question: str
50
+ answer: str
51
+ score: float
52
+
53
+ class CombinedRequest(BaseModel):
54
+ text: str # The input text prompt
55
+
56
+ class CombinedResponse(BaseModel):
57
+ ner: NERResponse # NER output
58
+ qa: QAResponse # QA output
59
+
60
+ # Function to check if the input text belongs to allowed domains
61
+ def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool:
62
+ for domain in domains:
63
+ if domain in text.lower():
64
+ return True
65
+ return False
66
+
67
+ # Combined endpoint for NER and QA with domain filtering
68
+ @app.post("/process/", response_model=CombinedResponse)
69
+ async def process_request(request: CombinedRequest):
70
+ """
71
+ Process the input text for both NER and QA, returning both responses,
72
+ only if the text matches the allowed domains.
73
+ """
74
+ input_text = request.text
75
+
76
+ # Check if the input text belongs to the allowed domains
77
+ if not is_text_in_allowed_domain(input_text, allowed_domains):
78
+ raise HTTPException(
79
+ status_code=400,
80
+ detail=(
81
+ "The input text does not match the allowed domains. "
82
+ "Please provide a query related to clothing, fashion, or accessories."
83
+ )
84
+ )
85
+
86
+ # Perform Named Entity Recognition (NER)
87
+ ner_entities = ner_pipeline(input_text)
88
+
89
+ # Process the NER entities into the required format
90
+ formatted_entities = [
91
+ {
92
+ "word": entity["word"],
93
+ "entity_group": entity["entity_group"],
94
+ "score": float(entity["score"]), # Convert numpy.float32 to Python float
95
+ }
96
+ for entity in ner_entities
97
+ ]
98
+ ner_response = {"entities": formatted_entities}
99
+
100
+ # Perform Question Answering (QA)
101
+ qa_result = qa_pipeline(question=input_text, context=input_text)
102
+ qa_result["score"] = float(qa_result["score"]) # Convert numpy.float32 to Python float
103
+
104
+ qa_response = {
105
+ "question": input_text,
106
+ "answer": qa_result["answer"],
107
+ "score": qa_result["score"]
108
+ }
109
+
110
+ # Return both NER and QA responses
111
+ return {"ner": ner_response, "qa": qa_response}
112
+
113
+ # Root endpoint
114
+ @app.get("/")
115
+ async def root():
116
+ """
117
+ Root endpoint to confirm the server is running.
118
+ """
119
+ return {"message": "Welcome to the filtered NER and QA API!"}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.98.0
2
+ uvicorn[standard]==0.23.2
3
+ transformers==4.34.0
4
+ torch==2.0.1
5
+ pydantic==1.10.9
6
+ numpy<2.0 # Compatibility with PyTorch and Transformers