From 19084539195ee461cd2b8f22179dbcec7254c2b5 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Wed, 1 Mar 2023 14:58:47 +0100 Subject: [PATCH] use chex for typing (#503) use chex for typing --- blackjax/types.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/blackjax/types.py b/blackjax/types.py index 5f02bc661..dc2181a03 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -1,11 +1,7 @@ from typing import Any, Iterable, Mapping, Union import jax -import jax.numpy as jnp -import numpy as np - -#: JAX or Numpy array -Array = Union[np.ndarray, jnp.ndarray] +from chex import Array #: JAX PyTrees PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]]