diff --git a/src/keras_format/constraint_config.ts b/src/keras_format/constraint_config.ts index 01ee061a4..c0091fa60 100644 --- a/src/keras_format/constraint_config.ts +++ b/src/keras_format/constraint_config.ts @@ -38,3 +38,5 @@ export type MinMaxNormSerialization = export type ConstraintSerialization = MaxNormSerialization|NonNegSerialization| UnitNormSerialization|MinMaxNormSerialization; + +export type ConstraintClassName = ConstraintSerialization['class_name']; diff --git a/src/keras_format/initializer_config.ts b/src/keras_format/initializer_config.ts index 17415cb19..bf3697bdf 100644 --- a/src/keras_format/initializer_config.ts +++ b/src/keras_format/initializer_config.ts @@ -92,3 +92,5 @@ export type InitializerSerialization = ZerosSerialization|OnesSerialization| ConstantSerialization|RandomUniformSerialization|RandomNormalSerialization| TruncatedNormalSerialization|IdentitySerialization| VarianceScalingSerialization|OrthogonalSerialization; + +export type InitializerClassName = InitializerSerialization['class_name']; diff --git a/src/keras_format/keras_class_names.ts b/src/keras_format/keras_class_names.ts new file mode 100644 index 000000000..7258a8f85 --- /dev/null +++ b/src/keras_format/keras_class_names.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2018 Google LLC + * + * Use of this source code is governed by an MIT-style + * license that can be found in the LICENSE file or at + * https://opensource.org/licenses/MIT. + * ============================================================================= + */ + +import {ConstraintClassName} from './constraint_config'; +import {InitializerClassName} from './initializer_config'; +import {LayerClassName} from './layers/layer_serialization'; +import {RegularizerClassName} from './regularizer_config'; + +/** + * A type representing all valid values of `class_name` in a Keras JSON file + * (regardless of context, which will naturally further restrict the valid + * values). + */ +export type KerasClassName = LayerClassName|ConstraintClassName| + InitializerClassName|RegularizerClassName; diff --git a/src/keras_format/layers/pooling_serialization.ts b/src/keras_format/layers/pooling_serialization.ts index 4aa183566..6830ab3f9 100644 --- a/src/keras_format/layers/pooling_serialization.ts +++ b/src/keras_format/layers/pooling_serialization.ts @@ -59,4 +59,3 @@ export type PoolingLayerSerialization = MaxPooling1DLayerSerialization| GlobalMaxPooling1DLayerSerialization| GlobalAveragePooling2DLayerSerialization| GlobalMaxPooling2DLayerSerialization; -; diff --git a/src/keras_format/regularizer_config.ts b/src/keras_format/regularizer_config.ts index f889c948d..7371e66ff 100644 --- a/src/keras_format/regularizer_config.ts +++ b/src/keras_format/regularizer_config.ts @@ -31,3 +31,5 @@ export type L2Serialization = BaseSerialization<'L2', L2Config>; export type RegularizerSerialization = L1L2Serialization|L1Serialization|L2Serialization; + +export type RegularizerClassName = RegularizerSerialization['class_name'];