Spaces:
Running
on
Zero
Running
on
Zero
import unittest | |
from utils.image_utils import load_image, preprocess_image, get_image_from_input | |
from utils.ui_utils import update_input_visibility | |
import numpy as np | |
from PIL import Image | |
import io | |
import unittest.mock | |
import urllib.request # Import urllib.request for patching | |
import os # Keep os for file cleanup | |
# Mock object to mimic gr.update return value | |
class MockGradioUpdateReturn: | |
def __init__(self, visible=None): | |
self.visible = visible | |
# Simple mock class for urllib.request.urlopen response | |
class SimpleMockURLResponse: | |
def __init__(self, content): | |
self._content = content | |
def read(self): | |
return self._content | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
pass | |
# Use patch as a class decorator for gr.update | |
class TestUtils(unittest.TestCase): | |
# Patch urllib.request.urlopen for this specific test | |
def test_load_image_from_url( | |
self, mock_urlopen | |
): # Added mock_urlopen and mock_gr_update | |
# Create a dummy image and get its bytes | |
dummy_image = Image.new("RGB", (10, 10), color="purple") | |
byte_arr = io.BytesIO() | |
dummy_image.save(byte_arr, format="PNG") | |
mock_image_bytes = byte_arr.getvalue() | |
# Configure the mock urlopen response using SimpleMockURLResponse | |
mock_urlopen.return_value = SimpleMockURLResponse(mock_image_bytes) | |
url = "https://www.example.com/dummy_image.png" # Use a dummy URL | |
image = load_image(url) | |
self.assertIsNotNone(image) | |
self.assertIsInstance(image, Image.Image) | |
self.assertEqual(image.size, (10, 10)) | |
mock_urlopen.assert_called_once_with(url) # Verify urlopen was called | |
def test_load_image_from_base64(self): # Added mock_gr_update | |
# A simple 1x1 black PNG as a base64 string | |
base64_string = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5ErkJggg==" # Corrected padding | |
image = load_image(base64_string) | |
self.assertIsNotNone(image) | |
self.assertIsInstance(image, Image.Image) | |
self.assertEqual(image.size, (1, 1)) | |
def test_load_image_from_file(self): # Added mock_gr_update | |
# Create a dummy image file for testing | |
dummy_image = Image.new("RGB", (10, 10), color="red") | |
dummy_file_path = "dummy_test_image.png" | |
dummy_image.save(dummy_file_path) | |
image = load_image(dummy_file_path) | |
self.assertIsNotNone(image) | |
self.assertIsInstance(image, Image.Image) | |
self.assertEqual(image.size, (10, 10)) | |
# Clean up the dummy file | |
os.remove(dummy_file_path) | |
def test_load_image_invalid_url( | |
self, mock_urlopen | |
): # Added mock_urlopen and mock_gr_update | |
# Configure the mock urlopen to raise an exception | |
mock_urlopen.side_effect = urllib.error.URLError("Simulated network error") | |
url = "http://invalid.url/image.jpg" | |
image = load_image(url) | |
self.assertIsNone(image) | |
mock_urlopen.assert_called_once_with(url) | |
def test_load_image_invalid_base64(self): # Added mock_gr_update | |
base64_string = "data:image/png;base64,invalidbase64string" | |
image = load_image(base64_string) | |
self.assertIsNone(image) | |
def test_preprocess_image(self): # Added mock_gr_update | |
# Create a dummy PIL image | |
dummy_image = Image.new("RGB", (20, 20), color="blue") | |
processed_image = preprocess_image(dummy_image) | |
self.assertIsNotNone(processed_image) | |
self.assertIsInstance(processed_image, np.ndarray) | |
self.assertEqual(processed_image.shape, (20, 20, 3)) # Check shape for RGB | |
def test_get_image_from_input_url( | |
self, mock_urlopen | |
): # Added mock_urlopen and mock_gr_update | |
# Create a dummy image and get its bytes | |
dummy_image = Image.new("RGB", (30, 30), color="orange") | |
byte_arr = io.BytesIO() | |
dummy_image.save(byte_arr, format="PNG") | |
mock_image_bytes = byte_arr.getvalue() | |
# Configure the mock urlopen response using SimpleMockURLResponse | |
mock_urlopen.return_value = SimpleMockURLResponse(mock_image_bytes) | |
url = "https://www.example.com/another_dummy_image.png" | |
image = get_image_from_input("Enter URL", None, url, "") | |
self.assertIsNotNone(image) | |
self.assertIsInstance(image, Image.Image) | |
self.assertEqual(image.size, (30, 30)) | |
mock_urlopen.assert_called_once_with(url) | |
def test_get_image_from_input_upload(self): # Added mock_gr_update | |
# Mock an uploaded PIL image | |
mock_uploaded_image = Image.new("RGB", (30, 30), color="green") | |
image = get_image_from_input("Upload File", mock_uploaded_image, "", "") | |
self.assertIsNotNone(image) | |
self.assertIsInstance(image, Image.Image) | |
self.assertEqual(image.size, (30, 30)) | |
def test_get_image_from_input_base64(self): # Added mock_gr_update | |
base64_string = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" # Corrected padding | |
image = get_image_from_input("Enter Base64", None, "", base64_string) | |
self.assertIsNotNone(image) | |
self.assertIsInstance(image, Image.Image) | |
self.assertEqual(image.size, (1, 1)) | |
def test_get_image_from_input_no_input(self): # Added mock_gr_update | |
image = get_image_from_input("Upload File", None, "", "") | |
self.assertIsNone(image) | |
def test_get_image_from_input_invalid_type(self): # Added mock_gr_update | |
image = get_image_from_input( | |
"Invalid Type", Image.new("RGB", (10, 10)), "url", "base64" | |
) | |
self.assertIsNone(image) | |
# Tests for update_input_visibility now use the class-level patch | |
# Configure the side_effect of mock_gr_update to return specific MockGradioUpdateReturn objects | |
def test_update_input_visibility_upload( | |
self, mock_gr_update | |
): # Added mock_gr_update | |
mock_gr_update.side_effect = [ | |
MockGradioUpdateReturn(visible=True), | |
MockGradioUpdateReturn(visible=False), | |
MockGradioUpdateReturn(visible=False), | |
] | |
updates = update_input_visibility("Upload File") | |
self.assertEqual(len(updates), 3) | |
self.assertTrue(updates[0].visible) | |
self.assertFalse(updates[1].visible) | |
self.assertFalse(updates[2].visible) | |
def test_update_input_visibility_url(self, mock_gr_update): # Added mock_gr_update | |
mock_gr_update.side_effect = [ | |
MockGradioUpdateReturn(visible=False), | |
MockGradioUpdateReturn(visible=True), | |
MockGradioUpdateReturn(visible=False), | |
] | |
updates = update_input_visibility("Enter URL") | |
self.assertEqual(len(updates), 3) | |
self.assertFalse(updates[0].visible) | |
self.assertTrue(updates[1].visible) | |
self.assertFalse(updates[2].visible) | |
def test_update_input_visibility_base64( | |
self, mock_gr_update | |
): # Added mock_gr_update | |
mock_gr_update.side_effect = [ | |
MockGradioUpdateReturn(visible=False), | |
MockGradioUpdateReturn(visible=False), | |
MockGradioUpdateReturn(visible=True), | |
] | |
updates = update_input_visibility("Enter Base64") | |
self.assertEqual(len(updates), 3) | |
self.assertFalse(updates[0].visible) | |
self.assertFalse(updates[1].visible) | |
self.assertTrue(updates[2].visible) | |
def test_update_input_visibility_default( | |
self, mock_gr_update | |
): # Added mock_gr_update | |
mock_gr_update.side_effect = [ | |
MockGradioUpdateReturn(visible=True), | |
MockGradioUpdateReturn(visible=False), | |
MockGradioUpdateReturn(visible=False), | |
] | |
updates = update_input_visibility("Invalid Choice") | |
self.assertEqual(len(updates), 3) | |
self.assertTrue(updates[0].visible) | |
self.assertFalse(updates[1].visible) | |
self.assertFalse(updates[2].visible) | |
if __name__ == "__main__": | |
unittest.main() | |