Skip to content

Commit

Permalink
Adapt enum type params in pca, chi square test, glm and correlation. …
Browse files Browse the repository at this point in the history
…see #73
  • Loading branch information
shaomengwang authored and ning.cain committed Apr 9, 2020
1 parent 5a090a3 commit 309ecb6
Show file tree
Hide file tree
Showing 23 changed files with 205 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public ChiSqSelectorBatchOp linkFrom(BatchOperator<?>... inputs) {
String[] selectedColNames = getSelectedCols();
String labelColName = getLabelCol();

String selectorType = getParams().get(SELECTOR_TYPE).trim().toLowerCase();
SelectorType selectorType = getParams().get(SELECTOR_TYPE);
int numTopFeatures = getParams().get(NUM_TOP_FEATURES);
double percentile = getParams().get(PERCENTILE);
double fpr = getParams().get(FPR);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.alibaba.alink.operator.batch.feature;

import com.alibaba.alink.common.linalg.*;
import com.alibaba.alink.operator.common.feature.pca.PcaTypeEnum;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.DataSet;
Expand Down Expand Up @@ -328,7 +327,7 @@ public void mapPartition(Iterable<Tuple2<Integer, DenseVector>> splitVec, Collec
PcaModelData pcr = new PcaModelData();

//get correlation or covariance matrix
PcaTypeEnum pcaTypeEnum = PcaTypeEnum.valueOf(pcaType.toUpperCase());
CalculationType pcaTypeEnum = CalculationType.valueOf(pcaType.toUpperCase());

double[][] corr = null;

Expand All @@ -346,7 +345,7 @@ public void mapPartition(Iterable<Tuple2<Integer, DenseVector>> splitVec, Collec


DenseMatrix calculateMatrix = new DenseMatrix(corr);
if (pcaTypeEnum.equals(PcaTypeEnum.COVAR_POP)) {
if (pcaTypeEnum.equals(CalculationType.COVAR_POP)) {
double cnt = counts[0];
if (cnt > 1) {
calculateMatrix.scaleEqual(cnt / (cnt - 1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public VectorChiSqSelectorBatchOp linkFrom(BatchOperator<?>... inputs) {
String vectorColName = getSelectedCol();
String labelColName = getLabelCol();

String selectorType = getParams().get(SELECTOR_TYPE).trim().toLowerCase();
SelectorType selectorType = getParams().get(SELECTOR_TYPE);
int numTopFeatures = getParams().get(NUM_TOP_FEATURES);
double percentile = getParams().get(PERCENTILE);
double fpr = getParams().get(FPR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public GlmEvaluationBatchOp linkFrom(BatchOperator<?>... inputs) {
String weightColName = getWeightCol();
String offsetColName = getOffsetCol();

String familyName = getFamily();
String linkName = getLink();
Family familyName = getFamily();
Link linkName = getLink();
double variancePower = getVariancePower();
double linkPower = getLinkPower();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public GlmTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
String weightColName = getWeightCol();
String offsetColName = getOffsetCol();

String familyName = getFamily();
String linkName = getLink();
Family familyName = getFamily();
Link linkName = getLink();
double variancePower = getVariancePower();
double linkPower = getLinkPower();

Expand Down Expand Up @@ -125,9 +125,9 @@ private static class BuildModel implements MapPartitionFunction<GlmUtil.Weighted
private String weightColName;
private String labelColName;

private String familyName;
private Family familyName;
private double variancePower;
private String linkName;
private Link linkName;
private double linkPower;

private boolean fitIntercept;
Expand All @@ -136,8 +136,8 @@ private static class BuildModel implements MapPartitionFunction<GlmUtil.Weighted

public BuildModel(String[] featureColNames, String offsetColName,
String weightColName, String labelColName,
String familyName, double variancePower,
String linkName, double linkPower,
Family familyName, double variancePower,
Link linkName, double linkPower,
Boolean fitIntercept, int numIter, double epsilon) {
this.featureColNames = featureColNames;
this.offsetColName = offsetColName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ public CorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
//check col types must be double or bigint
TableUtil.assertNumericalCols(in.getSchema(), selectedColNames);

String corrType = getMethod().trim().toLowerCase();
Method corrType = getMethod();

if ("pearson".equals(corrType)) {
if (Method.PEARSON == corrType) {

DataSet<Tuple2<TableSummary, CorrelationResult>> srt = StatisticsHelper.pearsonCorrelation(in, selectedColNames);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ public VectorCorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
String vectorColName = getSelectedCol();

String corrType = getMethod().trim().toLowerCase();
Method corrType = getMethod();

if ("pearson".equals(corrType)) {
if (Method.PEARSON == corrType) {
DataSet<Tuple2<BaseVectorSummary, CorrelationResult>> srt = StatisticsHelper.vectorPearsonCorrelation(in, vectorColName);

//block
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.flink.ml.api.misc.param.Params;

import com.alibaba.alink.common.utils.JsonConverter;
import com.alibaba.alink.params.feature.HasCalculationType;
import com.alibaba.alink.params.shared.colname.HasFeatureColsDefaultAsNull;
import com.alibaba.alink.params.feature.PcaTrainParams;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.feature.HasCalculationType;
import com.alibaba.alink.params.feature.PcaPredictParams;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
Expand All @@ -25,7 +26,7 @@ public class PcaModelMapper extends ModelMapper {
private int[] featureIdxs = null;
private boolean isVector;

private String transformType = null;
private PcaPredictParams.TransformType transformType = null;
private String pcaType = null;

private double[] sourceMean = null;
Expand Down Expand Up @@ -77,8 +78,7 @@ public void loadModel(List<Row> modelRows) {
int nx = model.means.length;
int p = model.p;

PcaTransformTypeEnum transformTypeEnum = PcaTransformTypeEnum.valueOf(this.transformType.toUpperCase());
PcaTypeEnum pcaTypeEnum = PcaTypeEnum.valueOf(this.pcaType.toUpperCase());
HasCalculationType.CalculationType pcaTypeEnum = HasCalculationType.CalculationType.valueOf(this.pcaType.toUpperCase());

//transform mean, stdDevs and scoreStd
sourceMean = new double[nx];
Expand All @@ -88,11 +88,11 @@ public void loadModel(List<Row> modelRows) {
Arrays.fill(sourceStd, 1);
Arrays.fill(scoreStd, 1);

if (PcaTypeEnum.CORR.equals(pcaTypeEnum)) {
if (HasCalculationType.CalculationType.CORR.equals(pcaTypeEnum)) {
sourceStd = model.stddevs;
}

switch (transformTypeEnum) {
switch (transformType) {
case SUBMEAN:
sourceMean = model.means;
break;
Expand All @@ -108,6 +108,10 @@ public void loadModel(List<Row> modelRows) {
scoreStd[i] = Math.sqrt(tmp);
}
break;
case SIMPLE:
break;
default:
throw new IllegalArgumentException("Error transformType: " + transformType);
}
}

Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.alibaba.alink.operator.common.regression;


import com.alibaba.alink.params.regression.GlmTrainParams;

/**
* Glm model data.
*/
Expand Down Expand Up @@ -51,7 +53,7 @@ public class GlmModelData {
/**
* family name.
*/
public String familyName;
public GlmTrainParams.Family familyName;

/**
* variance power of family.
Expand All @@ -61,7 +63,7 @@ public class GlmModelData {
/**
* link function name.
*/
public String linkName;
public GlmTrainParams.Link linkName;

/**
* power of link function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ public void loadModel(List<Row> modelRows) {

features = new double[featureColIdxs.length];

String familyName = params.get(GlmTrainParams.FAMILY);
GlmTrainParams.Family familyName = params.get(GlmTrainParams.FAMILY);
double variancePower = params.get(GlmTrainParams.VARIANCE_POWER);
String linkName = params.get(GlmTrainParams.LINK);
GlmTrainParams.Link linkName = params.get(GlmTrainParams.LINK);
double linkPower = params.get(GlmTrainParams.LINK_POWER);

familyLink = new FamilyLink(familyName, variancePower, linkName, linkPower);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.alibaba.alink.operator.common.regression.glm.link.Power;
import com.alibaba.alink.operator.common.regression.glm.link.Probit;
import com.alibaba.alink.operator.common.regression.glm.link.Sqrt;
import com.alibaba.alink.params.regression.GlmTrainParams;

import java.io.Serializable;

Expand All @@ -32,57 +33,57 @@ public class FamilyLink implements Serializable {
* @param linkName: link name.
* @param linkPower: link power.
*/
public FamilyLink(String familyName, double variancePower, String linkName, double linkPower) {
if (familyName == null || familyName.isEmpty()) {
public FamilyLink(GlmTrainParams.Family familyName, double variancePower, GlmTrainParams.Link linkName, double linkPower) {
if (familyName == null) {
throw new RuntimeException("family can not be empty");
}

switch (familyName.toLowerCase()) {
case "gamma":
switch (familyName) {
case Gamma:
family = new Gamma();
break;
case "binomial":
case Binomial:
family = new Binomial();
break;
case "gaussian":
case Gaussian:
family = new Gaussian();
break;
case "poisson":
case Poisson:
family = new Poisson();
break;
case "tweedie":
case Tweedie:
family = new Tweedie(variancePower);
break;
default:
throw new RuntimeException("family is not support. ");
}

if (linkName == null || linkName.isEmpty()) {
if (linkName == null) {
link = family.getDefaultLink();
} else {
switch (linkName.toLowerCase()) {
case "cloglog":
switch (linkName) {
case CLogLog:
link = new CLogLog();
break;
case "identity":
case Identity:
link = new Identity();
break;
case "inverse":
case Inverse:
link = new Inverse();
break;
case "log":
case Log:
link = new Log();
break;
case "logit":
case Logit:
link = new Logit();
break;
case "power":
case Power:
link = new Power(linkPower);
break;
case "probit":
case Probit:
link = new Probit();
break;
case "sqrt":
case Sqrt:
link = new Sqrt();
break;
default:
Expand Down
Loading

0 comments on commit 309ecb6

Please sign in to comment.