From ca35a9d65d07920335cad1208e46b0aee2c4d024 Mon Sep 17 00:00:00 2001 From: wiredfool Date: Mon, 30 Sep 2013 14:10:58 -0700 Subject: [PATCH] tests for img -> numpy.array --- Tests/test_numpy.py | 40 +++++++++++++++++++++++++++++++++++++--- Tests/tester.py | 13 +++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/Tests/test_numpy.py b/Tests/test_numpy.py index 988189391c4..d833d7d815c 100644 --- a/Tests/test_numpy.py +++ b/Tests/test_numpy.py @@ -1,6 +1,7 @@ from tester import * from PIL import Image +import struct try: import site @@ -65,10 +66,43 @@ def test_3d_array(): assert_image(Image.fromarray(a[:, :, 1]), "L", (10, 10)) +def _test_img_equals_nparray(img, np): + assert_equal(img.size, np.shape[0:2]) + px = img.load() + for x in xrange(0, img.size[0], img.size[0]/10): + for y in xrange(0, img.size[1], img.size[1]/10): + assert_deep_equal(px[x,y], np[y,x]) + + def test_16bit(): img = Image.open('Tests/images/12bit.cropped.tif') - px = img.load() np_img = numpy.array(img) - assert_equal(np_img.shape, (64,64)) - assert_equal(px[1,1],np_img[1,1]) + _test_img_equals_nparray(img, np_img) + assert_equal(np_img.dtype, numpy.dtype('uint16')) + +def test_to_array(): + + def _to_array(mode, dtype): + img = lena(mode) + np_img = numpy.array(img) + _test_img_equals_nparray(img, np_img) + assert_equal(np_img.dtype, numpy.dtype(dtype)) + + + modes = [("L", 'uint8'), + ("I", 'int32'), + ("F", 'float32'), + ("RGB", 'uint8'), + ("RGBA", 'uint8'), + ("RGBX", 'uint8'), + ("CMYK", 'uint8'), + ("YCbCr", 'uint8'), + ("I;16", 'uint16'), + ("I;16B", '>u2'), + ("I;16L", 'uint16'), + ] + + for mode in modes: + assert_no_exception(lambda: _to_array(*mode)) + diff --git a/Tests/tester.py b/Tests/tester.py index 109265120ac..5f322cc203a 100644 --- a/Tests/tester.py +++ b/Tests/tester.py @@ -67,6 +67,19 @@ def assert_equal(a, b, msg=None): else: failure(msg or "got %r, expected %r" % (a, b)) +def assert_deep_equal(a, b, msg=None): + try: + if len(a) == len(b): + if all([x==y for x,y in zip(a,b)]): + success() + else: + failure(msg or "got %s, expected %s" % (a,b)) + else: + failure(msg or "got length %s, expected %s" % (len(a), len(b))) + except: + assert_equal(a,b,msg) + + def assert_match(v, pattern, msg=None): import re if re.match(pattern, v):