cvtools / tests /test_utils.py
rifatramadhani's picture
wip
2d3e7bb
import unittest
from utils.image_utils import load_image, preprocess_image, get_image_from_input
from utils.ui_utils import update_input_visibility
import numpy as np
from PIL import Image
import io
import unittest.mock
import urllib.request # Import urllib.request for patching
import os # Keep os for file cleanup
# Mock object to mimic gr.update return value
class MockGradioUpdateReturn:
def __init__(self, visible=None):
self.visible = visible
# Simple mock class for urllib.request.urlopen response
class SimpleMockURLResponse:
def __init__(self, content):
self._content = content
def read(self):
return self._content
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
# Use patch as a class decorator for gr.update
class TestUtils(unittest.TestCase):
# Patch urllib.request.urlopen for this specific test
@unittest.mock.patch("urllib.request.urlopen")
def test_load_image_from_url(
self, mock_urlopen
): # Added mock_urlopen and mock_gr_update
# Create a dummy image and get its bytes
dummy_image = Image.new("RGB", (10, 10), color="purple")
byte_arr = io.BytesIO()
dummy_image.save(byte_arr, format="PNG")
mock_image_bytes = byte_arr.getvalue()
# Configure the mock urlopen response using SimpleMockURLResponse
mock_urlopen.return_value = SimpleMockURLResponse(mock_image_bytes)
url = "https://www.example.com/dummy_image.png" # Use a dummy URL
image = load_image(url)
self.assertIsNotNone(image)
self.assertIsInstance(image, Image.Image)
self.assertEqual(image.size, (10, 10))
mock_urlopen.assert_called_once_with(url) # Verify urlopen was called
def test_load_image_from_base64(self): # Added mock_gr_update
# A simple 1x1 black PNG as a base64 string
base64_string = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5ErkJggg==" # Corrected padding
image = load_image(base64_string)
self.assertIsNotNone(image)
self.assertIsInstance(image, Image.Image)
self.assertEqual(image.size, (1, 1))
def test_load_image_from_file(self): # Added mock_gr_update
# Create a dummy image file for testing
dummy_image = Image.new("RGB", (10, 10), color="red")
dummy_file_path = "dummy_test_image.png"
dummy_image.save(dummy_file_path)
image = load_image(dummy_file_path)
self.assertIsNotNone(image)
self.assertIsInstance(image, Image.Image)
self.assertEqual(image.size, (10, 10))
# Clean up the dummy file
os.remove(dummy_file_path)
@unittest.mock.patch("urllib.request.urlopen")
def test_load_image_invalid_url(
self, mock_urlopen
): # Added mock_urlopen and mock_gr_update
# Configure the mock urlopen to raise an exception
mock_urlopen.side_effect = urllib.error.URLError("Simulated network error")
url = "http://invalid.url/image.jpg"
image = load_image(url)
self.assertIsNone(image)
mock_urlopen.assert_called_once_with(url)
def test_load_image_invalid_base64(self): # Added mock_gr_update
base64_string = "data:image/png;base64,invalidbase64string"
image = load_image(base64_string)
self.assertIsNone(image)
def test_preprocess_image(self): # Added mock_gr_update
# Create a dummy PIL image
dummy_image = Image.new("RGB", (20, 20), color="blue")
processed_image = preprocess_image(dummy_image)
self.assertIsNotNone(processed_image)
self.assertIsInstance(processed_image, np.ndarray)
self.assertEqual(processed_image.shape, (20, 20, 3)) # Check shape for RGB
@unittest.mock.patch("urllib.request.urlopen")
def test_get_image_from_input_url(
self, mock_urlopen
): # Added mock_urlopen and mock_gr_update
# Create a dummy image and get its bytes
dummy_image = Image.new("RGB", (30, 30), color="orange")
byte_arr = io.BytesIO()
dummy_image.save(byte_arr, format="PNG")
mock_image_bytes = byte_arr.getvalue()
# Configure the mock urlopen response using SimpleMockURLResponse
mock_urlopen.return_value = SimpleMockURLResponse(mock_image_bytes)
url = "https://www.example.com/another_dummy_image.png"
image = get_image_from_input("Enter URL", None, url, "")
self.assertIsNotNone(image)
self.assertIsInstance(image, Image.Image)
self.assertEqual(image.size, (30, 30))
mock_urlopen.assert_called_once_with(url)
def test_get_image_from_input_upload(self): # Added mock_gr_update
# Mock an uploaded PIL image
mock_uploaded_image = Image.new("RGB", (30, 30), color="green")
image = get_image_from_input("Upload File", mock_uploaded_image, "", "")
self.assertIsNotNone(image)
self.assertIsInstance(image, Image.Image)
self.assertEqual(image.size, (30, 30))
def test_get_image_from_input_base64(self): # Added mock_gr_update
base64_string = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" # Corrected padding
image = get_image_from_input("Enter Base64", None, "", base64_string)
self.assertIsNotNone(image)
self.assertIsInstance(image, Image.Image)
self.assertEqual(image.size, (1, 1))
def test_get_image_from_input_no_input(self): # Added mock_gr_update
image = get_image_from_input("Upload File", None, "", "")
self.assertIsNone(image)
def test_get_image_from_input_invalid_type(self): # Added mock_gr_update
image = get_image_from_input(
"Invalid Type", Image.new("RGB", (10, 10)), "url", "base64"
)
self.assertIsNone(image)
# Tests for update_input_visibility now use the class-level patch
# Configure the side_effect of mock_gr_update to return specific MockGradioUpdateReturn objects
@unittest.mock.patch("gradio.update")
def test_update_input_visibility_upload(
self, mock_gr_update
): # Added mock_gr_update
mock_gr_update.side_effect = [
MockGradioUpdateReturn(visible=True),
MockGradioUpdateReturn(visible=False),
MockGradioUpdateReturn(visible=False),
]
updates = update_input_visibility("Upload File")
self.assertEqual(len(updates), 3)
self.assertTrue(updates[0].visible)
self.assertFalse(updates[1].visible)
self.assertFalse(updates[2].visible)
@unittest.mock.patch("gradio.update")
def test_update_input_visibility_url(self, mock_gr_update): # Added mock_gr_update
mock_gr_update.side_effect = [
MockGradioUpdateReturn(visible=False),
MockGradioUpdateReturn(visible=True),
MockGradioUpdateReturn(visible=False),
]
updates = update_input_visibility("Enter URL")
self.assertEqual(len(updates), 3)
self.assertFalse(updates[0].visible)
self.assertTrue(updates[1].visible)
self.assertFalse(updates[2].visible)
@unittest.mock.patch("gradio.update")
def test_update_input_visibility_base64(
self, mock_gr_update
): # Added mock_gr_update
mock_gr_update.side_effect = [
MockGradioUpdateReturn(visible=False),
MockGradioUpdateReturn(visible=False),
MockGradioUpdateReturn(visible=True),
]
updates = update_input_visibility("Enter Base64")
self.assertEqual(len(updates), 3)
self.assertFalse(updates[0].visible)
self.assertFalse(updates[1].visible)
self.assertTrue(updates[2].visible)
@unittest.mock.patch("gradio.update")
def test_update_input_visibility_default(
self, mock_gr_update
): # Added mock_gr_update
mock_gr_update.side_effect = [
MockGradioUpdateReturn(visible=True),
MockGradioUpdateReturn(visible=False),
MockGradioUpdateReturn(visible=False),
]
updates = update_input_visibility("Invalid Choice")
self.assertEqual(len(updates), 3)
self.assertTrue(updates[0].visible)
self.assertFalse(updates[1].visible)
self.assertFalse(updates[2].visible)
if __name__ == "__main__":
unittest.main()