Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
import re | |
# Load FashionCLIP model | |
model_name = "patrickjohncyh/fashion-clip" | |
model = CLIPModel.from_pretrained(model_name) | |
processor = CLIPProcessor.from_pretrained(model_name) | |
# Price extraction regex | |
price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE) | |
def get_text_embedding(text): | |
""" | |
Converts input text into an embedding using FashionCLIP. | |
""" | |
inputs = processor(text=[text], images=None, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
text_embedding = model.get_text_features(**inputs) | |
return text_embedding | |
def extract_attributes(query): | |
""" | |
Extract structured fashion attributes dynamically using FashionCLIP. | |
""" | |
structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"} | |
# Get embedding for the query | |
query_embedding = get_text_embedding(query) | |
# Compare with embeddings of common fashion attribute words (using FashionCLIP) | |
reference_labels = ["Brand", "Category", "Gender", "Price"] | |
reference_embeddings = get_text_embedding(reference_labels) | |
# Compute cosine similarity to classify the type of query | |
similarities = torch.nn.functional.cosine_similarity(query_embedding, reference_embeddings) | |
best_match_index = similarities.argmax().item() | |
# Assign type dynamically | |
attribute_type = reference_labels[best_match_index] | |
# Extract price dynamically | |
price_match = price_pattern.search(query) | |
if price_match: | |
condition, amount, currency = price_match.groups() | |
structured_output["Price"] = f"{condition.capitalize() if condition else ''} {amount} {currency if currency else 'AED'}".strip() | |
# Extract brand & category dynamically using FashionCLIP similarity | |
structured_output[attribute_type] = query # Assigning full query text to matched attribute | |
return structured_output | |
# Define Gradio UI | |
def parse_query(user_query): | |
""" | |
Takes user query and returns structured attributes dynamically. | |
""" | |
parsed_output = extract_attributes(user_query) | |
return parsed_output # Returns structured JSON | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🛍️ Fashion Query Parser using FashionCLIP") | |
query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., Gucci men’s perfume under 200AED") | |
output_box = gr.JSON(label="Parsed Output") | |
parse_button = gr.Button("Parse Query") | |
parse_button.click(parse_query, inputs=[query_input], outputs=[output_box]) | |
demo.launch() |