jatingocodeo commited on
Commit
7036bb9
·
verified ·
1 Parent(s): 57f1e4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -4,8 +4,9 @@ import torchvision.transforms as transforms
4
  from PIL import Image
5
  import requests
6
  from io import BytesIO
7
- from transformers import AutoImageProcessor, AutoModelForImageClassification
8
  import json
 
 
9
 
10
  # Load ImageNet class labels
11
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
@@ -16,8 +17,21 @@ def load_model():
16
  """
17
  Load model and processor from Hugging Face Hub
18
  """
19
- model_id = "jatingocodeo/ImageNet" # Updated model repository ID
20
- model = AutoModelForImageClassification.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  processor = AutoImageProcessor.from_pretrained(model_id)
22
  return model, processor
23
 
@@ -31,18 +45,16 @@ def predict(image):
31
  try:
32
  # Load model and processor (with caching)
33
  model, processor = load_model()
34
- model.eval()
35
 
36
  # Process image
37
  inputs = processor(image, return_tensors="pt")
38
 
39
  # Get predictions
40
  with torch.no_grad():
41
- outputs = model(**inputs)
42
- logits = outputs.logits
43
 
44
  # Get probabilities and classes
45
- probs = torch.nn.functional.softmax(logits, dim=1)[0]
46
  top_probs, top_indices = torch.topk(probs, k=5)
47
 
48
  # Format results
 
4
  from PIL import Image
5
  import requests
6
  from io import BytesIO
 
7
  import json
8
+ import torchvision.models as models
9
+ from transformers import AutoImageProcessor
10
 
11
  # Load ImageNet class labels
12
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
 
17
  """
18
  Load model and processor from Hugging Face Hub
19
  """
20
+ model_id = "jatingocodeo/ImageNet"
21
+
22
+ # Initialize ResNet50 model
23
+ model = models.resnet50(weights=None)
24
+ model.fc = torch.nn.Linear(model.fc.in_features, 1000) # 1000 ImageNet classes
25
+
26
+ # Load model weights
27
+ checkpoint = torch.hub.load_state_dict_from_url(
28
+ f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin",
29
+ map_location="cpu"
30
+ )
31
+ model.load_state_dict(checkpoint)
32
+ model.eval()
33
+
34
+ # Create processor
35
  processor = AutoImageProcessor.from_pretrained(model_id)
36
  return model, processor
37
 
 
45
  try:
46
  # Load model and processor (with caching)
47
  model, processor = load_model()
 
48
 
49
  # Process image
50
  inputs = processor(image, return_tensors="pt")
51
 
52
  # Get predictions
53
  with torch.no_grad():
54
+ outputs = model(inputs.pixel_values)
 
55
 
56
  # Get probabilities and classes
57
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0]
58
  top_probs, top_indices = torch.topk(probs, k=5)
59
 
60
  # Format results