bnoy1 commited on
Commit
c8a491f
verified
1 Parent(s): 758d7e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ from transformers import AutoImageProcessor, AutoModelForImageClassification, Trainer, TrainingArguments
4
+ import torch
5
+ import os
6
+
7
+ # 讟讜注谞讬诐 讚讗讟讗住讟 诪讛转讬拽讬讜转
8
+ dataset = load_dataset("imagefolder", data_dir=".", split={"train": "train[:80%]", "test": "train[80%:]"})
9
+
10
+ # 讘讜讞专讬诐 诪讜讚诇 讘住讬住讬
11
+ checkpoint = "google/vit-tiny-patch16-224"
12
+ processor = AutoImageProcessor.from_pretrained(checkpoint)
13
+ model = AutoModelForImageClassification.from_pretrained(
14
+ checkpoint,
15
+ num_labels=3,
16
+ id2label={0: "rock", 1: "paper", 2: "scissors"},
17
+ label2id={"rock": 0, "paper": 1, "scissors": 2}
18
+ )
19
+
20
+ # 驻讜谞拽爪讬讛 诇注讬讘讜讚 讛转诪讜谞讜转
21
+ def preprocess(examples):
22
+ images = [x.convert("RGB") for x in examples["image"]]
23
+ inputs = processor(images=images, return_tensors="pt")
24
+ inputs["labels"] = examples["label"]
25
+ return inputs
26
+
27
+ dataset = dataset.map(preprocess, batched=True)
28
+
29
+ # 讛讙讚专讜转 讗讬诪讜谉
30
+ training_args = TrainingArguments(
31
+ output_dir="./results",
32
+ evaluation_strategy="epoch",
33
+ save_strategy="epoch",
34
+ per_device_train_batch_size=4,
35
+ per_device_eval_batch_size=4,
36
+ num_train_epochs=5,
37
+ load_best_model_at_end=True,
38
+ logging_dir='./logs',
39
+ logging_steps=5,
40
+ )
41
+
42
+ trainer = Trainer(
43
+ model=model,
44
+ args=training_args,
45
+ train_dataset=dataset["train"],
46
+ eval_dataset=dataset["test"],
47
+ )
48
+
49
+ # 讗讬诪讜谉
50
+ trainer.train()
51
+
52
+ # 驻讜谞拽爪讬讛 诇讛专爪转 讞讬讝讜讬 注诇 转诪讜谞讛 讞讚砖讛
53
+ def predict(image):
54
+ inputs = processor(images=image, return_tensors="pt")
55
+ outputs = model(**inputs)
56
+ logits = outputs.logits
57
+ predicted_class_idx = logits.argmax(-1).item()
58
+ label = model.config.id2label[predicted_class_idx]
59
+ return label
60
+
61
+ # 讘谞讬讬转 讗驻诇讬拽爪讬讛
62
+ demo = gr.Interface(fn=predict, inputs="image", outputs="text")
63
+ demo.launch()