|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List
|
|
from transformers import pipeline
|
|
import torch
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1
|
|
|
|
|
|
ner_pipeline = pipeline(
|
|
"ner",
|
|
model="dbmdz/bert-large-cased-finetuned-conll03-english",
|
|
aggregation_strategy="simple",
|
|
device=device
|
|
)
|
|
|
|
|
|
qa_pipeline = pipeline(
|
|
"question-answering",
|
|
model="deepset/roberta-base-squad2",
|
|
device=device
|
|
)
|
|
|
|
|
|
allowed_domains = [
|
|
"clothing", "fashion", "shopping", "accessories", "pants", "jeans", "shirts", "sustainable materials"
|
|
]
|
|
|
|
|
|
context_msg = """
|
|
We offer a wide variety of clothing options, including sustainable pants, jeans, chinos, and trousers.
|
|
Our products are made with eco-friendly materials and are available in styles such as casual wear, formal wear, and activewear.
|
|
"""
|
|
|
|
|
|
class Entity(BaseModel):
|
|
word: str
|
|
entity_group: str
|
|
score: float
|
|
|
|
class NERResponse(BaseModel):
|
|
entities: List[Entity]
|
|
words: List[str]
|
|
|
|
class QAResponse(BaseModel):
|
|
question: str
|
|
answer: str
|
|
score: float
|
|
|
|
class CombinedRequest(BaseModel):
|
|
text: str
|
|
|
|
class CombinedResponse(BaseModel):
|
|
ner: NERResponse
|
|
qa: QAResponse
|
|
|
|
|
|
def is_text_in_allowed_domain(text: str, domains: List[str]) -> bool:
|
|
for domain in domains:
|
|
if domain in text.lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
@app.post("/process/", response_model=CombinedResponse)
|
|
async def process_request(request: CombinedRequest):
|
|
input_text = request.text
|
|
|
|
|
|
if not is_text_in_allowed_domain(input_text, allowed_domains):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="The input text does not match the allowed domains. Please provide a query related to clothing, fashion, or accessories."
|
|
)
|
|
|
|
|
|
ner_entities = ner_pipeline(input_text)
|
|
|
|
|
|
formatted_entities = [
|
|
{
|
|
"word": entity["word"],
|
|
"entity_group": entity["entity_group"],
|
|
"score": float(entity["score"]),
|
|
}
|
|
for entity in ner_entities
|
|
]
|
|
ner_words = [entity["word"] for entity in ner_entities]
|
|
|
|
ner_response = {
|
|
"entities": formatted_entities,
|
|
"words": ner_words
|
|
}
|
|
|
|
|
|
qa_result = qa_pipeline(question=input_text, context=context_msg)
|
|
qa_result["score"] = float(qa_result["score"])
|
|
|
|
qa_response = {
|
|
"question": input_text,
|
|
"answer": qa_result["answer"],
|
|
"score": qa_result["score"]
|
|
}
|
|
|
|
|
|
return {"ner": ner_response, "qa": qa_response}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "Welcome to the NER and QA API!"}
|
|
|