paramasivan27 commited on
Commit
7b0a595
·
1 Parent(s): 6f69f46

Update space

Browse files
Files changed (1) hide show
  1. app.py +26 -1
app.py CHANGED
@@ -6,11 +6,36 @@ import streamlit as st
6
  model = AutoModelForSequenceClassification.from_pretrained('paramasivan27/RetailProductClassification_bert-base-uncased')
7
  tokenizer = AutoTokenizer.from_pretrained('paramasivan27/RetailProductClassification_bert-base-uncased')
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def predict(text):
10
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
11
  outputs = model(**inputs)
12
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
13
- return predicted_class
 
14
 
15
  # Streamlit app layout
16
  st.title("Retail Product Classification")
 
6
  model = AutoModelForSequenceClassification.from_pretrained('paramasivan27/RetailProductClassification_bert-base-uncased')
7
  tokenizer = AutoTokenizer.from_pretrained('paramasivan27/RetailProductClassification_bert-base-uncased')
8
 
9
+ label_to_id = {'Electronics':0,
10
+ 'Sports & Outdoors':1,
11
+ 'Cell Phones & Accessories':2,
12
+ 'Automotive':3,
13
+ 'Toys & Games':4,
14
+ 'Tools & Home Improvement':5,
15
+ 'Health & Personal Care':6,
16
+ 'Beauty':7,
17
+ 'Grocery & Gourmet Food':8,
18
+ 'Office Products':9,
19
+ 'Arts, Crafts & Sewing':10,
20
+ 'Pet Supplies':11,
21
+ 'Patio, Lawn & Garden':12,
22
+ 'Clothing, Shoes & Jewelry':13,
23
+ 'Baby':14,
24
+ 'Musical Instruments':15,
25
+ 'Industrial & Scientific':16,
26
+ 'Baby Products':17,
27
+ 'Appliances':18,
28
+ 'All Beauty':19,
29
+ 'All Electronics':20}
30
+
31
+ id_to_label = {v: k for k, v in label_to_id.items()}
32
+
33
  def predict(text):
34
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
35
  outputs = model(**inputs)
36
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
37
+ class_name = id_to_label[predicted_class]
38
+ return class_name
39
 
40
  # Streamlit app layout
41
  st.title("Retail Product Classification")