bnoy1 commited on
Commit
e4ecb2d
·
verified ·
1 Parent(s): 42a7263

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -50
app.py CHANGED
@@ -1,63 +1,28 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
- from transformers import AutoImageProcessor, AutoModelForImageClassification, Trainer, TrainingArguments
 
 
4
  import torch
5
- import os
6
 
7
- # טוענים דאטאסט מהתיקיות
8
  dataset = load_dataset("imagefolder", data_dir=".", split={"train": "train[:80%]", "test": "train[80%:]"})
9
 
10
- # בוחרים מודל בסיסי
 
11
  checkpoint = "facebook/deit-tiny-patch16-224"
12
  processor = AutoImageProcessor.from_pretrained(checkpoint)
 
 
13
  model = AutoModelForImageClassification.from_pretrained(
14
  checkpoint,
15
  num_labels=3,
16
  id2label={0: "rock", 1: "paper", 2: "scissors"},
17
- label2id={"rock": 0, "paper": 1, "scissors": 2}
18
- )
19
-
20
- # פונקציה לעיבוד התמונות
21
- def preprocess(examples):
22
- images = [x.convert("RGB") for x in examples["image"]]
23
- inputs = processor(images=images, return_tensors="pt")
24
- inputs["labels"] = examples["label"]
25
- return inputs
26
-
27
- dataset = dataset.map(preprocess, batched=True)
28
-
29
- # הגדרות אימון מהיר
30
- training_args = TrainingArguments(
31
- output_dir="./results",
32
- evaluation_strategy="epoch",
33
- save_strategy="epoch",
34
- per_device_train_batch_size=4,
35
- per_device_eval_batch_size=4,
36
- num_train_epochs=2, # ✅ אפוקים מהירים: רק 2!
37
- load_best_model_at_end=True,
38
- logging_dir='./logs',
39
- logging_steps=5,
40
- )
41
-
42
- trainer = Trainer(
43
- model=model,
44
- args=training_args,
45
- train_dataset=dataset["train"],
46
- eval_dataset=dataset["test"],
47
  )
48
 
49
- # אימון
50
- trainer.train()
51
-
52
- # פונקציה להרצת חיזוי על תמונה חדשה
53
- def predict(image):
54
- inputs = processor(images=image, return_tensors="pt")
55
- outputs = model(**inputs)
56
- logits = outputs.logits
57
- predicted_class_idx = logits.argmax(-1).item()
58
- label = model.config.id2label[predicted_class_idx]
59
- return label
60
-
61
- # בניית אפליקציה
62
- demo = gr.Interface(fn=predict, inputs="image", outputs="text")
63
- demo.launch()
 
1
  import gradio as gr
2
+ import random
3
+ from PIL import Image
4
+ import time
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  import torch
7
+ from datasets import load_dataset
8
 
9
+ # טוענים דאטאסט
10
  dataset = load_dataset("imagefolder", data_dir=".", split={"train": "train[:80%]", "test": "train[80%:]"})
11
 
12
+
13
+ # טוענים מודל בסיס
14
  checkpoint = "facebook/deit-tiny-patch16-224"
15
  processor = AutoImageProcessor.from_pretrained(checkpoint)
16
+
17
+ # במקום להטעין מודל סופי → טוענים ואז משנים את הראש
18
  model = AutoModelForImageClassification.from_pretrained(
19
  checkpoint,
20
  num_labels=3,
21
  id2label={0: "rock", 1: "paper", 2: "scissors"},
22
+ label2id={"rock": 0, "paper": 1, "scissors": 2},
23
+ ignore_mismatched_sizes=True, # 🔥 קריטי!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
 
26
+ # פונקציית אימון קטנה (נוסיף עוד מעט)
27
+ # פונקציית משחק
28
+ # אפליקציה Gradio