import gradio as gr import torch from transformers import CLIPProcessor, CLIPModel import re # Load FashionCLIP model model_name = "patrickjohncyh/fashion-clip" model = CLIPModel.from_pretrained(model_name) processor = CLIPProcessor.from_pretrained(model_name) # Price extraction regex price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE) def get_text_embedding(text): """ Converts input text into an embedding using FashionCLIP. """ inputs = processor(text=[text], images=None, return_tensors="pt", padding=True) with torch.no_grad(): text_embedding = model.get_text_features(**inputs) return text_embedding def extract_attributes(query): """ Extract structured fashion attributes dynamically using FashionCLIP. """ structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"} # Get embedding for the query query_embedding = get_text_embedding(query) # Compare with embeddings of common fashion attribute words (using FashionCLIP) reference_labels = ["Brand", "Category", "Gender", "Price"] reference_embeddings = get_text_embedding(reference_labels) # Compute cosine similarity to classify the type of query similarities = torch.nn.functional.cosine_similarity(query_embedding, reference_embeddings) best_match_index = similarities.argmax().item() # Assign type dynamically attribute_type = reference_labels[best_match_index] # Extract price dynamically price_match = price_pattern.search(query) if price_match: condition, amount, currency = price_match.groups() structured_output["Price"] = f"{condition.capitalize() if condition else ''} {amount} {currency if currency else 'AED'}".strip() # Extract brand & category dynamically using FashionCLIP similarity structured_output[attribute_type] = query # Assigning full query text to matched attribute return structured_output # Define Gradio UI def parse_query(user_query): """ Takes user query and returns structured attributes dynamically. """ parsed_output = extract_attributes(user_query) return parsed_output # Returns structured JSON with gr.Blocks() as demo: gr.Markdown("# 🛍️ Fashion Query Parser using FashionCLIP") query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., Gucci men’s perfume under 200AED") output_box = gr.JSON(label="Parsed Output") parse_button = gr.Button("Parse Query") parse_button.click(parse_query, inputs=[query_input], outputs=[output_box]) demo.launch()