mestrevh commited on
Commit
dadd386
·
1 Parent(s): 15dcb93

Add the files

Browse files
Files changed (3) hide show
  1. app.py +17 -0
  2. requirements.txt +4 -0
  3. train.py +47 -0
app.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # Carregar o modelo diretamente do Hugging Face Hub
5
+ classifier = pipeline("image-classification", model="mestrevh/computer-vision-cifar-10")
6
+
7
+ # Função de classificação
8
+ def predict_image(image):
9
+ return classifier(image)
10
+
11
+ # Interface Gradio
12
+ interface = gr.Interface(fn=predict_image,
13
+ inputs=gr.inputs.Image(type="pil"),
14
+ outputs="label",
15
+ live=True)
16
+
17
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.29.0
2
+ datasets==2.10.0
3
+ torch==2.1.0
4
+ gradio==3.33.0
train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
4
+
5
+ # Carregar o dataset "beans"
6
+ dataset = load_dataset("beans")
7
+
8
+ # Carregar o modelo pré-treinado e definir o número de classes corretamente (3 classes para Beans)
9
+ model = ViTForImageClassification.from_pretrained(
10
+ "google/vit-base-patch16-224-in21k",
11
+ num_labels=3 # Beans tem 3 classes
12
+ )
13
+ feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
14
+
15
+ # Preprocessamento
16
+ def preprocess_function(examples):
17
+ inputs = feature_extractor(examples["image"], return_tensors="pt") # A chave correta no Beans é "image"
18
+ inputs["labels"] = examples["labels"] # Certifique-se de que o rótulo está correto
19
+ return inputs
20
+
21
+ # Aplicando o preprocessamento ao dataset
22
+ dataset = dataset.map(preprocess_function, batched=True)
23
+
24
+ # Definir os parâmetros de treinamento
25
+ training_args = TrainingArguments(
26
+ output_dir="./results",
27
+ evaluation_strategy="epoch",
28
+ learning_rate=2e-5,
29
+ per_device_train_batch_size=16,
30
+ per_device_eval_batch_size=64,
31
+ num_train_epochs=3,
32
+ weight_decay=0.01,
33
+ )
34
+
35
+ trainer = Trainer(
36
+ model=model,
37
+ args=training_args,
38
+ train_dataset=dataset["train"],
39
+ eval_dataset=dataset["validation"], # No Beans, o conjunto de teste é chamado de "validation"
40
+ )
41
+
42
+ # Treinar o modelo
43
+ trainer.train()
44
+
45
+ # Salvar o modelo e o feature extractor treinados
46
+ model.save_pretrained("./computer-vision-beans")
47
+ feature_extractor.save_pretrained("./computer-vision-beans")