From 06d362054fcf6352f184fe3ac56acb9c1caae63d Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 17 Oct 2023 17:06:09 -0700 Subject: [PATCH] Clean up usage of some deprecated functionality (#2049) Summary: Updates some code & tests that lead to deprecation & other warnings. See also /~https://github.com/cornellius-gp/gpytorch/pull/2423 Reviewed By: esantorella Differential Revision: D50384526 --- botorch/acquisition/input_constructors.py | 6 +- .../multi_output_risk_measures.py | 2 +- botorch/posteriors/gpytorch.py | 8 +- test/acquisition/test_active_learning.py | 2 +- test/acquisition/test_cached_cholesky.py | 2 +- test/models/kernels/test_categorical.py | 8 +- .../likelihoods/test_pairwise_likelihood.py | 2 +- test/models/test_contextual_multioutput.py | 14 +- test/posteriors/test_gpytorch.py | 147 ++++++------------ 9 files changed, 75 insertions(+), 116 deletions(-) diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index 6f4229d2a3..a1357cc106 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -1033,7 +1033,7 @@ def construct_inputs_qMES( X = _get_dataset_field(training_data, "X", first_only=True) _kw = {"device": X.device, "dtype": X.dtype} _rvs = torch.rand(candidate_size, len(bounds), **_kw) - _bounds = torch.tensor(bounds, **_kw).transpose(0, 1) + _bounds = torch.as_tensor(bounds, **_kw).transpose(0, 1) return { "model": model, "candidate_set": _bounds[0] + (_bounds[1] - _bounds[0]) * _rvs, @@ -1090,7 +1090,7 @@ def construct_inputs_qKG( r"""Construct kwargs for `qKnowledgeGradient` constructor.""" X = _get_dataset_field(training_data, "X", first_only=True) - _bounds = torch.tensor(bounds, dtype=X.dtype, device=X.device) + _bounds = torch.as_tensor(bounds, dtype=X.dtype, device=X.device) _, current_value = optimize_objective( model=model, @@ -1181,7 +1181,7 @@ def construct_inputs_qMFMES( ) X = _get_dataset_field(training_data, "X", first_only=True) - _bounds = torch.tensor(bounds, dtype=X.dtype, device=X.device) + _bounds = torch.as_tensor(bounds, dtype=X.dtype, device=X.device) _, current_value = optimize_objective( model=model, bounds=_bounds.t(), diff --git a/botorch/acquisition/multi_objective/multi_output_risk_measures.py b/botorch/acquisition/multi_objective/multi_output_risk_measures.py index e00bacecba..b740ffe4cb 100644 --- a/botorch/acquisition/multi_objective/multi_output_risk_measures.py +++ b/botorch/acquisition/multi_objective/multi_output_risk_measures.py @@ -429,7 +429,7 @@ def get_mvar_set_gpu(self, Y: Tensor) -> Tensor: [ torch.stack( torch.meshgrid( - [Y_sorted[b, :, i] for i in range(m)], indexing=None + [Y_sorted[b, :, i] for i in range(m)], indexing="ij" ), dim=-1, ).view(-1, m) diff --git a/botorch/posteriors/gpytorch.py b/botorch/posteriors/gpytorch.py index 31d32a00a0..62faeec2fb 100644 --- a/botorch/posteriors/gpytorch.py +++ b/botorch/posteriors/gpytorch.py @@ -125,7 +125,10 @@ def rsample_from_base_samples( `self._extended_shape(sample_shape=sample_shape)`. """ if base_samples.shape[: len(sample_shape)] != sample_shape: - raise RuntimeError("`sample_shape` disagrees with shape of `base_samples`.") + raise RuntimeError( + "`sample_shape` disagrees with shape of `base_samples`. " + f"Got {sample_shape=} and {base_samples.shape=}." + ) if self._is_mt: base_samples = _reshape_base_samples_non_interleaved( mvn=self.distribution, @@ -171,7 +174,8 @@ def rsample( ) if base_samples.shape[: len(sample_shape)] != sample_shape: raise RuntimeError( - "`sample_shape` disagrees with shape of `base_samples`." + "`sample_shape` disagrees with shape of `base_samples`. " + f"Got {sample_shape=} and {base_samples.shape=}." ) # get base_samples to the correct shape base_samples = base_samples.expand(self._extended_shape(sample_shape)) diff --git a/test/acquisition/test_active_learning.py b/test/acquisition/test_active_learning.py index 193ff232b6..f2b53455b4 100644 --- a/test/acquisition/test_active_learning.py +++ b/test/acquisition/test_active_learning.py @@ -129,7 +129,7 @@ def test_q_neg_int_post_variance(self): class TestPairwiseMCPosteriorVariance(BotorchTestCase): def test_pairwise_mc_post_var(self): - train_X = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 0.0]]) + train_X = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 0.0]], dtype=torch.double) train_comp = torch.tensor([[0, 1]], dtype=torch.long) model = PairwiseGP(train_X, train_comp) diff --git a/test/acquisition/test_cached_cholesky.py b/test/acquisition/test_cached_cholesky.py index 8cbaf163c2..6e9098ccd5 100644 --- a/test/acquisition/test_cached_cholesky.py +++ b/test/acquisition/test_cached_cholesky.py @@ -95,7 +95,7 @@ def test_cache_root_decomposition(self): acqf = DummyCachedCholeskyAcqf( model=model, sampler=sampler, - objective=GenericMCObjective(lambda Y: Y[..., 0]), + objective=GenericMCObjective(lambda Y, _: Y[..., 0]), ) baseline_L = torch.eye(2, **tkwargs) with mock.patch( diff --git a/test/models/kernels/test_categorical.py b/test/models/kernels/test_categorical.py index b83df1ddd1..749b65c429 100644 --- a/test/models/kernels/test_categorical.py +++ b/test/models/kernels/test_categorical.py @@ -75,7 +75,7 @@ def test_ard(self): self.assertAllClose(res, actual) # diag - res = kernel(x1, x2).diag() + res = kernel(x1, x2).diagonal() actual = torch.diagonal(actual, dim1=-1, dim2=-2) self.assertAllClose(res, actual) @@ -85,7 +85,7 @@ def test_ard(self): self.assertAllClose(res, actual) # batch_dims + diag - res = kernel(x1, x2, last_dim_is_batch=True).diag() + res = kernel(x1, x2, last_dim_is_batch=True).diagonal() self.assertAllClose(res, torch.diagonal(actual, dim1=-1, dim2=-2)) def test_ard_batch(self): @@ -131,7 +131,7 @@ def test_ard_separate_batch(self): self.assertAllClose(res, actual) # diag - res = kernel(x1, x2).diag() + res = kernel(x1, x2).diagonal() actual = torch.diagonal(actual, dim1=-1, dim2=-2) self.assertAllClose(res, actual) @@ -141,5 +141,5 @@ def test_ard_separate_batch(self): self.assertAllClose(res, actual) # batch_dims + diag - res = kernel(x1, x2, last_dim_is_batch=True).diag() + res = kernel(x1, x2, last_dim_is_batch=True).diagonal() self.assertAllClose(res, torch.diagonal(actual, dim1=-1, dim2=-2)) diff --git a/test/models/likelihoods/test_pairwise_likelihood.py b/test/models/likelihoods/test_pairwise_likelihood.py index f83a6edeec..b795a1ef92 100644 --- a/test/models/likelihoods/test_pairwise_likelihood.py +++ b/test/models/likelihoods/test_pairwise_likelihood.py @@ -51,7 +51,7 @@ def p(self, utility: Tensor, D: Tensor) -> Tensor: n_datapoints = 4 n_comps = 3 X_dim = 4 - train_X = torch.rand(*batch_shape, n_datapoints, X_dim) + train_X = torch.rand(*batch_shape, n_datapoints, X_dim, dtype=torch.double) train_Y = train_X.sum(dim=-1, keepdim=True) train_comp = torch.stack( [ diff --git a/test/models/test_contextual_multioutput.py b/test/models/test_contextual_multioutput.py index ae044f774c..02a9305386 100644 --- a/test/models/test_contextual_multioutput.py +++ b/test/models/test_contextual_multioutput.py @@ -18,7 +18,7 @@ class ContextualMultiOutputTest(BotorchTestCase): - def testLCEMGP(self): + def test_LCEMGP(self): d = 1 for dtype, fixed_noise in ((torch.float, True), (torch.double, False)): # test with batch evaluation @@ -100,7 +100,7 @@ def testLCEMGP(self): self.assertIsInstance(embeddings2, Tensor) self.assertEqual(embeddings2.shape, torch.Size([2, 3])) - def testFixedNoiseLCEMGP(self): + def test_FixedNoiseLCEMGP(self): d = 1 for dtype in (torch.float, torch.double): train_x = torch.rand(10, d, device=self.device, dtype=dtype) @@ -111,9 +111,13 @@ def testFixedNoiseLCEMGP(self): train_x = torch.cat([train_x, task_indices.unsqueeze(-1)], axis=1) train_yvar = torch.ones(10, 1, device=self.device, dtype=dtype) * 0.01 - model = FixedNoiseLCEMGP( - train_X=train_x, train_Y=train_y, train_Yvar=train_yvar, task_feature=d - ) + with self.assertWarnsRegex(DeprecationWarning, "FixedNoiseLCEMGP"): + model = FixedNoiseLCEMGP( + train_X=train_x, + train_Y=train_y, + train_Yvar=train_yvar, + task_feature=d, + ) mll = ExactMarginalLogLikelihood(model.likelihood, model) fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 1}}) diff --git a/test/posteriors/test_gpytorch.py b/test/posteriors/test_gpytorch.py index 911927b835..0e5d282712 100644 --- a/test/posteriors/test_gpytorch.py +++ b/test/posteriors/test_gpytorch.py @@ -74,19 +74,23 @@ def test_GPyTorchPosterior(self): mock_func.assert_called_once() # rsample w/ base samples - base_samples = torch.randn(4, 3, 1, device=self.device, dtype=dtype) + base_samples = torch.randn(4, 3, device=self.device, dtype=dtype) # incompatible shapes - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, "sample_shape"): + posterior.rsample_from_base_samples( + sample_shape=torch.Size([3]), base_samples=base_samples + ) + with self.assertRaisesRegex(RuntimeError, "sample_shape"): posterior.rsample( sample_shape=torch.Size([3]), base_samples=base_samples ) # ensure consistent result for sample_shape in ([4], [4, 2]): base_samples = torch.randn( - *sample_shape, 3, 1, device=self.device, dtype=dtype + *sample_shape, 3, device=self.device, dtype=dtype ) samples = [ - posterior.rsample( + posterior.rsample_from_base_samples( sample_shape=torch.Size(sample_shape), base_samples=base_samples ) for _ in range(2) @@ -112,9 +116,12 @@ def test_GPyTorchPosterior(self): b_mvn = MultivariateNormal(b_mean, to_linear_operator(b_covar)) b_posterior = GPyTorchPosterior(distribution=b_mvn) b_base_samples = torch.randn(4, 1, 3, 1, device=self.device, dtype=dtype) - b_samples = b_posterior.rsample( - sample_shape=torch.Size([4]), base_samples=b_base_samples - ) + with self.assertWarnsRegex( + DeprecationWarning, "`base_samples` with `rsample`" + ): + b_samples = b_posterior.rsample( + sample_shape=torch.Size([4]), base_samples=b_base_samples + ) self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1])) def test_GPyTorchPosterior_Multitask(self): @@ -137,18 +144,18 @@ def test_GPyTorchPosterior_Multitask(self): self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2])) # rsample w/ base samples base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype) - samples_b1 = posterior.rsample( + samples_b1 = posterior.rsample_from_base_samples( sample_shape=torch.Size([4]), base_samples=base_samples ) - samples_b2 = posterior.rsample( + samples_b2 = posterior.rsample_from_base_samples( sample_shape=torch.Size([4]), base_samples=base_samples ) self.assertAllClose(samples_b1, samples_b2) base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype) - samples2_b1 = posterior.rsample( + samples2_b1 = posterior.rsample_from_base_samples( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) - samples2_b2 = posterior.rsample( + samples2_b2 = posterior.rsample_from_base_samples( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) self.assertAllClose(samples2_b1, samples2_b2) @@ -159,87 +166,47 @@ def test_GPyTorchPosterior_Multitask(self): b_mvn = MultitaskMultivariateNormal(b_mean, to_linear_operator(b_covar)) b_posterior = GPyTorchPosterior(distribution=b_mvn) b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype) - b_samples = b_posterior.rsample( - sample_shape=torch.Size([4]), base_samples=b_base_samples - ) - self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2])) - - def test_degenerate_GPyTorchPosterior(self): - for dtype in (torch.float, torch.double): - # singular covariance matrix - degenerate_covar = torch.tensor( - [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=self.device - ) - mean = torch.rand(3, dtype=dtype, device=self.device) - mvn = MultivariateNormal(mean, to_linear_operator(degenerate_covar)) - posterior = GPyTorchPosterior(distribution=mvn) - # basics - self.assertEqual(posterior.device.type, self.device.type) - self.assertTrue(posterior.dtype == dtype) - self.assertEqual(posterior._extended_shape(), torch.Size([3, 1])) - self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1))) - variance_exp = degenerate_covar.diag().unsqueeze(-1) - self.assertTrue(torch.equal(posterior.variance, variance_exp)) - - # rsample - with warnings.catch_warnings(record=True) as ws: - # we check that the p.d. warning is emitted - this only - # happens once per posterior, so we need to check only once - samples = posterior.rsample(sample_shape=torch.Size([4])) - self.assertTrue(any(issubclass(w.category, RuntimeWarning) for w in ws)) - self.assertTrue(any("not p.d" in str(w.message) for w in ws)) - self.assertEqual(samples.shape, torch.Size([4, 3, 1])) - samples2 = posterior.rsample(sample_shape=torch.Size([4, 2])) - self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1])) - # rsample w/ base samples - base_samples = torch.randn(4, 3, 1, device=self.device, dtype=dtype) - samples_b1 = posterior.rsample( - sample_shape=torch.Size([4]), base_samples=base_samples - ) - samples_b2 = posterior.rsample( - sample_shape=torch.Size([4]), base_samples=base_samples - ) - self.assertAllClose(samples_b1, samples_b2) - base_samples2 = torch.randn(4, 2, 3, 1, device=self.device, dtype=dtype) - samples2_b1 = posterior.rsample( - sample_shape=torch.Size([4, 2]), base_samples=base_samples2 - ) - samples2_b2 = posterior.rsample( - sample_shape=torch.Size([4, 2]), base_samples=base_samples2 - ) - self.assertAllClose(samples2_b1, samples2_b2) - # collapse_batch_dims - b_mean = torch.rand(2, 3, dtype=dtype, device=self.device) - b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape) - b_mvn = MultivariateNormal(b_mean, to_linear_operator(b_degenerate_covar)) - b_posterior = GPyTorchPosterior(distribution=b_mvn) - b_base_samples = torch.randn(4, 2, 3, 1, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as ws: + with self.assertWarnsRegex( + DeprecationWarning, "`base_samples` with `rsample`" + ): b_samples = b_posterior.rsample( sample_shape=torch.Size([4]), base_samples=b_base_samples ) - self.assertTrue(any(issubclass(w.category, RuntimeWarning) for w in ws)) - self.assertTrue(any("not p.d" in str(w.message) for w in ws)) - self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1])) + self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2])) - def test_degenerate_GPyTorchPosterior_Multitask(self): - for dtype in (torch.float, torch.double): + def test_degenerate_GPyTorchPosterior(self): + for dtype, multi_task in ( + (torch.float, False), + (torch.double, False), + (torch.double, True), + ): # singular covariance matrix degenerate_covar = torch.tensor( [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=self.device ) mean = torch.rand(3, dtype=dtype, device=self.device) mvn = MultivariateNormal(mean, to_linear_operator(degenerate_covar)) - mvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn]) + mean_exp = mean.unsqueeze(-1) + variance_exp = degenerate_covar.diag().unsqueeze(-1) + if multi_task: + expected_dim = 2 + mvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn]) + mean_exp = mean_exp.repeat(1, 2) + variance_exp = variance_exp.repeat(1, 2) + base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype) + base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype) + else: + expected_dim = 1 + base_samples = torch.randn(4, 3, device=self.device, dtype=dtype) + base_samples2 = torch.randn(4, 2, 3, device=self.device, dtype=dtype) posterior = GPyTorchPosterior(distribution=mvn) # basics self.assertEqual(posterior.device.type, self.device.type) self.assertTrue(posterior.dtype == dtype) - self.assertEqual(posterior._extended_shape(), torch.Size([3, 2])) - mean_exp = mean.unsqueeze(-1).repeat(1, 2) + self.assertEqual(posterior._extended_shape(), torch.Size([3, expected_dim])) self.assertTrue(torch.equal(posterior.mean, mean_exp)) - variance_exp = degenerate_covar.diag().unsqueeze(-1).repeat(1, 2) self.assertTrue(torch.equal(posterior.variance, variance_exp)) + # rsample with warnings.catch_warnings(record=True) as ws: # we check that the p.d. warning is emitted - this only @@ -247,40 +214,24 @@ def test_degenerate_GPyTorchPosterior_Multitask(self): samples = posterior.rsample(sample_shape=torch.Size([4])) self.assertTrue(any(issubclass(w.category, RuntimeWarning) for w in ws)) self.assertTrue(any("not p.d" in str(w.message) for w in ws)) - self.assertEqual(samples.shape, torch.Size([4, 3, 2])) + self.assertEqual(samples.shape, torch.Size([4, 3, expected_dim])) samples2 = posterior.rsample(sample_shape=torch.Size([4, 2])) - self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2])) + self.assertEqual(samples2.shape, torch.Size([4, 2, 3, expected_dim])) # rsample w/ base samples - base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype) - samples_b1 = posterior.rsample( + samples_b1 = posterior.rsample_from_base_samples( sample_shape=torch.Size([4]), base_samples=base_samples ) - samples_b2 = posterior.rsample( + samples_b2 = posterior.rsample_from_base_samples( sample_shape=torch.Size([4]), base_samples=base_samples ) self.assertAllClose(samples_b1, samples_b2) - base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype) - samples2_b1 = posterior.rsample( + samples2_b1 = posterior.rsample_from_base_samples( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) - samples2_b2 = posterior.rsample( + samples2_b2 = posterior.rsample_from_base_samples( sample_shape=torch.Size([4, 2]), base_samples=base_samples2 ) self.assertAllClose(samples2_b1, samples2_b2) - # collapse_batch_dims - b_mean = torch.rand(2, 3, dtype=dtype, device=self.device) - b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape) - b_mvn = MultivariateNormal(b_mean, to_linear_operator(b_degenerate_covar)) - b_mvn = MultitaskMultivariateNormal.from_independent_mvns([b_mvn, b_mvn]) - b_posterior = GPyTorchPosterior(distribution=b_mvn) - b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype) - with warnings.catch_warnings(record=True) as ws: - b_samples = b_posterior.rsample( - sample_shape=torch.Size([4]), base_samples=b_base_samples - ) - self.assertTrue(any(issubclass(w.category, RuntimeWarning) for w in ws)) - self.assertTrue(any("not p.d" in str(w.message) for w in ws)) - self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2])) def test_scalarize_posterior(self): for batch_shape, m, lazy, dtype in itertools.product(