SyedHutter commited on
Commit
7727abe
·
verified ·
1 Parent(s): d86a00a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ from transformers import pipeline
5
+ import torch
6
+
7
+ # Initialize the FastAPI app
8
+ app = FastAPI()
9
+
10
+ # Determine device (use GPU if available, otherwise CPU)
11
+ device = 0 if torch.cuda.is_available() else -1
12
+
13
+ # Initialize the NER pipeline
14
+ ner_pipeline = pipeline(
15
+ "ner",
16
+ model="dbmdz/bert-large-cased-finetuned-conll03-english",
17
+ aggregation_strategy="simple",
18
+ device=device
19
+ )
20
+
21
+ # Initialize the QA pipeline
22
+ qa_pipeline = pipeline(
23
+ "question-answering",
24
+ model="deepset/roberta-base-squad2",
25
+ device=device
26
+ )
27
+
28
+ # Allowed domains for filtering
29
+ allowed_domains = [
30
+ "clothing", "fashion", "shopping", "accessories", "sustainability", "shoes", "hats", "shirts",
31
+ "dresses", "pants", "jeans", "skirts", "jackets", "coats", "t-shirts", "sweaters", "hoodies",
32
+ "activewear", "formal wear", "casual wear", "sportswear", "outerwear", "swimwear", "underwear",
33
+ "lingerie", "socks", "scarves", "gloves", "belts", "ties", "caps", "beanies", "boots", "sandals",
34
+ "heels", "sneakers", "materials", "cotton", "polyester", "wool", "silk", "leather", "denim",
35
+ "linen", "athleisure", "ethnic wear", "fashion trends", "custom clothing", "tailoring",
36
+ "sustainable materials", "recycled clothing", "fashion brands", "streetwear",
37
+ "footwear", "handbags", "jewelry", "watches", "eyewear", "cosmetics", "beauty products",
38
+ "personal care", "fragrances", "home decor", "lifestyle", "luxury goods", "vintage clothing",
39
+ "second-hand clothing", "upcycled fashion", "ethical fashion", "eco-friendly products",
40
+ "fashion technology", "textile innovation", "fashion marketing", "fashion retail"
41
+ ]
42
+
43
+ # Context for the QA pipeline
44
+ context_msg = (
45
+ "Hutter Products GmbH provides a wide array of services to help businesses create high-quality, sustainable products. "
46
+ "Their offerings include comprehensive product design, ensuring items are both visually appealing and functional, and product consulting, "
47
+ "which provides expert advice on features, materials, and design elements. They also offer sustainability consulting to integrate eco-friendly practices, "
48
+ "such as using recycled materials and Ocean Bound Plastic. Additionally, they manage customized production to ensure products meet the highest standards "
49
+ "and offer product animation services, creating realistic rendered images and animations to enhance online engagement. These services collectively enable "
50
+ "businesses to develop products that are sustainable, market-responsive, and aligned with their brand identity."
51
+ )
52
+
53
+ # Pydantic models for structured responses
54
+ class Entity(BaseModel):
55
+ word: str
56
+ entity_group: str
57
+ score: float
58
+
59
+ class NERResponse(BaseModel):
60
+ entities: List[Entity]
61
+ words: List[str] # List of extracted words (added)
62
+
63
+ class QAResponse(BaseModel):
64
+ question: str
65
+ answer: str
66
+ score: float
67
+
68
+ class CombinedRequest(BaseModel):
69
+ text: str # The input text prompt
70
+
71
+ class CombinedResponse(BaseModel):
72
+ ner: NERResponse # NER output
73
+ qa: QAResponse # QA output
74
+
75
+ # Function to check if the input text belongs to allowed domains
76
+ def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool:
77
+ for domain in domains:
78
+ if domain in text.lower():
79
+ return True
80
+ return False
81
+
82
+ # Combined endpoint for NER and QA with domain filtering
83
+ @app.post("/process/", response_model=CombinedResponse)
84
+ async def process_request(request: CombinedRequest):
85
+ input_text = request.text
86
+
87
+ # Check if the input text belongs to the allowed domains
88
+ if not is_text_in_allowed_domain(input_text, allowed_domains):
89
+ raise HTTPException(
90
+ status_code=400,
91
+ detail="The input text does not match the allowed domains. Please provide a query related to clothing, fashion, or accessories."
92
+ )
93
+
94
+ # Perform Named Entity Recognition (NER)
95
+ ner_entities = ner_pipeline(input_text)
96
+
97
+ # Process NER results into a structured response
98
+ formatted_entities = [
99
+ {
100
+ "word": entity["word"],
101
+ "entity_group": entity["entity_group"],
102
+ "score": float(entity["score"]),
103
+ }
104
+ for entity in ner_entities
105
+ ]
106
+ ner_words = [entity["word"] for entity in ner_entities] # Collect only the words
107
+
108
+ ner_response = {
109
+ "entities": formatted_entities,
110
+ "words": ner_words # Include the list of words
111
+ }
112
+
113
+ # Perform Question Answering (QA)
114
+ qa_result = qa_pipeline(question=input_text, context=context_msg)
115
+ qa_result["score"] = float(qa_result["score"]) # Convert numpy.float32 to Python float
116
+
117
+ qa_response = {
118
+ "question": input_text,
119
+ "answer": qa_result["answer"],
120
+ "score": qa_result["score"]
121
+ }
122
+
123
+ # Return both NER and QA responses
124
+ return {"ner": ner_response, "qa": qa_response}
125
+
126
+ # Root endpoint
127
+ @app.get("/")
128
+ async def root():
129
+ return {"message": "Welcome to the NER and QA API!"}