DINGOLANI commited on
Commit
8c1ee79
·
verified ·
1 Parent(s): 4898394

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -21
app.py CHANGED
@@ -1,45 +1,71 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import CLIPProcessor, CLIPModel
 
4
 
5
- # Load the FashionCLIP model
6
  model_name = "patrickjohncyh/fashion-clip"
7
  model = CLIPModel.from_pretrained(model_name)
8
  processor = CLIPProcessor.from_pretrained(model_name)
9
 
10
- def parse_query(user_query):
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
- Parse fashion-related search queries into structured data.
13
  """
14
- # Define categories relevant to luxury fashion search
15
- fashion_categories = ["Brand", "Category", "Gender", "Price Range"]
16
 
17
- # Format user query for CLIP
18
- inputs = processor(text=[user_query], images=None, return_tensors="pt", padding=True)
19
 
20
- # Get model embeddings
21
- with torch.no_grad():
22
- outputs = model.get_text_features(**inputs)
 
 
 
 
23
 
24
- # Simulated parsing output (FashionCLIP itself does not generate structured JSON)
25
- parsed_output = {
26
- "Brand": "Gucci" if "Gucci" in user_query else "Unknown",
27
- "Category": "Perfume" if "perfume" in user_query else "Unknown",
28
- "Gender": "Men" if "men" in user_query else "Women" if "women" in user_query else "Unisex",
29
- "Price Range": "Under 200 AED" if "under 200" in user_query else "Above 200 AED",
30
- }
31
 
32
- return parsed_output
 
 
 
 
 
 
 
 
 
33
 
34
  # Define Gradio UI
 
 
 
 
 
 
 
35
  with gr.Blocks() as demo:
36
- gr.Markdown("# 🛍️ Luxury Fashion Query Parser (FashionCLIP)")
37
-
38
  query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., Gucci men’s perfume under 200AED")
39
  output_box = gr.JSON(label="Parsed Output")
40
 
41
  parse_button = gr.Button("Parse Query")
42
  parse_button.click(parse_query, inputs=[query_input], outputs=[output_box])
43
 
44
- # Launch the app
45
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import CLIPProcessor, CLIPModel
4
+ import re
5
 
6
+ # Load FashionCLIP model
7
  model_name = "patrickjohncyh/fashion-clip"
8
  model = CLIPModel.from_pretrained(model_name)
9
  processor = CLIPProcessor.from_pretrained(model_name)
10
 
11
+ # Price extraction regex
12
+ price_pattern = re.compile(r'(\bunder\b|\babove\b|\bbelow\b|\bbetween\b)?\s?(\d{1,5})\s?(AED|USD|EUR)?', re.IGNORECASE)
13
+
14
+ def get_text_embedding(text):
15
+ """
16
+ Converts input text into an embedding using FashionCLIP.
17
+ """
18
+ inputs = processor(text=[text], images=None, return_tensors="pt", padding=True)
19
+ with torch.no_grad():
20
+ text_embedding = model.get_text_features(**inputs)
21
+ return text_embedding
22
+
23
+ def extract_attributes(query):
24
  """
25
+ Extract structured fashion attributes dynamically using FashionCLIP.
26
  """
27
+ structured_output = {"Brand": "Unknown", "Category": "Unknown", "Gender": "Unknown", "Price": "Unknown"}
 
28
 
29
+ # Get embedding for the query
30
+ query_embedding = get_text_embedding(query)
31
 
32
+ # Compare with embeddings of common fashion attribute words (using FashionCLIP)
33
+ reference_labels = ["Brand", "Category", "Gender", "Price"]
34
+ reference_embeddings = get_text_embedding(reference_labels)
35
+
36
+ # Compute cosine similarity to classify the type of query
37
+ similarities = torch.nn.functional.cosine_similarity(query_embedding, reference_embeddings)
38
+ best_match_index = similarities.argmax().item()
39
 
40
+ # Assign type dynamically
41
+ attribute_type = reference_labels[best_match_index]
 
 
 
 
 
42
 
43
+ # Extract price dynamically
44
+ price_match = price_pattern.search(query)
45
+ if price_match:
46
+ condition, amount, currency = price_match.groups()
47
+ structured_output["Price"] = f"{condition.capitalize() if condition else ''} {amount} {currency if currency else 'AED'}".strip()
48
+
49
+ # Extract brand & category dynamically using FashionCLIP similarity
50
+ structured_output[attribute_type] = query # Assigning full query text to matched attribute
51
+
52
+ return structured_output
53
 
54
  # Define Gradio UI
55
+ def parse_query(user_query):
56
+ """
57
+ Takes user query and returns structured attributes dynamically.
58
+ """
59
+ parsed_output = extract_attributes(user_query)
60
+ return parsed_output # Returns structured JSON
61
+
62
  with gr.Blocks() as demo:
63
+ gr.Markdown("# 🛍️ Fashion Query Parser using FashionCLIP")
64
+
65
  query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., Gucci men’s perfume under 200AED")
66
  output_box = gr.JSON(label="Parsed Output")
67
 
68
  parse_button = gr.Button("Parse Query")
69
  parse_button.click(parse_query, inputs=[query_input], outputs=[output_box])
70
 
 
71
  demo.launch()