Skip to content

Commit

Permalink
Merge pull request #351 from wiredfool/numpy-tests
Browse files Browse the repository at this point in the history
tests for img -> numpy.array
  • Loading branch information
aclark4life committed Sep 30, 2013
2 parents 6954e5c + ca35a9d commit 6599fe0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
40 changes: 37 additions & 3 deletions Tests/test_numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tester import *

from PIL import Image
import struct

try:
import site
Expand Down Expand Up @@ -65,10 +66,43 @@ def test_3d_array():
assert_image(Image.fromarray(a[:, :, 1]), "L", (10, 10))


def _test_img_equals_nparray(img, np):
assert_equal(img.size, np.shape[0:2])
px = img.load()
for x in xrange(0, img.size[0], img.size[0]/10):
for y in xrange(0, img.size[1], img.size[1]/10):
assert_deep_equal(px[x,y], np[y,x])


def test_16bit():
img = Image.open('Tests/images/12bit.cropped.tif')
px = img.load()
np_img = numpy.array(img)
assert_equal(np_img.shape, (64,64))
assert_equal(px[1,1],np_img[1,1])
_test_img_equals_nparray(img, np_img)
assert_equal(np_img.dtype, numpy.dtype('uint16'))

def test_to_array():

def _to_array(mode, dtype):
img = lena(mode)
np_img = numpy.array(img)
_test_img_equals_nparray(img, np_img)
assert_equal(np_img.dtype, numpy.dtype(dtype))


modes = [("L", 'uint8'),
("I", 'int32'),
("F", 'float32'),
("RGB", 'uint8'),
("RGBA", 'uint8'),
("RGBX", 'uint8'),
("CMYK", 'uint8'),
("YCbCr", 'uint8'),
("I;16", 'uint16'),
("I;16B", '>u2'),
("I;16L", 'uint16'),
]


for mode in modes:
assert_no_exception(lambda: _to_array(*mode))

13 changes: 13 additions & 0 deletions Tests/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ def assert_equal(a, b, msg=None):
else:
failure(msg or "got %r, expected %r" % (a, b))

def assert_deep_equal(a, b, msg=None):
try:
if len(a) == len(b):
if all([x==y for x,y in zip(a,b)]):
success()
else:
failure(msg or "got %s, expected %s" % (a,b))
else:
failure(msg or "got length %s, expected %s" % (len(a), len(b)))
except:
assert_equal(a,b,msg)


def assert_match(v, pattern, msg=None):
import re
if re.match(pattern, v):
Expand Down

0 comments on commit 6599fe0

Please sign in to comment.