|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List, Dict
|
|
from transformers import pipeline
|
|
from itertools import groupby
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
ner_pipeline = pipeline(
|
|
"ner",
|
|
model="dbmdz/bert-large-cased-finetuned-conll03-english",
|
|
grouped_entities=True
|
|
)
|
|
|
|
|
|
qa_pipeline = pipeline(
|
|
"question-answering",
|
|
model="deepset/roberta-base-squad2"
|
|
)
|
|
|
|
|
|
allowed_domains = [
|
|
"clothing",
|
|
"fashion",
|
|
"shopping",
|
|
"accessories",
|
|
"sustainability",
|
|
"shoes",
|
|
"hats",
|
|
"shirts",
|
|
"dresses",
|
|
"pants",
|
|
"jeans",
|
|
"skirts",
|
|
"jackets",
|
|
"coats",
|
|
"t-shirts",
|
|
"sweaters",
|
|
"hoodies",
|
|
"activewear",
|
|
"formal wear",
|
|
"casual wear",
|
|
"sportswear",
|
|
"outerwear",
|
|
"swimwear",
|
|
"underwear",
|
|
"lingerie",
|
|
"socks",
|
|
"scarves",
|
|
"gloves",
|
|
"belts",
|
|
"ties",
|
|
"caps",
|
|
"beanies",
|
|
"boots",
|
|
"sandals",
|
|
"heels",
|
|
"sneakers",
|
|
"materials",
|
|
"cotton",
|
|
"polyester",
|
|
"wool",
|
|
"silk",
|
|
"leather",
|
|
"denim",
|
|
"linen",
|
|
"athleisure",
|
|
"ethnic wear",
|
|
"fashion trends",
|
|
"custom clothing",
|
|
"tailoring",
|
|
"sustainable materials",
|
|
"recycled clothing",
|
|
"fashion brands",
|
|
"streetwear"
|
|
]
|
|
|
|
|
|
class Entity(BaseModel):
|
|
word: str
|
|
entity_group: str
|
|
score: float
|
|
|
|
class NERResponse(BaseModel):
|
|
entities: List[Entity]
|
|
|
|
class QAResponse(BaseModel):
|
|
question: str
|
|
answer: str
|
|
score: float
|
|
|
|
class CombinedRequest(BaseModel):
|
|
text: str
|
|
|
|
class CombinedResponse(BaseModel):
|
|
ner: NERResponse
|
|
qa: QAResponse
|
|
|
|
|
|
def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool:
|
|
for domain in domains:
|
|
if domain in text.lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
@app.post("/process/", response_model=CombinedResponse)
|
|
async def process_request(request: CombinedRequest):
|
|
"""
|
|
Process the input text for both NER and QA, returning both responses,
|
|
only if the text matches the allowed domains.
|
|
"""
|
|
input_text = request.text
|
|
|
|
|
|
if not is_text_in_allowed_domain(input_text, allowed_domains):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=(
|
|
"The input text does not match the allowed domains. "
|
|
"Please provide a query related to clothing, fashion, or accessories."
|
|
)
|
|
)
|
|
|
|
|
|
ner_entities = ner_pipeline(input_text)
|
|
|
|
|
|
formatted_entities = [
|
|
{
|
|
"word": entity["word"],
|
|
"entity_group": entity["entity_group"],
|
|
"score": float(entity["score"]),
|
|
}
|
|
for entity in ner_entities
|
|
]
|
|
ner_response = {"entities": formatted_entities}
|
|
|
|
|
|
qa_result = qa_pipeline(question=input_text, context=input_text)
|
|
qa_result["score"] = float(qa_result["score"])
|
|
|
|
qa_response = {
|
|
"question": input_text,
|
|
"answer": qa_result["answer"],
|
|
"score": qa_result["score"]
|
|
}
|
|
|
|
|
|
return {"ner": ner_response, "qa": qa_response}
|
|
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""
|
|
Root endpoint to confirm the server is running.
|
|
"""
|
|
return {"message": "Welcome to the filtered NER and QA API!"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|