|
"""Miscellaneous functions for testing masked arrays and subclasses |
|
|
|
:author: Pierre Gerard-Marchant |
|
:contact: pierregm_at_uga_dot_edu |
|
:version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $ |
|
|
|
""" |
|
from __future__ import division, absolute_import, print_function |
|
|
|
__author__ = "Pierre GF Gerard-Marchant ($Author: jarrod.millman $)" |
|
__version__ = "1.0" |
|
__revision__ = "$Revision: 3529 $" |
|
__date__ = "$Date: 2007-11-13 10:01:14 +0200 (Tue, 13 Nov 2007) $" |
|
|
|
|
|
import operator |
|
|
|
import numpy as np |
|
from numpy import ndarray, float_ |
|
import numpy.core.umath as umath |
|
from numpy.testing import * |
|
import numpy.testing.utils as utils |
|
|
|
from .core import mask_or, getmask, masked_array, nomask, masked, filled, \ |
|
equal, less |
|
|
|
|
|
def approx (a, b, fill_value=True, rtol=1e-5, atol=1e-8): |
|
"""Returns true if all components of a and b are equal subject to given tolerances. |
|
|
|
If fill_value is True, masked values considered equal. Otherwise, masked values |
|
are considered unequal. |
|
The relative error rtol should be positive and << 1.0 |
|
The absolute error atol comes into play for those elements of b that are very |
|
small or zero; it says how small a must be also. |
|
""" |
|
m = mask_or(getmask(a), getmask(b)) |
|
d1 = filled(a) |
|
d2 = filled(b) |
|
if d1.dtype.char == "O" or d2.dtype.char == "O": |
|
return np.equal(d1, d2).ravel() |
|
x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) |
|
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) |
|
d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y)) |
|
return d.ravel() |
|
|
|
|
|
def almost(a, b, decimal=6, fill_value=True): |
|
"""Returns True if a and b are equal up to decimal places. |
|
If fill_value is True, masked values considered equal. Otherwise, masked values |
|
are considered unequal. |
|
""" |
|
m = mask_or(getmask(a), getmask(b)) |
|
d1 = filled(a) |
|
d2 = filled(b) |
|
if d1.dtype.char == "O" or d2.dtype.char == "O": |
|
return np.equal(d1, d2).ravel() |
|
x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) |
|
y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) |
|
d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal) |
|
return d.ravel() |
|
|
|
|
|
|
|
def _assert_equal_on_sequences(actual, desired, err_msg=''): |
|
"Asserts the equality of two non-array sequences." |
|
assert_equal(len(actual), len(desired), err_msg) |
|
for k in range(len(desired)): |
|
assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg)) |
|
return |
|
|
|
def assert_equal_records(a, b): |
|
"""Asserts that two records are equal. Pretty crude for now.""" |
|
assert_equal(a.dtype, b.dtype) |
|
for f in a.dtype.names: |
|
(af, bf) = (operator.getitem(a, f), operator.getitem(b, f)) |
|
if not (af is masked) and not (bf is masked): |
|
assert_equal(operator.getitem(a, f), operator.getitem(b, f)) |
|
return |
|
|
|
|
|
def assert_equal(actual, desired, err_msg=''): |
|
"Asserts that two items are equal." |
|
|
|
if isinstance(desired, dict): |
|
if not isinstance(actual, dict): |
|
raise AssertionError(repr(type(actual))) |
|
assert_equal(len(actual), len(desired), err_msg) |
|
for k, i in desired.items(): |
|
if not k in actual: |
|
raise AssertionError("%s not in %s" % (k, actual)) |
|
assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg)) |
|
return |
|
|
|
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): |
|
return _assert_equal_on_sequences(actual, desired, err_msg='') |
|
if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)): |
|
msg = build_err_msg([actual, desired], err_msg,) |
|
if not desired == actual: |
|
raise AssertionError(msg) |
|
return |
|
|
|
if ((actual is masked) and not (desired is masked)) or \ |
|
((desired is masked) and not (actual is masked)): |
|
msg = build_err_msg([actual, desired], |
|
err_msg, header='', names=('x', 'y')) |
|
raise ValueError(msg) |
|
actual = np.array(actual, copy=False, subok=True) |
|
desired = np.array(desired, copy=False, subok=True) |
|
(actual_dtype, desired_dtype) = (actual.dtype, desired.dtype) |
|
if actual_dtype.char == "S" and desired_dtype.char == "S": |
|
return _assert_equal_on_sequences(actual.tolist(), |
|
desired.tolist(), |
|
err_msg='') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return assert_array_equal(actual, desired, err_msg) |
|
|
|
|
|
def fail_if_equal(actual, desired, err_msg='',): |
|
"""Raises an assertion error if two items are equal. |
|
""" |
|
if isinstance(desired, dict): |
|
if not isinstance(actual, dict): |
|
raise AssertionError(repr(type(actual))) |
|
fail_if_equal(len(actual), len(desired), err_msg) |
|
for k, i in desired.items(): |
|
if not k in actual: |
|
raise AssertionError(repr(k)) |
|
fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg)) |
|
return |
|
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): |
|
fail_if_equal(len(actual), len(desired), err_msg) |
|
for k in range(len(desired)): |
|
fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg)) |
|
return |
|
if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray): |
|
return fail_if_array_equal(actual, desired, err_msg) |
|
msg = build_err_msg([actual, desired], err_msg) |
|
if not desired != actual: |
|
raise AssertionError(msg) |
|
|
|
assert_not_equal = fail_if_equal |
|
|
|
|
|
def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): |
|
"""Asserts that two items are almost equal. |
|
The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal) |
|
""" |
|
if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray): |
|
return assert_array_almost_equal(actual, desired, decimal=decimal, |
|
err_msg=err_msg, verbose=verbose) |
|
msg = build_err_msg([actual, desired], |
|
err_msg=err_msg, verbose=verbose) |
|
if not round(abs(desired - actual), decimal) == 0: |
|
raise AssertionError(msg) |
|
|
|
|
|
assert_close = assert_almost_equal |
|
|
|
|
|
def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', |
|
fill_value=True): |
|
"""Asserts that a comparison relation between two masked arrays is satisfied |
|
elementwise.""" |
|
|
|
|
|
|
|
|
|
m = mask_or(getmask(x), getmask(y)) |
|
x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False) |
|
y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False) |
|
if ((x is masked) and not (y is masked)) or \ |
|
((y is masked) and not (x is masked)): |
|
msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose, |
|
header=header, names=('x', 'y')) |
|
raise ValueError(msg) |
|
|
|
return utils.assert_array_compare(comparison, |
|
x.filled(fill_value), |
|
y.filled(fill_value), |
|
err_msg=err_msg, |
|
verbose=verbose, header=header) |
|
|
|
|
|
def assert_array_equal(x, y, err_msg='', verbose=True): |
|
"""Checks the elementwise equality of two masked arrays.""" |
|
assert_array_compare(operator.__eq__, x, y, |
|
err_msg=err_msg, verbose=verbose, |
|
header='Arrays are not equal') |
|
|
|
|
|
def fail_if_array_equal(x, y, err_msg='', verbose=True): |
|
"Raises an assertion error if two masked arrays are not equal (elementwise)." |
|
def compare(x, y): |
|
return (not np.alltrue(approx(x, y))) |
|
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, |
|
header='Arrays are not equal') |
|
|
|
|
|
def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True): |
|
"""Checks the elementwise equality of two masked arrays, up to a given |
|
number of decimals.""" |
|
def compare(x, y): |
|
"Returns the result of the loose comparison between x and y)." |
|
return approx(x, y, rtol=10. ** -decimal) |
|
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, |
|
header='Arrays are not almost equal') |
|
|
|
|
|
def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): |
|
"""Checks the elementwise equality of two masked arrays, up to a given |
|
number of decimals.""" |
|
def compare(x, y): |
|
"Returns the result of the loose comparison between x and y)." |
|
return almost(x, y, decimal) |
|
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, |
|
header='Arrays are not almost equal') |
|
|
|
|
|
def assert_array_less(x, y, err_msg='', verbose=True): |
|
"Checks that x is smaller than y elementwise." |
|
assert_array_compare(operator.__lt__, x, y, |
|
err_msg=err_msg, verbose=verbose, |
|
header='Arrays are not less-ordered') |
|
|
|
|
|
def assert_mask_equal(m1, m2, err_msg=''): |
|
"""Asserts the equality of two masks.""" |
|
if m1 is nomask: |
|
assert_(m2 is nomask) |
|
if m2 is nomask: |
|
assert_(m1 is nomask) |
|
assert_array_equal(m1, m2, err_msg=err_msg) |
|
|