diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java index b71ed9c36..1004e7ddc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/ChiSqSelectorBatchOp.java @@ -25,7 +25,7 @@ public ChiSqSelectorBatchOp linkFrom(BatchOperator... inputs) { String[] selectedColNames = getSelectedCols(); String labelColName = getLabelCol(); - String selectorType = getParams().get(SELECTOR_TYPE).trim().toLowerCase(); + SelectorType selectorType = getParams().get(SELECTOR_TYPE); int numTopFeatures = getParams().get(NUM_TOP_FEATURES); double percentile = getParams().get(PERCENTILE); double fpr = getParams().get(FPR); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java index 41affaa4b..1e992e989 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/PcaTrainBatchOp.java @@ -1,7 +1,6 @@ package com.alibaba.alink.operator.batch.feature; import com.alibaba.alink.common.linalg.*; -import com.alibaba.alink.operator.common.feature.pca.PcaTypeEnum; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.java.DataSet; @@ -328,7 +327,7 @@ public void mapPartition(Iterable> splitVec, Collec PcaModelData pcr = new PcaModelData(); //get correlation or covariance matrix - PcaTypeEnum pcaTypeEnum = PcaTypeEnum.valueOf(pcaType.toUpperCase()); + CalculationType pcaTypeEnum = CalculationType.valueOf(pcaType.toUpperCase()); double[][] corr = null; @@ -346,7 +345,7 @@ public void mapPartition(Iterable> splitVec, Collec DenseMatrix calculateMatrix = new DenseMatrix(corr); - if (pcaTypeEnum.equals(PcaTypeEnum.COVAR_POP)) { + if (pcaTypeEnum.equals(CalculationType.COVAR_POP)) { double cnt = counts[0]; if (cnt > 1) { calculateMatrix.scaleEqual(cnt / (cnt - 1)); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java index dfeb3d5ee..4a7a4418c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/VectorChiSqSelectorBatchOp.java @@ -24,7 +24,7 @@ public VectorChiSqSelectorBatchOp linkFrom(BatchOperator... inputs) { String vectorColName = getSelectedCol(); String labelColName = getLabelCol(); - String selectorType = getParams().get(SELECTOR_TYPE).trim().toLowerCase(); + SelectorType selectorType = getParams().get(SELECTOR_TYPE); int numTopFeatures = getParams().get(NUM_TOP_FEATURES); double percentile = getParams().get(PERCENTILE); double fpr = getParams().get(FPR); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmEvaluationBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmEvaluationBatchOp.java index 511aa5013..2f0b66664 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmEvaluationBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmEvaluationBatchOp.java @@ -36,8 +36,8 @@ public GlmEvaluationBatchOp linkFrom(BatchOperator... inputs) { String weightColName = getWeightCol(); String offsetColName = getOffsetCol(); - String familyName = getFamily(); - String linkName = getLink(); + Family familyName = getFamily(); + Link linkName = getLink(); double variancePower = getVariancePower(); double linkPower = getLinkPower(); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.java index cfb5ec056..976c1bcdb 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/regression/GlmTrainBatchOp.java @@ -48,8 +48,8 @@ public GlmTrainBatchOp linkFrom(BatchOperator... inputs) { String weightColName = getWeightCol(); String offsetColName = getOffsetCol(); - String familyName = getFamily(); - String linkName = getLink(); + Family familyName = getFamily(); + Link linkName = getLink(); double variancePower = getVariancePower(); double linkPower = getLinkPower(); @@ -125,9 +125,9 @@ private static class BuildModel implements MapPartitionFunction... inputs) { //check col types must be double or bigint TableUtil.assertNumericalCols(in.getSchema(), selectedColNames); - String corrType = getMethod().trim().toLowerCase(); + Method corrType = getMethod(); - if ("pearson".equals(corrType)) { + if (Method.PEARSON == corrType) { DataSet> srt = StatisticsHelper.pearsonCorrelation(in, selectedColNames); diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java index 4c915bf6e..fe60d5247 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/VectorCorrelationBatchOp.java @@ -39,9 +39,9 @@ public VectorCorrelationBatchOp linkFrom(BatchOperator... inputs) { BatchOperator in = checkAndGetFirst(inputs); String vectorColName = getSelectedCol(); - String corrType = getMethod().trim().toLowerCase(); + Method corrType = getMethod(); - if ("pearson".equals(corrType)) { + if (Method.PEARSON == corrType) { DataSet> srt = StatisticsHelper.vectorPearsonCorrelation(in, vectorColName); //block diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelDataConverter.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelDataConverter.java index b78b87e0e..4efbfe8dc 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelDataConverter.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelDataConverter.java @@ -7,6 +7,7 @@ import org.apache.flink.ml.api.misc.param.Params; import com.alibaba.alink.common.utils.JsonConverter; +import com.alibaba.alink.params.feature.HasCalculationType; import com.alibaba.alink.params.shared.colname.HasFeatureColsDefaultAsNull; import com.alibaba.alink.params.feature.PcaTrainParams; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java index a3c0f788d..dc07020ae 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaModelMapper.java @@ -6,6 +6,7 @@ import com.alibaba.alink.common.mapper.ModelMapper; import com.alibaba.alink.common.utils.OutputColsHelper; import com.alibaba.alink.common.utils.TableUtil; +import com.alibaba.alink.params.feature.HasCalculationType; import com.alibaba.alink.params.feature.PcaPredictParams; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; @@ -25,7 +26,7 @@ public class PcaModelMapper extends ModelMapper { private int[] featureIdxs = null; private boolean isVector; - private String transformType = null; + private PcaPredictParams.TransformType transformType = null; private String pcaType = null; private double[] sourceMean = null; @@ -77,8 +78,7 @@ public void loadModel(List modelRows) { int nx = model.means.length; int p = model.p; - PcaTransformTypeEnum transformTypeEnum = PcaTransformTypeEnum.valueOf(this.transformType.toUpperCase()); - PcaTypeEnum pcaTypeEnum = PcaTypeEnum.valueOf(this.pcaType.toUpperCase()); + HasCalculationType.CalculationType pcaTypeEnum = HasCalculationType.CalculationType.valueOf(this.pcaType.toUpperCase()); //transform mean, stdDevs and scoreStd sourceMean = new double[nx]; @@ -88,11 +88,11 @@ public void loadModel(List modelRows) { Arrays.fill(sourceStd, 1); Arrays.fill(scoreStd, 1); - if (PcaTypeEnum.CORR.equals(pcaTypeEnum)) { + if (HasCalculationType.CalculationType.CORR.equals(pcaTypeEnum)) { sourceStd = model.stddevs; } - switch (transformTypeEnum) { + switch (transformType) { case SUBMEAN: sourceMean = model.means; break; @@ -108,6 +108,10 @@ public void loadModel(List modelRows) { scoreStd[i] = Math.sqrt(tmp); } break; + case SIMPLE: + break; + default: + throw new IllegalArgumentException("Error transformType: " + transformType); } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaTransformTypeEnum.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaTransformTypeEnum.java deleted file mode 100644 index 14f3a605a..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaTransformTypeEnum.java +++ /dev/null @@ -1,21 +0,0 @@ -package com.alibaba.alink.operator.common.feature.pca; - -/** - * pca transform type. - */ -public enum PcaTransformTypeEnum { - /** - * data * model - */ - SIMPLE, - - /** - * (data - mean) * model - */ - SUBMEAN, - - /** - * (data - mean) / stdVar * model - */ - NORMALIZATION -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaTypeEnum.java b/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaTypeEnum.java deleted file mode 100644 index c814ec647..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/feature/pca/PcaTypeEnum.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.alibaba.alink.operator.common.feature.pca; - -/** - * pca calculation type. - */ -public enum PcaTypeEnum { - /** - * correlation - */ - CORR, - /** - * sample variance - */ - COV_SAMPLE, - - /** - * population variance - */ - COVAR_POP -} diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/GlmModelData.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/GlmModelData.java index 943d131a5..4b4223a1b 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/regression/GlmModelData.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/GlmModelData.java @@ -1,6 +1,8 @@ package com.alibaba.alink.operator.common.regression; +import com.alibaba.alink.params.regression.GlmTrainParams; + /** * Glm model data. */ @@ -51,7 +53,7 @@ public class GlmModelData { /** * family name. */ - public String familyName; + public GlmTrainParams.Family familyName; /** * variance power of family. @@ -61,7 +63,7 @@ public class GlmModelData { /** * link function name. */ - public String linkName; + public GlmTrainParams.Link linkName; /** * power of link function. diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/GlmModelMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/GlmModelMapper.java index 848fe9e2f..5dfeb9ce4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/regression/GlmModelMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/GlmModelMapper.java @@ -98,9 +98,9 @@ public void loadModel(List modelRows) { features = new double[featureColIdxs.length]; - String familyName = params.get(GlmTrainParams.FAMILY); + GlmTrainParams.Family familyName = params.get(GlmTrainParams.FAMILY); double variancePower = params.get(GlmTrainParams.VARIANCE_POWER); - String linkName = params.get(GlmTrainParams.LINK); + GlmTrainParams.Link linkName = params.get(GlmTrainParams.LINK); double linkPower = params.get(GlmTrainParams.LINK_POWER); familyLink = new FamilyLink(familyName, variancePower, linkName, linkPower); diff --git a/core/src/main/java/com/alibaba/alink/operator/common/regression/glm/FamilyLink.java b/core/src/main/java/com/alibaba/alink/operator/common/regression/glm/FamilyLink.java index 37ec4ce6e..423b23900 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/regression/glm/FamilyLink.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/regression/glm/FamilyLink.java @@ -15,6 +15,7 @@ import com.alibaba.alink.operator.common.regression.glm.link.Power; import com.alibaba.alink.operator.common.regression.glm.link.Probit; import com.alibaba.alink.operator.common.regression.glm.link.Sqrt; +import com.alibaba.alink.params.regression.GlmTrainParams; import java.io.Serializable; @@ -32,57 +33,57 @@ public class FamilyLink implements Serializable { * @param linkName: link name. * @param linkPower: link power. */ - public FamilyLink(String familyName, double variancePower, String linkName, double linkPower) { - if (familyName == null || familyName.isEmpty()) { + public FamilyLink(GlmTrainParams.Family familyName, double variancePower, GlmTrainParams.Link linkName, double linkPower) { + if (familyName == null) { throw new RuntimeException("family can not be empty"); } - switch (familyName.toLowerCase()) { - case "gamma": + switch (familyName) { + case Gamma: family = new Gamma(); break; - case "binomial": + case Binomial: family = new Binomial(); break; - case "gaussian": + case Gaussian: family = new Gaussian(); break; - case "poisson": + case Poisson: family = new Poisson(); break; - case "tweedie": + case Tweedie: family = new Tweedie(variancePower); break; default: throw new RuntimeException("family is not support. "); } - if (linkName == null || linkName.isEmpty()) { + if (linkName == null) { link = family.getDefaultLink(); } else { - switch (linkName.toLowerCase()) { - case "cloglog": + switch (linkName) { + case CLogLog: link = new CLogLog(); break; - case "identity": + case Identity: link = new Identity(); break; - case "inverse": + case Inverse: link = new Inverse(); break; - case "log": + case Log: link = new Log(); break; - case "logit": + case Logit: link = new Logit(); break; - case "power": + case Power: link = new Power(linkPower); break; - case "probit": + case Probit: link = new Probit(); break; - case "sqrt": + case Sqrt: link = new Sqrt(); break; default: diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java index 978a29b15..03525dc37 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTest.java @@ -15,6 +15,7 @@ import com.alibaba.alink.common.utils.DataSetConversionUtil; import com.alibaba.alink.operator.common.feature.ChiSqSelectorModelDataConverter; +import com.alibaba.alink.params.feature.BasedChisqSelectorParams; import com.google.common.primitives.Ints; import org.apache.commons.math3.distribution.ChiSquaredDistribution; @@ -95,7 +96,7 @@ public void reduce(Iterable iterable, Collector> * @return selected col indices. */ protected static int[] selector(List chiSquareTest, - String selectorType, + BasedChisqSelectorParams.SelectorType selectorType, int numTopFeatures, double percentile, double fpr, @@ -105,14 +106,9 @@ protected static int[] selector(List chiSquareTest, int len = chiSquareTest.size(); - if(selectorType.toUpperCase().equals("NUMTOPFEATURES")) { - selectorType = "NUM_TOP_FEATURES"; - } - ChiSqSelectorType type = ChiSqSelectorType.valueOf(selectorType.toUpperCase()); - List selectedColIndices = new ArrayList<>(); - switch (type) { - case NUM_TOP_FEATURES: + switch (selectorType) { + case NumTopFeatures: chiSquareTest.sort(new RowAscComparator()); for (int i = 0; i < numTopFeatures && i < len; i++) { @@ -211,14 +207,14 @@ protected static Tuple4 test(Tuple2 { - private String selectorType; + private BasedChisqSelectorParams.SelectorType selectorType; private int numTopFeatures; private double percentile; private double fpr; private double fdr; private double fwe; - ChiSquareSelector(String selectorType, int numTopFeatures, + ChiSquareSelector(BasedChisqSelectorParams.SelectorType selectorType, int numTopFeatures, double percentile, double fpr, double fdr, double fwe) { this.selectorType = selectorType; @@ -291,35 +287,4 @@ public Row map(Tuple2 crossTabWithId) throws Exception { return row; } } - - /** - * chi-square selector type. - */ - - public enum ChiSqSelectorType { - /** - * select numTopFeatures features which are maximum chi-square value. - */ - NUM_TOP_FEATURES, - - /** - * select percentile * n features which are maximum chi-square value. - */ - PERCENTILE, - - /** - * select feature which chi-square value less than fpr. - */ - FPR, - - /** - * select feature which chi-square value less than fdr * (i + 1) / n. - */ - FDR, - - /** - * select feature which chi-square value less than fwe / n. - */ - FWE - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java index ed9423ef2..701aa252c 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestUtil.java @@ -4,6 +4,7 @@ import com.alibaba.alink.operator.common.feature.ChiSqSelectorModelDataConverter; import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.common.utils.DataSetConversionUtil; +import com.alibaba.alink.params.feature.BasedChisqSelectorParams; import org.apache.commons.lang3.ArrayUtils; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -64,7 +65,7 @@ public static DataSet test(BatchOperator in, public static Table selector(BatchOperator in, String[] selectedColNames, String labelColName, - String selectorType, + BasedChisqSelectorParams.SelectorType selectorType, int numTopFeatures, double percentile, double fpr, @@ -97,7 +98,7 @@ public static Table selector(BatchOperator in, public static Table vectorSelector(BatchOperator in, String selectedColName, String labelColName, - String selectorType, + BasedChisqSelectorParams.SelectorType selectorType, int numTopFeatures, double percentile, double fpr, diff --git a/core/src/main/java/com/alibaba/alink/params/feature/BasedChisqSelectorParams.java b/core/src/main/java/com/alibaba/alink/params/feature/BasedChisqSelectorParams.java index 7cd6b3c12..c7e843a63 100644 --- a/core/src/main/java/com/alibaba/alink/params/feature/BasedChisqSelectorParams.java +++ b/core/src/main/java/com/alibaba/alink/params/feature/BasedChisqSelectorParams.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.ParamInfo; import org.apache.flink.ml.api.misc.param.ParamInfoFactory; +import com.alibaba.alink.params.ParamUtil; import com.alibaba.alink.params.shared.colname.HasLabelCol; /** @@ -11,10 +12,11 @@ public interface BasedChisqSelectorParams extends HasLabelCol { - ParamInfo SELECTOR_TYPE = ParamInfoFactory.createParamInfo("selectorType", String.class) - .setDescription("The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`,\n" + + ParamInfo SELECTOR_TYPE = ParamInfoFactory.createParamInfo("selectorType", + SelectorType.class) + .setDescription("The selector supports different selection methods: `NumTopFeatures`, `percentile`, `fpr`,\n" + " `fdr`, `fwe`.\n" + - " - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.\n" + + " - `NumTopFeatures` chooses a fixed number of top features according to a chi-squared test.\n" + " - `percentile` is similar but chooses a fraction of all features instead of a fixed number.\n" + " - `fpr` chooses all features whose p-values are below a threshold, thus controlling the false\n" + " positive rate of selection.\n" + @@ -23,10 +25,11 @@ public interface BasedChisqSelectorParams extends " to choose all features whose false discovery rate is below a threshold.\n" + " - `fwe` chooses all features whose p-values are below a threshold. The threshold is scaled by\n" + " 1/numFeatures, thus controlling the family-wise error rate of selection.\n" + - " By default, the selection method is `numTopFeatures`, with the default number of top features") + " By default, the selection method is `NumTopFeatures`, with the default number of top features") .setOptional() - .setHasDefaultValue("numTopFeatures") + .setHasDefaultValue(SelectorType.NumTopFeatures) .build(); + ParamInfo NUM_TOP_FEATURES = ParamInfoFactory.createParamInfo("numTopFeatures", Integer.class) .setDescription("Number of features that selector will select, ordered by ascending p-value. If the" + " number of features is < numTopFeatures, then this will select all features." + @@ -58,14 +61,18 @@ public interface BasedChisqSelectorParams extends .setHasDefaultValue(0.05) .build(); - default String getSelectorType() { + default SelectorType getSelectorType() { return get(SELECTOR_TYPE); } - default T setSelectorType(String value) { + default T setSelectorType(SelectorType value) { return set(SELECTOR_TYPE, value); } + default T setSelectorType(String value) { + return set(SELECTOR_TYPE, ParamUtil.searchEnum(SELECTOR_TYPE, value)); + } + default Integer getNumTopFeatures() { return get(NUM_TOP_FEATURES); } @@ -106,4 +113,33 @@ default T setFwe(Double value) { return set(FWE, value); } + /** + * chi-square selector type. + */ + enum SelectorType { + /** + * select NumTopFeatures features which are maximum chi-square value. + */ + NumTopFeatures, + + /** + * select percentile * n features which are maximum chi-square value. + */ + PERCENTILE, + + /** + * select feature which chi-square value less than fpr. + */ + FPR, + + /** + * select feature which chi-square value less than fdr * (i + 1) / n. + */ + FDR, + + /** + * select feature which chi-square value less than fwe / n. + */ + FWE + } } diff --git a/core/src/main/java/com/alibaba/alink/params/feature/HasCalculationType.java b/core/src/main/java/com/alibaba/alink/params/feature/HasCalculationType.java index fbda6dbac..36c602cc0 100644 --- a/core/src/main/java/com/alibaba/alink/params/feature/HasCalculationType.java +++ b/core/src/main/java/com/alibaba/alink/params/feature/HasCalculationType.java @@ -1,18 +1,19 @@ package com.alibaba.alink.params.feature; -import com.alibaba.alink.params.ParamUtil; import org.apache.flink.ml.api.misc.param.ParamInfo; import org.apache.flink.ml.api.misc.param.ParamInfoFactory; import org.apache.flink.ml.api.misc.param.WithParams; +import com.alibaba.alink.params.ParamUtil; + public interface HasCalculationType extends WithParams { ParamInfo CALCULATION_TYPE = ParamInfoFactory - .createParamInfo("calculationType", CalculationType.class) - .setDescription("compute type, be CORR, COV_SAMPLE, COVAR_POP.") - .setHasDefaultValue(CalculationType.CORR) - .setAlias(new String[]{"calcType", "pcaType"}) - .build(); + .createParamInfo("calculationType", CalculationType.class) + .setDescription("compute type, be CORR, COV_SAMPLE, COVAR_POP.") + .setHasDefaultValue(CalculationType.CORR) + .setAlias(new String[]{"calcType", "pcaType"}) + .build(); default CalculationType getCalculationType() { return get(CALCULATION_TYPE); @@ -44,4 +45,4 @@ enum CalculationType { */ COVAR_POP } -} \ No newline at end of file +} diff --git a/core/src/main/java/com/alibaba/alink/params/feature/PcaPredictParams.java b/core/src/main/java/com/alibaba/alink/params/feature/PcaPredictParams.java index 155e7cd1f..c4baa1084 100644 --- a/core/src/main/java/com/alibaba/alink/params/feature/PcaPredictParams.java +++ b/core/src/main/java/com/alibaba/alink/params/feature/PcaPredictParams.java @@ -3,6 +3,7 @@ import org.apache.flink.ml.api.misc.param.ParamInfo; import org.apache.flink.ml.api.misc.param.ParamInfoFactory; +import com.alibaba.alink.params.ParamUtil; import com.alibaba.alink.params.shared.colname.HasPredictionCol; import com.alibaba.alink.params.shared.colname.HasReservedCols; import com.alibaba.alink.params.shared.colname.HasVectorColDefaultAsNull; @@ -15,17 +16,41 @@ public interface PcaPredictParams extends HasPredictionCol, HasVectorColDefaultAsNull { - ParamInfo TRANSFORM_TYPE = ParamInfoFactory - .createParamInfo("transformType", String.class) + ParamInfo TRANSFORM_TYPE = ParamInfoFactory + .createParamInfo("transformType", TransformType.class) .setDescription("'SIMPLE' or 'SUBMEAN', SIMPLE is data * model, SUBMEAN is (data - mean) * model") - .setHasDefaultValue("SIMPLE") + .setHasDefaultValue(TransformType.SIMPLE) .build(); - default String getTransformType() { + default TransformType getTransformType() { return get(TRANSFORM_TYPE); } - default T setTransformType(String value) { + default T setTransformType(TransformType value) { return set(TRANSFORM_TYPE, value); } + + default T setTransformType(String value) { + return set(TRANSFORM_TYPE, ParamUtil.searchEnum(TRANSFORM_TYPE, value)); + } + + /** + * pca transform type. + */ + enum TransformType { + /** + * data * model + */ + SIMPLE, + + /** + * (data - mean) * model + */ + SUBMEAN, + + /** + * (data - mean) / stdVar * model + */ + NORMALIZATION + } } diff --git a/core/src/main/java/com/alibaba/alink/params/feature/PcaTrainParams.java b/core/src/main/java/com/alibaba/alink/params/feature/PcaTrainParams.java index c1ccf2f58..04ec6cc70 100644 --- a/core/src/main/java/com/alibaba/alink/params/feature/PcaTrainParams.java +++ b/core/src/main/java/com/alibaba/alink/params/feature/PcaTrainParams.java @@ -16,5 +16,4 @@ public interface PcaTrainParams extends HasWithStd, HasK, HasCalculationType { - } diff --git a/core/src/main/java/com/alibaba/alink/params/regression/GlmTrainParams.java b/core/src/main/java/com/alibaba/alink/params/regression/GlmTrainParams.java index 9470faaeb..293528461 100644 --- a/core/src/main/java/com/alibaba/alink/params/regression/GlmTrainParams.java +++ b/core/src/main/java/com/alibaba/alink/params/regression/GlmTrainParams.java @@ -3,6 +3,8 @@ import org.apache.flink.ml.api.misc.param.ParamInfo; import org.apache.flink.ml.api.misc.param.ParamInfoFactory; +import com.alibaba.alink.operator.common.regression.glm.famliy.Family; +import com.alibaba.alink.params.ParamUtil; import com.alibaba.alink.params.shared.colname.HasWeightColDefaultAsNull; import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs10; @@ -14,11 +16,11 @@ public interface GlmTrainParams extends HasWeightColDefaultAsNull, HasMaxIterDefaultAs10 { - ParamInfo FAMILY = ParamInfoFactory - .createParamInfo("family", String.class) + ParamInfo FAMILY = ParamInfoFactory + .createParamInfo("family", Family.class) .setDescription("the name of family which is a description of the error distribution. " + - "Supported options: gaussian, binomial, poisson, gamma and tweedie") - .setHasDefaultValue("gaussian") + "Supported options: Gaussian, Binomial, Poisson, Gamma and Tweedie") + .setHasDefaultValue(Family.Gaussian) .build(); ParamInfo VARIANCE_POWER = ParamInfoFactory .createParamInfo("variancePower", Double.class) @@ -26,10 +28,10 @@ public interface GlmTrainParams extends "It describe the relationship between the variance and mean of the distribution") .setHasDefaultValue(0.0) .build(); - ParamInfo LINK = ParamInfoFactory - .createParamInfo("link", String.class) + ParamInfo LINK = ParamInfoFactory + .createParamInfo("link", Link.class) .setDescription("The name of link function" + - "Supported options: cloglog, identity, inverse, log, logit, power, probit and sqrt") + "Supported options: CLogLog, Identity, Inverse, log, logit, power, probit and sqrt") .setHasDefaultValue(null) .build(); ParamInfo LINK_POWER = ParamInfoFactory @@ -59,14 +61,18 @@ public interface GlmTrainParams extends .setHasDefaultValue(1.0e-5) .build(); - default String getFamily() { + default Family getFamily() { return get(FAMILY); } - default T setFamily(String value) { + default T setFamily(Family value) { return set(FAMILY, value); } + default T setFamily(String value) { + return set(FAMILY, ParamUtil.searchEnum(FAMILY, value)); + } + default Double getVariancePower() { return get(VARIANCE_POWER); } @@ -75,14 +81,18 @@ default T setVariancePower(Double value) { return set(VARIANCE_POWER, value); } - default String getLink() { + default Link getLink() { return get(LINK); } - default T setLink(String value) { + default T setLink(Link value) { return set(LINK, value); } + default T setLink(String value) { + return set(LINK, ParamUtil.searchEnum(LINK, value)); + } + default Double getLinkPower() { return get(LINK_POWER); } @@ -119,4 +129,22 @@ default Double getEpsilon() { return get(EPSILON); } + enum Family { + Gamma, + Binomial, + Gaussian, + Poisson, + Tweedie + } + + enum Link { + CLogLog, + Identity, + Inverse, + Log, + Logit, + Power, + Probit, + Sqrt + } } diff --git a/core/src/main/java/com/alibaba/alink/params/statistics/HasMethod.java b/core/src/main/java/com/alibaba/alink/params/statistics/HasMethod.java index 4f23a4395..eec3d3982 100644 --- a/core/src/main/java/com/alibaba/alink/params/statistics/HasMethod.java +++ b/core/src/main/java/com/alibaba/alink/params/statistics/HasMethod.java @@ -5,23 +5,34 @@ import org.apache.flink.ml.api.misc.param.WithParams; +import com.alibaba.alink.params.ParamUtil; + /** * Parameter of correlation method. */ public interface HasMethod extends WithParams { - ParamInfo METHOD = ParamInfoFactory - .createParamInfo("method", String.class) - .setDescription("method: pearson, spearman. default pearson") - .setHasDefaultValue("pearson") + ParamInfo METHOD = ParamInfoFactory + .createParamInfo("method", Method.class) + .setDescription("method: PEARSON, SPEAMAN. default PEARSON") + .setHasDefaultValue(Method.PEARSON) .build(); - default String getMethod() { + default Method getMethod() { return get(METHOD); } - default T setMethod(String value) { + default T setMethod(Method value) { return set(METHOD, value); } + default T setMethod(String value) { + return set(METHOD, ParamUtil.searchEnum(METHOD, value)); + } + + enum Method { + PEARSON, + SPEAMAN + } + } diff --git a/core/src/test/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestTest.java b/core/src/test/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestTest.java index c7c9534ca..5195fde48 100644 --- a/core/src/test/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/common/statistics/ChiSquareTestTest.java @@ -5,6 +5,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.types.Row; + +import com.alibaba.alink.params.feature.BasedChisqSelectorParams; import org.junit.Test; import java.util.ArrayList; @@ -35,7 +37,7 @@ public void testChiSquare() { @Test public void testChiSqSelector() { - String selectorType = "NumTopFeatures"; + BasedChisqSelectorParams.SelectorType selectorType = BasedChisqSelectorParams.SelectorType.NumTopFeatures; int numTopFeatures = 2; double percentile = 0.5; double fpr = 0.5; @@ -51,7 +53,7 @@ public void testChiSqSelector() { @Test public void testChiSqSelector2() { - String selectorType = "percentile"; + BasedChisqSelectorParams.SelectorType selectorType = BasedChisqSelectorParams.SelectorType.PERCENTILE; int numTopFeatures = 2; double percentile = 0.5; double fpr = 0.5; @@ -67,7 +69,7 @@ public void testChiSqSelector2() { @Test public void testChiSqSelector3() { - String selectorType = "fpr"; + BasedChisqSelectorParams.SelectorType selectorType = BasedChisqSelectorParams.SelectorType.FPR; int numTopFeatures = 2; double percentile = 0.5; double fpr = 0.5; @@ -83,7 +85,7 @@ public void testChiSqSelector3() { @Test public void testChiSqSelector4() { - String selectorType = "fdr"; + BasedChisqSelectorParams.SelectorType selectorType = BasedChisqSelectorParams.SelectorType.FDR; int numTopFeatures = 2; double percentile = 0.5; double fpr = 0.5; @@ -99,7 +101,7 @@ public void testChiSqSelector4() { @Test public void testChiSqSelector5() { - String selectorType = "fwe"; + BasedChisqSelectorParams.SelectorType selectorType = BasedChisqSelectorParams.SelectorType.FWE; int numTopFeatures = 2; double percentile = 0.5; double fpr = 0.5; @@ -112,7 +114,7 @@ public void testChiSqSelector5() { assertEquals(0, selectedIndices[0]); } - private int[] testSelector(String selectorType, int numTopFeatures, + private int[] testSelector(BasedChisqSelectorParams.SelectorType selectorType, int numTopFeatures, double percentile, double fpr, double fdr,