diff --git a/palette/Cargo.toml b/palette/Cargo.toml index 62707c1f8..dd57e97cc 100644 --- a/palette/Cargo.toml +++ b/palette/Cargo.toml @@ -79,6 +79,7 @@ lazy_static = "1" serde = "1" serde_derive = "1" serde_json = "1" +ron = "0.8.0" enterpolation = "0.2.0" scad = "1.2.2" # For regression testing #283 diff --git a/palette/src/alpha/alpha.rs b/palette/src/alpha/alpha.rs index 86bcc8591..b1e246821 100644 --- a/palette/src/alpha/alpha.rs +++ b/palette/src/alpha/alpha.rs @@ -30,11 +30,9 @@ use crate::{ /// An alpha component wrapper for colors. #[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serializing", derive(Serialize, Deserialize))] #[repr(C)] pub struct Alpha { /// The color. - #[cfg_attr(feature = "serializing", serde(flatten))] pub color: C, /// The transparency component. 0.0 is fully transparent and 1.0 is fully @@ -658,6 +656,48 @@ where } } +#[cfg(feature = "serializing")] +impl serde::Serialize for Alpha +where + C: serde::Serialize, + T: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.color.serialize(crate::serde::AlphaSerializer { + inner: serializer, + alpha: &self.alpha, + }) + } +} + +#[cfg(feature = "serializing")] +impl<'de, C, T> serde::Deserialize<'de> for Alpha +where + C: serde::Deserialize<'de>, + T: serde::Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut alpha: Option = None; + + let color = C::deserialize(crate::serde::AlphaDeserializer { + inner: deserializer, + alpha: &mut alpha, + })?; + + if let Some(alpha) = alpha { + Ok(Self { color, alpha }) + } else { + Err(serde::de::Error::missing_field("alpha")) + } + } +} + #[cfg(feature = "random")] impl Distribution> for Standard where @@ -865,21 +905,109 @@ mod test { #[cfg(feature = "serializing")] #[test] fn serialize() { - let serialized = ::serde_json::to_string(&Rgba::::new(0.3, 0.8, 0.1, 0.5)).unwrap(); + let color = Rgba::::new(0.3, 0.8, 0.1, 0.5); assert_eq!( - serialized, + serde_json::to_string(&color).unwrap(), r#"{"red":0.3,"green":0.8,"blue":0.1,"alpha":0.5}"# ); + + assert_eq!( + ron::to_string(&color).unwrap(), + r#"(red:0.3,green:0.8,blue:0.1,alpha:0.5)"# + ); } #[cfg(feature = "serializing")] #[test] fn deserialize() { - let deserialized: Rgba = - ::serde_json::from_str(r#"{"red":0.3,"green":0.8,"blue":0.1,"alpha":0.5}"#).unwrap(); + let color = Rgba::::new(0.3, 0.8, 0.1, 0.5); + + assert_eq!( + serde_json::from_str::>(r#"{"alpha":0.5,"red":0.3,"green":0.8,"blue":0.1}"#) + .unwrap(), + color + ); + + assert_eq!( + ron::from_str::>(r#"(alpha:0.5,red:0.3,green:0.8,blue:0.1)"#).unwrap(), + color + ); + + assert_eq!( + ron::from_str::>(r#"Rgb(alpha:0.5,red:0.3,green:0.8,blue:0.1)"#).unwrap(), + color + ); + } + + #[cfg(feature = "serializing")] + #[test] + fn serde_round_trips() { + let color = Rgba::::new(0.3, 0.8, 0.1, 0.5); + + assert_eq!( + serde_json::from_str::>(&serde_json::to_string(&color).unwrap()).unwrap(), + color + ); + + assert_eq!( + ron::from_str::>(&ron::to_string(&color).unwrap()).unwrap(), + color + ); + } + + #[cfg(feature = "serializing")] + #[test] + fn serde_various_types() { + macro_rules! test_roundtrip { + ($value:expr $(, $ron_name:expr)?) => { + let value = super::Alpha { + color: $value, + alpha: 0.5, + }; + assert_eq!( + serde_json::from_str::>( + &serde_json::to_string(&value).expect("json serialization") + ) + .expect("json deserialization"), + value + ); + + let ron_string = ron::to_string(&value).expect("ron serialization"); + assert_eq!( + ron::from_str::>(&ron_string) + .expect("ron deserialization"), + value + ); + $( + assert_eq!( + ron::from_str::>(&format!("{}{ron_string}", $ron_name)) + .expect("ron deserialization"), + value + ); + )? + }; + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Empty; + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct UnitTuple(); + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Newtype(f32); + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Tuple(f32, f32); + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Struct { + value: f32, + } - assert_eq!(deserialized, Rgba::::new(0.3, 0.8, 0.1, 0.5)); + test_roundtrip!(()); + test_roundtrip!(Empty, "Empty"); + test_roundtrip!(UnitTuple(), "UnitTuple"); + test_roundtrip!(Newtype(0.1), "Newtype"); + test_roundtrip!(Tuple(0.1, 0.2), "Tuple"); + test_roundtrip!(Struct { value: 0.1 }, "Struct"); } #[cfg(feature = "random")] diff --git a/palette/src/blend/pre_alpha.rs b/palette/src/blend/pre_alpha.rs index f1a4ec2a5..b4e21b31c 100644 --- a/palette/src/blend/pre_alpha.rs +++ b/palette/src/blend/pre_alpha.rs @@ -36,11 +36,9 @@ use super::Premultiply; /// component to be clamped to [0.0, 1.0], and fully transparent colors will /// become black. #[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serializing", derive(Serialize, Deserialize))] #[repr(C)] pub struct PreAlpha { /// The premultiplied color components (`original.color * original.alpha`). - #[cfg_attr(feature = "serializing", serde(flatten))] pub color: C, /// The transparency component. 0.0 is fully transparent and 1.0 is fully @@ -347,6 +345,48 @@ impl DerefMut for PreAlpha { } } +#[cfg(feature = "serializing")] +impl serde::Serialize for PreAlpha +where + C: Premultiply + serde::Serialize, + C::Scalar: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.color.serialize(crate::serde::AlphaSerializer { + inner: serializer, + alpha: &self.alpha, + }) + } +} + +#[cfg(feature = "serializing")] +impl<'de, C> serde::Deserialize<'de> for PreAlpha +where + C: Premultiply + serde::Deserialize<'de>, + C::Scalar: serde::Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut alpha: Option = None; + + let color = C::deserialize(crate::serde::AlphaDeserializer { + inner: deserializer, + alpha: &mut alpha, + })?; + + if let Some(alpha) = alpha { + Ok(Self { color, alpha }) + } else { + Err(serde::de::Error::missing_field("alpha")) + } + } +} + #[cfg(feature = "bytemuck")] unsafe impl bytemuck::Zeroable for PreAlpha where @@ -380,25 +420,43 @@ mod test { alpha: 0.5, }; - let serialized = ::serde_json::to_string(&color).unwrap(); - assert_eq!( - serialized, + serde_json::to_string(&color).unwrap(), r#"{"red":0.3,"green":0.8,"blue":0.1,"alpha":0.5}"# ); + + assert_eq!( + ron::to_string(&color).unwrap(), + r#"(red:0.3,green:0.8,blue:0.1,alpha:0.5)"# + ); } #[cfg(feature = "serializing")] #[test] fn deserialize() { - let expected = PreAlpha { + let color = PreAlpha { color: LinSrgb::new(0.3, 0.8, 0.1), alpha: 0.5, }; - let deserialized: PreAlpha<_> = - ::serde_json::from_str(r#"{"red":0.3,"green":0.8,"blue":0.1,"alpha":0.5}"#).unwrap(); + assert_eq!( + serde_json::from_str::>( + r#"{"alpha":0.5,"red":0.3,"green":0.8,"blue":0.1}"# + ) + .unwrap(), + color + ); - assert_eq!(deserialized, expected); + assert_eq!( + ron::from_str::>(r#"(alpha:0.5,red:0.3,green:0.8,blue:0.1)"#) + .unwrap(), + color + ); + + assert_eq!( + ron::from_str::>(r#"Rgb(alpha:0.5,red:0.3,green:0.8,blue:0.1)"#) + .unwrap(), + color + ); } } diff --git a/palette/src/lib.rs b/palette/src/lib.rs index 0132efec0..29e7129a9 100644 --- a/palette/src/lib.rs +++ b/palette/src/lib.rs @@ -247,7 +247,7 @@ extern crate phf; #[cfg(feature = "serializing")] #[macro_use] -extern crate serde; +extern crate serde as _; #[cfg(all(test, feature = "serializing"))] extern crate serde_json; @@ -447,6 +447,9 @@ pub mod named; #[cfg(feature = "random")] mod random_sampling; +#[cfg(feature = "serializing")] +pub mod serde; + mod alpha; pub mod angle; pub mod blend; diff --git a/palette/src/serde.rs b/palette/src/serde.rs new file mode 100644 index 000000000..9dea3b087 --- /dev/null +++ b/palette/src/serde.rs @@ -0,0 +1,329 @@ +//! Utilities for serializing and deserializing with `serde`. +//! +//! These modules and functions can be combined with `serde`'s [field +//! attributes](https://serde.rs/field-attrs.html) to better control how to +//! serialize and deserialize colors. See each item's examples for more details. + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{ + blend::{PreAlpha, Premultiply}, + cast::{self, ArrayCast, UintCast}, + stimulus::Stimulus, + Alpha, +}; + +pub(crate) use self::{alpha_deserializer::AlphaDeserializer, alpha_serializer::AlphaSerializer}; + +mod alpha_deserializer; +mod alpha_serializer; + +/// Combines [`serialize_as_array`] and [`deserialize_as_array`] as a module for `#[serde(with = "...")]`. +/// +/// ``` +/// use serde::{Serialize, Deserialize}; +/// use palette::{Srgb, Srgba}; +/// +/// #[derive(Serialize, Deserialize, PartialEq, Debug)] +/// struct MyColors { +/// #[serde(with = "palette::serde::as_array")] +/// opaque: Srgb, +/// #[serde(with = "palette::serde::as_array")] +/// transparent: Srgba, +/// } +/// +/// let my_colors = MyColors { +/// opaque: Srgb::new(0.6, 0.8, 0.3), +/// transparent: Srgba::new(0.6, 0.8, 0.3, 0.5), +/// }; +/// +/// let json = serde_json::to_string(&my_colors).unwrap(); +/// +/// assert_eq!( +/// json, +/// r#"{"opaque":[0.6,0.8,0.3],"transparent":[0.6,0.8,0.3,0.5]}"# +/// ); +/// +/// assert_eq!( +/// serde_json::from_str::(&json).unwrap(), +/// my_colors +/// ); +/// ``` +pub mod as_array { + pub use super::deserialize_as_array as deserialize; + pub use super::serialize_as_array as serialize; +} + +/// Serialize the value as an array of its components. +/// +/// ``` +/// use serde::Serialize; +/// use palette::{Srgb, Srgba}; +/// +/// #[derive(Serialize)] +/// struct MyColors { +/// #[serde(serialize_with = "palette::serde::serialize_as_array")] +/// opaque: Srgb, +/// #[serde(serialize_with = "palette::serde::serialize_as_array")] +/// transparent: Srgba, +/// } +/// +/// let my_colors = MyColors { +/// opaque: Srgb::new(0.6, 0.8, 0.3), +/// transparent: Srgba::new(0.6, 0.8, 0.3, 0.5), +/// }; +/// +/// assert_eq!( +/// serde_json::to_string(&my_colors).unwrap(), +/// r#"{"opaque":[0.6,0.8,0.3],"transparent":[0.6,0.8,0.3,0.5]}"# +/// ); +/// ``` +pub fn serialize_as_array(value: &T, serializer: S) -> Result +where + T: ArrayCast, + T::Array: Serialize, + S: Serializer, +{ + cast::into_array_ref(value).serialize(serializer) +} + +/// Deserialize a value from an array of its components. +/// +/// ``` +/// use serde::Deserialize; +/// use palette::{Srgb, Srgba}; +/// +/// #[derive(Deserialize, PartialEq, Debug)] +/// struct MyColors { +/// #[serde(deserialize_with = "palette::serde::deserialize_as_array")] +/// opaque: Srgb, +/// #[serde(deserialize_with = "palette::serde::deserialize_as_array")] +/// transparent: Srgba, +/// } +/// +/// let my_colors = MyColors { +/// opaque: Srgb::new(0.6, 0.8, 0.3), +/// transparent: Srgba::new(0.6, 0.8, 0.3, 0.5), +/// }; +/// +/// let json = r#"{"opaque":[0.6,0.8,0.3],"transparent":[0.6,0.8,0.3,0.5]}"#; +/// assert_eq!( +/// serde_json::from_str::(json).unwrap(), +/// my_colors +/// ); +/// ``` +pub fn deserialize_as_array<'de, T, D>(deserializer: D) -> Result +where + T: ArrayCast, + T::Array: Deserialize<'de>, + D: Deserializer<'de>, +{ + Ok(cast::from_array(T::Array::deserialize(deserializer)?)) +} + +/// Combines [`serialize_as_uint`] and [`deserialize_as_uint`] as a module for `#[serde(with = "...")]`. +/// +/// ``` +/// use serde::{Serialize, Deserialize}; +/// use palette::{Srgb, Srgba, rgb::{PackedArgb, PackedRgba}}; +/// +/// #[derive(Serialize, Deserialize, PartialEq, Debug)] +/// struct MyColors { +/// #[serde(with = "palette::serde::as_uint")] +/// argb: PackedArgb, +/// #[serde(with = "palette::serde::as_uint")] +/// rgba: PackedRgba, +/// } +/// +/// let my_colors = MyColors { +/// argb: Srgb::new(0x17, 0xC6, 0x4C).into(), +/// rgba: Srgba::new(0x17, 0xC6, 0x4C, 0xFF).into(), +/// }; +/// +/// let json = serde_json::to_string(&my_colors).unwrap(); +/// +/// assert_eq!( +/// json, +/// r#"{"argb":4279748172,"rgba":398871807}"# +/// ); +/// +/// assert_eq!( +/// serde_json::from_str::(&json).unwrap(), +/// my_colors +/// ); +/// ``` +pub mod as_uint { + pub use super::deserialize_as_uint as deserialize; + pub use super::serialize_as_uint as serialize; +} + +/// Serialize the value as an unsigned integer. +/// +/// ``` +/// use serde::Serialize; +/// use palette::{Srgb, Srgba, rgb::{PackedArgb, PackedRgba}}; +/// +/// #[derive(Serialize)] +/// struct MyColors { +/// #[serde(serialize_with = "palette::serde::serialize_as_uint")] +/// argb: PackedArgb, +/// #[serde(serialize_with = "palette::serde::serialize_as_uint")] +/// rgba: PackedRgba, +/// } +/// +/// let my_colors = MyColors { +/// argb: Srgb::new(0x17, 0xC6, 0x4C).into(), +/// rgba: Srgba::new(0x17, 0xC6, 0x4C, 0xFF).into(), +/// }; +/// +/// assert_eq!( +/// serde_json::to_string(&my_colors).unwrap(), +/// r#"{"argb":4279748172,"rgba":398871807}"# +/// ); +/// ``` +pub fn serialize_as_uint(value: &T, serializer: S) -> Result +where + T: UintCast, + T::Uint: Serialize, + S: Serializer, +{ + cast::into_uint_ref(value).serialize(serializer) +} + +/// Deserialize a value from an unsigned integer. +/// +/// ``` +/// use serde::Deserialize; +/// use palette::{Srgb, Srgba, rgb::{PackedArgb, PackedRgba}}; +/// +/// #[derive(Deserialize, PartialEq, Debug)] +/// struct MyColors { +/// #[serde(deserialize_with = "palette::serde::deserialize_as_uint")] +/// argb: PackedArgb, +/// #[serde(deserialize_with = "palette::serde::deserialize_as_uint")] +/// rgba: PackedRgba, +/// } +/// +/// let my_colors = MyColors { +/// argb: Srgb::new(0x17, 0xC6, 0x4C).into(), +/// rgba: Srgba::new(0x17, 0xC6, 0x4C, 0xFF).into(), +/// }; +/// +/// let json = r#"{"argb":4279748172,"rgba":398871807}"#; +/// assert_eq!( +/// serde_json::from_str::(json).unwrap(), +/// my_colors +/// ); +/// ``` +pub fn deserialize_as_uint<'de, T, D>(deserializer: D) -> Result +where + T: UintCast, + T::Uint: Deserialize<'de>, + D: Deserializer<'de>, +{ + Ok(cast::from_uint(T::Uint::deserialize(deserializer)?)) +} + +/// Deserialize a transparent color without requiring the alpha to be specified. +/// +/// A color with missing alpha will be interpreted as fully opaque. +/// +/// ``` +/// use serde::Deserialize; +/// use palette::Srgba; +/// +/// #[derive(Deserialize, PartialEq, Debug)] +/// struct MyColors { +/// #[serde(deserialize_with = "palette::serde::deserialize_with_optional_alpha")] +/// opaque: Srgba, +/// #[serde(deserialize_with = "palette::serde::deserialize_with_optional_alpha")] +/// transparent: Srgba, +/// } +/// +/// let my_colors = MyColors { +/// opaque: Srgba::new(0.6, 0.8, 0.3, 1.0), +/// transparent: Srgba::new(0.6, 0.8, 0.3, 0.5), +/// }; +/// +/// let json = r#"{ +/// "opaque":{"red":0.6,"green":0.8,"blue":0.3}, +/// "transparent":{"red":0.6,"green":0.8,"blue":0.3,"alpha":0.5} +/// }"#; +/// assert_eq!( +/// serde_json::from_str::(json).unwrap(), +/// my_colors +/// ); +/// ``` +pub fn deserialize_with_optional_alpha<'de, T, A, D>( + deserializer: D, +) -> Result, D::Error> +where + T: Deserialize<'de>, + A: Stimulus + Deserialize<'de>, + D: Deserializer<'de>, +{ + let mut alpha: Option = None; + + let color = T::deserialize(crate::serde::AlphaDeserializer { + inner: deserializer, + alpha: &mut alpha, + })?; + + Ok(Alpha { + color, + alpha: alpha.unwrap_or_else(A::max_intensity), + }) +} + +/// Deserialize a premultiplied transparent color without requiring the alpha to be specified. +/// +/// A color with missing alpha will be interpreted as fully opaque. +/// +/// ``` +/// use serde::Deserialize; +/// use palette::{LinSrgba, LinSrgb, blend::PreAlpha}; +/// +/// type PreRgba = PreAlpha>; +/// +/// #[derive(Deserialize, PartialEq, Debug)] +/// struct MyColors { +/// #[serde(deserialize_with = "palette::serde::deserialize_with_optional_pre_alpha")] +/// opaque: PreRgba, +/// #[serde(deserialize_with = "palette::serde::deserialize_with_optional_pre_alpha")] +/// transparent: PreRgba, +/// } +/// +/// let my_colors = MyColors { +/// opaque: LinSrgba::new(0.6, 0.8, 0.3, 1.0).into(), +/// transparent: LinSrgba::new(0.6, 0.8, 0.3, 0.5).into(), +/// }; +/// +/// let json = r#"{ +/// "opaque":{"red":0.6,"green":0.8,"blue":0.3}, +/// "transparent":{"red":0.3,"green":0.4,"blue":0.15,"alpha":0.5} +/// }"#; +/// assert_eq!( +/// serde_json::from_str::(json).unwrap(), +/// my_colors +/// ); +/// ``` +pub fn deserialize_with_optional_pre_alpha<'de, T, D>( + deserializer: D, +) -> Result, D::Error> +where + T: Premultiply + Deserialize<'de>, + T::Scalar: Stimulus + Deserialize<'de>, + D: Deserializer<'de>, +{ + let mut alpha: Option = None; + + let color = T::deserialize(crate::serde::AlphaDeserializer { + inner: deserializer, + alpha: &mut alpha, + })?; + + Ok(PreAlpha { + color, + alpha: alpha.unwrap_or_else(T::Scalar::max_intensity), + }) +} diff --git a/palette/src/serde/alpha_deserializer.rs b/palette/src/serde/alpha_deserializer.rs new file mode 100644 index 000000000..966926994 --- /dev/null +++ b/palette/src/serde/alpha_deserializer.rs @@ -0,0 +1,796 @@ +use core::marker::PhantomData; + +use serde::{ + de::{DeserializeSeed, MapAccess, Visitor}, + Deserialize, Deserializer, +}; + +/// Deserializes a color with an attached alpha value. The alpha value is +/// expected to be found alongside the other values in a flattened structure. +pub(crate) struct AlphaDeserializer<'a, D, A> { + pub inner: D, + pub alpha: &'a mut Option, +} + +impl<'de, 'a, D, A> Deserializer<'de> for AlphaDeserializer<'a, D, A> +where + D: Deserializer<'de>, + A: Deserialize<'de>, +{ + type Error = D::Error; + + fn deserialize_seq(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.inner.deserialize_seq(AlphaSeqVisitor { + inner: visitor, + alpha: self.alpha, + }) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.inner.deserialize_tuple( + len + 1, + AlphaMapVisitor { + inner: visitor, + alpha: self.alpha, + field_count: Some(len), + }, + ) + } + + fn deserialize_tuple_struct( + self, + name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.inner.deserialize_tuple_struct( + name, + len + 1, + AlphaMapVisitor { + inner: visitor, + alpha: self.alpha, + field_count: Some(len), + }, + ) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.inner.deserialize_map(AlphaMapVisitor { + inner: visitor, + alpha: self.alpha, + field_count: None, + }) + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.inner.deserialize_struct( + name, + fields, // We can't add to the expected fields so we just hope it works anyway. + AlphaMapVisitor { + inner: visitor, + alpha: self.alpha, + field_count: Some(fields.len()), + }, + ) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.inner.deserialize_ignored_any(AlphaSeqVisitor { + inner: visitor, + alpha: self.alpha, + }) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.inner.deserialize_tuple( + 1, + AlphaMapVisitor { + inner: visitor, + alpha: self.alpha, + field_count: None, + }, + ) + } + + fn deserialize_unit_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.inner.deserialize_newtype_struct( + name, + AlphaMapVisitor { + inner: visitor, + alpha: self.alpha, + field_count: Some(0), + }, + ) + } + + fn deserialize_newtype_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_tuple_struct(name, 1, visitor) + } + + // Unsupported methods: + + fn deserialize_any(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_bool(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_i8(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_i16(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_i32(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_i64(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_u8(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_u16(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_u32(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_u64(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_f32(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_f64(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_char(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_str(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_string(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_bytes(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_byte_buf(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_option(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } + + fn deserialize_identifier(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + alpha_deserializer_error() + } +} + +fn alpha_deserializer_error() -> ! { + unimplemented!("AlphaDeserializer can only deserialize structs, maps and sequences") +} + +/// Deserializes a sequence with the alpha value last. +struct AlphaSeqVisitor<'a, D, A> { + inner: D, + alpha: &'a mut Option, +} + +impl<'de, 'a, D, A> Visitor<'de> for AlphaSeqVisitor<'a, D, A> +where + D: Visitor<'de>, + A: Deserialize<'de>, +{ + type Value = D::Value; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + self.inner.expecting(formatter)?; + write!(formatter, " with an alpha value") + } + + fn visit_seq(self, mut seq: T) -> Result + where + T: serde::de::SeqAccess<'de>, + { + let color = self.inner.visit_seq(&mut seq)?; + *self.alpha = seq.next_element()?; + + Ok(color) + } +} + +/// Deserializes a map or a struct with an "alpha" key, or a tuple with the +/// alpha value as the last value. +struct AlphaMapVisitor<'a, D, A> { + inner: D, + alpha: &'a mut Option, + field_count: Option, +} + +impl<'de, 'a, D, A> Visitor<'de> for AlphaMapVisitor<'a, D, A> +where + D: Visitor<'de>, + A: Deserialize<'de>, +{ + type Value = D::Value; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + self.inner.expecting(formatter)?; + write!(formatter, " with an alpha value") + } + + fn visit_seq(self, mut seq: T) -> Result + where + T: serde::de::SeqAccess<'de>, + { + let color = if self.field_count == None { + self.inner.visit_unit()? + } else { + self.inner.visit_seq(&mut seq)? + }; + *self.alpha = seq.next_element()?; + + Ok(color) + } + + fn visit_map(self, map: T) -> Result + where + T: serde::de::MapAccess<'de>, + { + self.inner.visit_map(MapWrapper { + inner: map, + alpha: self.alpha, + field_count: self.field_count, + }) + } + + fn visit_newtype_struct(self, deserializer: T) -> Result + where + T: Deserializer<'de>, + { + *self.alpha = Some(A::deserialize(deserializer)?); + self.inner.visit_unit() + } +} + +/// Intercepts map deserializing to catch the alpha value while deserializing +/// the entries. +struct MapWrapper<'a, T, A> { + inner: T, + alpha: &'a mut Option, + field_count: Option, +} + +impl<'a, 'de, T, A> MapAccess<'de> for MapWrapper<'a, T, A> +where + T: MapAccess<'de>, + A: Deserialize<'de>, +{ + type Error = T::Error; + + fn next_key_seed(&mut self, mut seed: K) -> Result, Self::Error> + where + K: serde::de::DeserializeSeed<'de>, + { + // Look for and extract the alpha value if its key is found, then return + // the next key after that. The first key that isn't alpha is + // immediately returned to the wrapped type's visitor. + loop { + seed = match self.inner.next_key_seed(AlphaFieldDeserializerSeed { + inner: seed, + field_count: self.field_count, + }) { + Ok(Some(AlphaField::Alpha(seed))) => { + // We found the alpha value, so deserialize it... + if self.alpha.is_some() { + return Err(serde::de::Error::duplicate_field("alpha")); + } + *self.alpha = Some(self.inner.next_value()?); + + // ...then give the seed back for the next key + seed + } + Ok(Some(AlphaField::Other(other))) => return Ok(Some(other)), + Ok(None) => return Ok(None), + Err(error) => return Err(error), + }; + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + self.inner.next_value_seed(seed) + } +} + +struct AlphaFieldDeserializerSeed { + inner: T, + field_count: Option, +} + +impl<'de, T> DeserializeSeed<'de> for AlphaFieldDeserializerSeed +where + T: DeserializeSeed<'de>, +{ + type Value = AlphaField; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_identifier(AlphaFieldVisitor { + inner: self.inner, + field_count: self.field_count, + }) + } +} + +/// An alpha struct field or another struct field. +enum AlphaField { + Alpha(A), + Other(O), +} + +/// A struct field name that hasn't been serialized yet. +enum StructField<'de> { + Unsigned(u64), + Str(&'de str), + Bytes(&'de [u8]), +} + +struct AlphaFieldVisitor { + inner: T, + field_count: Option, +} + +impl<'de, T> Visitor<'de> for AlphaFieldVisitor +where + T: DeserializeSeed<'de>, +{ + type Value = AlphaField; + + fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(formatter, "alpha field") + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + // We need the field count here to get the last tuple field. No field + // count implies that we definitely expected a struct or a map. + let field_count = self.field_count.ok_or(serde::de::Error::invalid_type( + serde::de::Unexpected::Unsigned(v), + &"map key or struct field", + ))?; + + // Assume that it's the alpha value if it's after the expected number of + // fields. Otherwise, pass on to the wrapped type's deserializer. + if v == field_count as u64 { + Ok(AlphaField::Alpha(self.inner)) + } else { + Ok(AlphaField::Other(self.inner.deserialize( + StructFieldDeserializer { + struct_field: StructField::Unsigned(v), + error: PhantomData, + }, + )?)) + } + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + // Assume that it's the alpha value if it's named "alpha". Otherwise, + // pass on to the wrapped type's deserializer. + if v == "alpha" { + Ok(AlphaField::Alpha(self.inner)) + } else { + Ok(AlphaField::Other(self.inner.deserialize( + StructFieldDeserializer { + struct_field: StructField::Str(v), + error: PhantomData, + }, + )?)) + } + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + // Assume that it's the alpha value if it's named "alpha". Otherwise, + // pass on to the wrapped type's deserializer. + if v == b"alpha" { + Ok(AlphaField::Alpha(self.inner)) + } else { + Ok(AlphaField::Other(self.inner.deserialize( + StructFieldDeserializer { + struct_field: StructField::Bytes(v), + error: PhantomData, + }, + )?)) + } + } +} + +/// Deserializes a non-alpha struct field name. +struct StructFieldDeserializer<'a, E> { + struct_field: StructField<'a>, + error: PhantomData E>, +} + +impl<'a, 'de, E> Deserializer<'de> for StructFieldDeserializer<'a, E> +where + E: serde::de::Error, +{ + type Error = E; + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.struct_field { + StructField::Unsigned(v) => visitor.visit_u64(v), + StructField::Str(v) => visitor.visit_str(v), + StructField::Bytes(v) => visitor.visit_bytes(v), + } + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_identifier(visitor) + } + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_identifier(visitor) + } + + // Unsupported methods:: + + fn deserialize_bool(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_i8(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_i16(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_i32(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_i64(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_u8(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_u16(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_u32(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_u64(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_f32(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_f64(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_char(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_str(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_string(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_bytes(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_byte_buf(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_option(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_unit(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_seq(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_tuple(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_map(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + struct_field_deserializer_error() + } +} + +fn struct_field_deserializer_error() -> ! { + unimplemented!("StructFieldDeserializer can only deserialize identifiers") +} diff --git a/palette/src/serde/alpha_serializer.rs b/palette/src/serde/alpha_serializer.rs new file mode 100644 index 000000000..094dd02d1 --- /dev/null +++ b/palette/src/serde/alpha_serializer.rs @@ -0,0 +1,420 @@ +use serde::{ + ser::{ + SerializeMap, SerializeSeq, SerializeStruct, SerializeStructVariant, SerializeTuple, + SerializeTupleStruct, SerializeTupleVariant, + }, + Serialize, Serializer, +}; + +/// Serializes a color with an attached alpha value. The alpha value is added +/// alongside the other values in a flattened structure. +pub(crate) struct AlphaSerializer<'a, S, A> { + pub inner: S, + pub alpha: &'a A, +} + +impl<'a, S, A> Serializer for AlphaSerializer<'a, S, A> +where + S: Serializer, + A: Serialize, +{ + type Ok = S::Ok; + + type Error = S::Error; + + type SerializeSeq = AlphaSerializer<'a, S::SerializeSeq, A>; + + type SerializeTuple = AlphaSerializer<'a, S::SerializeTuple, A>; + + type SerializeTupleStruct = AlphaSerializer<'a, S::SerializeTupleStruct, A>; + + type SerializeTupleVariant = AlphaSerializer<'a, S::SerializeTupleVariant, A>; + + type SerializeMap = AlphaSerializer<'a, S::SerializeMap, A>; + + type SerializeStruct = AlphaSerializer<'a, S::SerializeStruct, A>; + + type SerializeStructVariant = AlphaSerializer<'a, S::SerializeStructVariant, A>; + + fn serialize_seq(self, len: Option) -> Result { + Ok(AlphaSerializer { + inner: self.inner.serialize_seq(len.map(|len| len + 1))?, + alpha: self.alpha, + }) + } + + fn serialize_tuple(self, len: usize) -> Result { + Ok(AlphaSerializer { + inner: self.inner.serialize_tuple(len + 1)?, + alpha: self.alpha, + }) + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + Ok(AlphaSerializer { + inner: self.inner.serialize_tuple_struct(name, len + 1)?, + alpha: self.alpha, + }) + } + + fn serialize_map(self, len: Option) -> Result { + Ok(AlphaSerializer { + inner: self.inner.serialize_map(len.map(|len| len + 1))?, + alpha: self.alpha, + }) + } + + fn serialize_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + Ok(AlphaSerializer { + inner: self.inner.serialize_struct(name, len + 1)?, + alpha: self.alpha, + }) + } + + fn serialize_newtype_struct( + self, + name: &'static str, + value: &T, + ) -> Result + where + T: serde::Serialize, + { + let mut serializer = self.serialize_tuple_struct(name, 1)?; + serializer.serialize_field(value)?; + serializer.end() + } + + fn serialize_unit_struct(self, name: &'static str) -> Result { + self.inner.serialize_newtype_struct(name, self.alpha) + } + + fn serialize_unit(self) -> Result { + self.serialize_tuple(0)?.end() + } + + // Unsupported methods: + + fn serialize_bool(self, _v: bool) -> Result { + alpha_serializer_error() + } + + fn serialize_i8(self, _v: i8) -> Result { + alpha_serializer_error() + } + + fn serialize_i16(self, _v: i16) -> Result { + alpha_serializer_error() + } + + fn serialize_i32(self, _v: i32) -> Result { + alpha_serializer_error() + } + + fn serialize_i64(self, _v: i64) -> Result { + alpha_serializer_error() + } + + fn serialize_u8(self, _v: u8) -> Result { + alpha_serializer_error() + } + + fn serialize_u16(self, _v: u16) -> Result { + alpha_serializer_error() + } + + fn serialize_u32(self, _v: u32) -> Result { + alpha_serializer_error() + } + + fn serialize_u64(self, _v: u64) -> Result { + alpha_serializer_error() + } + + fn serialize_f32(self, _v: f32) -> Result { + alpha_serializer_error() + } + + fn serialize_f64(self, _v: f64) -> Result { + alpha_serializer_error() + } + + fn serialize_char(self, _v: char) -> Result { + alpha_serializer_error() + } + + fn serialize_str(self, _v: &str) -> Result { + alpha_serializer_error() + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + alpha_serializer_error() + } + + fn serialize_none(self) -> Result { + alpha_serializer_error() + } + + fn serialize_some(self, _value: &T) -> Result + where + T: serde::Serialize, + { + alpha_serializer_error() + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + alpha_serializer_error() + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + alpha_serializer_error() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + alpha_serializer_error() + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: serde::Serialize, + { + alpha_serializer_error() + } + + fn serialize_i128(self, v: i128) -> Result { + let _ = v; + alpha_serializer_error() + } + + fn serialize_u128(self, v: u128) -> Result { + let _ = v; + alpha_serializer_error() + } + + fn is_human_readable(&self) -> bool { + self.inner.is_human_readable() + } +} + +fn alpha_serializer_error() -> ! { + unimplemented!("AlphaSerializer can only serialize structs, maps and sequences") +} + +impl<'a, S, A> SerializeSeq for AlphaSerializer<'a, S, A> +where + S: SerializeSeq, + A: Serialize, +{ + type Ok = S::Ok; + + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + self.inner.serialize_element(value) + } + + fn end(mut self) -> Result { + self.inner.serialize_element(self.alpha)?; + self.inner.end() + } +} + +impl<'a, S, A> SerializeTuple for AlphaSerializer<'a, S, A> +where + S: SerializeTuple, + A: Serialize, +{ + type Ok = S::Ok; + + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + self.inner.serialize_element(value) + } + + fn end(mut self) -> Result { + self.inner.serialize_element(self.alpha)?; + self.inner.end() + } +} + +impl<'a, S, A> SerializeTupleStruct for AlphaSerializer<'a, S, A> +where + S: SerializeTupleStruct, + A: Serialize, +{ + type Ok = S::Ok; + + type Error = S::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + self.inner.serialize_field(value) + } + + fn end(mut self) -> Result { + self.inner.serialize_field(self.alpha)?; + self.inner.end() + } +} + +impl<'a, S, A> SerializeTupleVariant for AlphaSerializer<'a, S, A> +where + S: SerializeTupleVariant, + A: Serialize, +{ + type Ok = S::Ok; + + type Error = S::Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + self.inner.serialize_field(value) + } + + fn end(mut self) -> Result { + self.inner.serialize_field(self.alpha)?; + self.inner.end() + } +} + +impl<'a, S, A> SerializeMap for AlphaSerializer<'a, S, A> +where + S: SerializeMap, + A: Serialize, +{ + type Ok = S::Ok; + + type Error = S::Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + self.inner.serialize_key(key) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: Serialize, + { + self.inner.serialize_value(value) + } + + fn serialize_entry( + &mut self, + key: &K, + value: &V, + ) -> Result<(), Self::Error> + where + K: Serialize, + V: Serialize, + { + self.inner.serialize_entry(key, value) + } + + fn end(mut self) -> Result { + self.inner.serialize_entry("alpha", self.alpha)?; + self.inner.end() + } +} + +impl<'a, S, A> SerializeStruct for AlphaSerializer<'a, S, A> +where + S: SerializeStruct, + A: Serialize, +{ + type Ok = S::Ok; + + type Error = S::Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> + where + T: Serialize, + { + self.inner.serialize_field(key, value) + } + + fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> { + self.inner.skip_field(key) + } + + fn end(mut self) -> Result { + self.inner.serialize_field("alpha", self.alpha)?; + self.inner.end() + } +} + +impl<'a, S, A> SerializeStructVariant for AlphaSerializer<'a, S, A> +where + S: SerializeStructVariant, + A: Serialize, +{ + type Ok = S::Ok; + + type Error = S::Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> + where + T: Serialize, + { + self.inner.serialize_field(key, value) + } + + fn skip_field(&mut self, key: &'static str) -> Result<(), Self::Error> { + self.inner.skip_field(key) + } + + fn end(mut self) -> Result { + self.inner.serialize_field("alpha", self.alpha)?; + self.inner.end() + } +}