Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix reshape interoperability test (#17155)
Browse files Browse the repository at this point in the history
* fix reshape interoperability test

* fix for scipy import
  • Loading branch information
haojin2 authored Dec 25, 2019
1 parent efc4ad8 commit 318d9c7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions ci/docker/install/requirements
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ h5py==2.8.0rc1
mock==2.0.0
nose==1.3.7
nose-timer==0.7.3
numpy>1.16.0,<2.0.0
numpy>1.16.0,<1.18.0
pylint==2.3.1; python_version >= '3.0'
requests<2.19.0,>=2.18.4
scipy==1.0.1
scipy==1.2.1
six==1.11.0
5 changes: 3 additions & 2 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mxnet as mx
import numpy as np
import scipy
from scipy.stats import pearsonr
import json
import math
from common import with_seed
Expand Down Expand Up @@ -267,7 +268,7 @@ def test_pearsonr():
pred1 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
label1 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
pearsonr_expected_np = np.corrcoef(pred1.asnumpy().ravel(), label1.asnumpy().ravel())[0, 1]
pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel())
pearsonr_expected_scipy, _ = pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel())
macro_pr = mx.metric.create('pearsonr', average='macro')
micro_pr = mx.metric.create('pearsonr', average='micro')

Expand All @@ -289,7 +290,7 @@ def test_pearsonr():
label12 = mx.nd.array([[1, 0], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]])

pearsonr_expected_np = np.corrcoef(pred12.asnumpy().ravel(), label12.asnumpy().ravel())[0, 1]
pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel())
pearsonr_expected_scipy, _ = pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel())

macro_pr.reset()
micro_pr.update([label2], [pred2])
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def _add_workload_reshape():
# OpArgMngr.add_workload('reshape', b, (2, 2), order='F') # Items are not equal with order='F'

a = np.array(_np.ones((0, 2)))
OpArgMngr.add_workload('reshape', a, -1, 2)
OpArgMngr.add_workload('reshape', a, (-1, 2))


def _add_workload_rint(array_pool):
Expand Down

0 comments on commit 318d9c7

Please sign in to comment.