SyedHutter commited on
Commit
d86a00a
·
verified ·
1 Parent(s): 0888a3b

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -115
app.py DELETED
@@ -1,115 +0,0 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from typing import List
4
- from transformers import pipeline
5
- import torch
6
-
7
- # Initialize the FastAPI app
8
- app = FastAPI()
9
-
10
- # Determine device (use GPU if available, otherwise CPU)
11
- device = 0 if torch.cuda.is_available() else -1
12
-
13
- # Initialize the NER pipeline
14
- ner_pipeline = pipeline(
15
- "ner",
16
- model="dbmdz/bert-large-cased-finetuned-conll03-english",
17
- aggregation_strategy="simple",
18
- device=device
19
- )
20
-
21
- # Initialize the QA pipeline
22
- qa_pipeline = pipeline(
23
- "question-answering",
24
- model="deepset/roberta-base-squad2",
25
- device=device
26
- )
27
-
28
- # Allowed domains for filtering
29
- allowed_domains = [
30
- "clothing", "fashion", "shopping", "accessories", "pants", "jeans", "shirts", "sustainable materials"
31
- ]
32
-
33
- # Context for the QA pipeline
34
- context_msg = """
35
- We offer a wide variety of clothing options, including sustainable pants, jeans, chinos, and trousers.
36
- Our products are made with eco-friendly materials and are available in styles such as casual wear, formal wear, and activewear.
37
- """
38
-
39
- # Pydantic models for structured responses
40
- class Entity(BaseModel):
41
- word: str
42
- entity_group: str
43
- score: float
44
-
45
- class NERResponse(BaseModel):
46
- entities: List[Entity]
47
- words: List[str] # List of extracted words (added)
48
-
49
- class QAResponse(BaseModel):
50
- question: str
51
- answer: str
52
- score: float
53
-
54
- class CombinedRequest(BaseModel):
55
- text: str # The input text prompt
56
-
57
- class CombinedResponse(BaseModel):
58
- ner: NERResponse # NER output
59
- qa: QAResponse # QA output
60
-
61
- # Function to check if the input text belongs to allowed domains
62
- def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool:
63
- for domain in domains:
64
- if domain in text.lower():
65
- return True
66
- return False
67
-
68
- # Combined endpoint for NER and QA with domain filtering
69
- @app.post("/process/", response_model=CombinedResponse)
70
- async def process_request(request: CombinedRequest):
71
- input_text = request.text
72
-
73
- # Check if the input text belongs to the allowed domains
74
- if not is_text_in_allowed_domain(input_text, allowed_domains):
75
- raise HTTPException(
76
- status_code=400,
77
- detail="The input text does not match the allowed domains. Please provide a query related to clothing, fashion, or accessories."
78
- )
79
-
80
- # Perform Named Entity Recognition (NER)
81
- ner_entities = ner_pipeline(input_text)
82
-
83
- # Process NER results into a structured response
84
- formatted_entities = [
85
- {
86
- "word": entity["word"],
87
- "entity_group": entity["entity_group"],
88
- "score": float(entity["score"]),
89
- }
90
- for entity in ner_entities
91
- ]
92
- ner_words = [entity["word"] for entity in ner_entities] # Collect only the words
93
-
94
- ner_response = {
95
- "entities": formatted_entities,
96
- "words": ner_words # Include the list of words
97
- }
98
-
99
- # Perform Question Answering (QA)
100
- qa_result = qa_pipeline(question=input_text, context=context_msg)
101
- qa_result["score"] = float(qa_result["score"]) # Convert numpy.float32 to Python float
102
-
103
- qa_response = {
104
- "question": input_text,
105
- "answer": qa_result["answer"],
106
- "score": qa_result["score"]
107
- }
108
-
109
- # Return both NER and QA responses
110
- return {"ner": ner_response, "qa": qa_response}
111
-
112
- # Root endpoint
113
- @app.get("/")
114
- async def root():
115
- return {"message": "Welcome to the NER and QA API!"}