Skip to content

Commit

Permalink
Updated JPMML-Python dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Apr 8, 2024
1 parent 2776ebe commit dc9a33b
Show file tree
Hide file tree
Showing 142 changed files with 496 additions and 618 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public DataType getDataType(){
@Override
public List<String> getFeatureNamesIn(){

if(containsKey(SkLearnFields.FEATURE_NAMES_IN)){
if(hasattr(SkLearnFields.FEATURE_NAMES_IN)){
return getListLike(SkLearnFields.FEATURE_NAMES_IN, String.class);
}

Expand All @@ -61,19 +61,19 @@ public PMML encodePMML(){
return TransformerUtil.encodePMML(this);
}

public List<?> getCols(){
return getList("cols");
public List<Object> getCols(){
return getObjectList("cols");
}

public List<String> getInvariantCols(){

// CategoryEncoders 2.3
if(containsKey("drop_cols")){
return getList("drop_cols", String.class);
if(hasattr("drop_cols")){
return getStringList("drop_cols");
}

// CategoryEncoders 2.5+
return getList("invariant_cols", String.class);
return getStringList("invariant_cols");
}

public Boolean getDropInvariant(){
Expand All @@ -83,12 +83,12 @@ public Boolean getDropInvariant(){
public List<String> getFeatureNamesOut(){

// CategoryEncoders 2.5.1post0
if(containsKey("feature_names")){
return getList("feature_names", String.class);
if(hasattr("feature_names")){
return getStringList("feature_names");
}

// CategoryEncoders 2.6+
return getList("feature_names_out_", String.class);
return getStringList("feature_names_out_");
}

public String getHandleMissing(){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public List<Feature> encode(List<Feature> features, SkLearnEncoder encoder){
public BaseNEncoder getBaseNEncoder(){

// CategoryEncoders 2.3
if(containsKey("base_n_encoder")){
if(hasattr("base_n_encoder")){
return get("base_n_encoder", BaseNEncoder.class);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.dmg.pmml.DataType;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.AttributeException;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.PythonObject;
import org.jpmml.sklearn.SkLearnEncoder;
Expand Down Expand Up @@ -107,7 +108,7 @@ public List<Mapping> getMapping(){
@Override
public Mapping apply(Map<String, ?> map){
Mapping mapping = new Mapping(getClassName(), "mapping");
mapping.putAll(map);
mapping.update(map);

return mapping;
}
Expand All @@ -130,7 +131,7 @@ public Map<?, Integer> getCategoryMapping(){
Series mapping = get("mapping", Series.class);

return SeriesUtil.toMap(mapping, Functions.identity(), ValueUtil::asInteger);
} catch(IllegalArgumentException iae){
} catch(AttributeException ae){
return (Map)getDict("mapping");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
String key = entry.getKey();
Object value = entry.getValue();

if(!optimalBinning.containsKey(key)){
optimalBinning.put(key, value);
if(!optimalBinning.hasattr(key)){
optimalBinning.setattr(key, value);
}
}
}
Expand Down Expand Up @@ -135,7 +135,7 @@ public List<Boolean> getSupport(){
}

public List<String> getVariableNames(){
return (List)getListLike("variable_names", String.class);
return getListLike("variable_names", String.class);
}

public Map<String, String> getVariableDTypes(){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ private MapValues encodeCategoricalBinning(Feature feature, List<Number> splits,
return mapValues;
}

public List<?> getCategoriesIn(){
return getArray("_categories");
public List<Object> getCategoriesIn(){
return getObjectArray("_categories");
}

public List<Double> getCategoriesOut(){
Expand Down Expand Up @@ -333,15 +333,15 @@ public String getDefaultMetric(){

public String getMetric(){

if(!containsKey("metric")){
if(!hasattr("metric")){
return getDefaultMetric();
}

return getString("metric");
}

public OptimalBinning setMetric(String metric){
put("metric", metric);
setattr("metric", metric);

return this;
}
Expand All @@ -355,7 +355,7 @@ public List<Integer> getNumberOfNonEvents(){
}

public List<Number> getSpecialCodes(){
Object specialCodes = get("special_codes");
Object specialCodes = getOptionalObject("special_codes");

if(specialCodes == null){
return Collections.emptyList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
public Map<String, ?> getToOther(){

// PyCaret 3.0.0-RC
if(containsKey("_to_other")){
if(hasattr("_to_other")){
return getDict("_to_other");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
public List<String> getDrop(){

// PyCaret 3.0.0-RC
if(containsKey("_drop")){
return getList("_drop", String.class);
if(hasattr("_drop")){
return getStringList("_drop");
}

// PyCaret 3.0.0+
return getList("drop_", String.class);
return getStringList("drop_");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,15 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
}

public List<String> getFeatureNames(){
return getList("_feature_names_in", String.class);
return getStringList("_feature_names_in");
}

public List<String> getExclude(){
return getList("_exclude", String.class);
return getStringList("_exclude");
}

public List<String> getInclude(){
return getList("_include", String.class);
return getStringList("_include");
}

public String getTargetName(){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public Object apply(Object object){

@Override
protected String formatMessage(Object object){
return "Dict attribute \'estimators_\' contains an unsupported item value (" + ClassDictUtil.formatClass(object) + ")";
return "The item value object (" + ClassDictUtil.formatClass(object) + ") is not a supported Classifier";
}
};

Expand All @@ -189,7 +189,7 @@ protected String formatMessage(Object object){

private Map<?, ?> getEstimatorCategories(){

if(!containsKey(SkLearn2PMMLFields.PMML_CLASSES)){
if(!hasattr(SkLearn2PMMLFields.PMML_CLASSES)){
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public int getNumberOfOutputs(){
public List<?> getClasses(){
MojoModel mojoModel = getMojoModel();

if(containsKey(SkLearn2PMMLFields.PMML_CLASSES)){
if(hasattr(SkLearn2PMMLFields.PMML_CLASSES)){
List<?> values = getListLike(SkLearn2PMMLFields.PMML_CLASSES);

return Classifier.canonicalizeValues(values);
Expand Down Expand Up @@ -260,7 +260,7 @@ public String getMojoPath(){
}

public H2OEstimator setMojoPath(String mojoPath){
put("_mojo_path", mojoPath);
setattr("_mojo_path", mojoPath);

return this;
}
Expand Down Expand Up @@ -291,7 +291,7 @@ private MojoModel loadMojoModel(){

try {

if(containsKey("_mojo_bytes")){
if(hasattr("_mojo_bytes")){
byte[] mojoBytes = getMojoBytes();

try(InputStream is = new ByteArrayInputStream(mojoBytes)){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private GBDT loadGBDT(){
public String getHandle(){

// LightGBM 3.3.5
if(containsKey("handle")){
if(hasattr("handle")){
return getString("handle");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private Learner loadLearner(ByteOrder byteOrder, String charset){

public Integer getBestNTreeLimit(){

if(!containsKey("best_ntree_limit")){
if(!hasattr("best_ntree_limit")){
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public DataType getDataType(){
public List<?> getClasses(){

// XGBoost 0.4 through 1.7
if(containsKey("_le") || containsKey(SkLearnFields.CLASSES) || containsKey(SkLearn2PMMLFields.PMML_CLASSES)){
if(hasattr("_le") || hasattr(SkLearnFields.CLASSES) || hasattr(SkLearn2PMMLFields.PMML_CLASSES)){
return super.getClasses();
}

Expand Down
16 changes: 5 additions & 11 deletions pmml-sklearn/src/main/java/chaid/Split.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,21 @@ public Split(String module, String name){
}

public Integer getColumnId(){
Object columnId = get("column_id");

if(columnId == null){
return null;
}

return getInteger("column_id");
return getOptionalInteger("column_id");
}

public InvalidSplitReason getInvalidReason(){
return getOptional("_invalid_reason", InvalidSplitReason.class);
}

public List<List<Integer>> getSplits(){
List<?> splits = getList("splits");
List<Object> splits = getObjectList("splits");

return decodeSplits(splits);
}

public List<List<?>> getSplitMap(){
List<?> splitMap = getList("split_map");
public List<List<Object>> getSplitMap(){
List<Object> splitMap = getObjectList("split_map");

return decodeSplitMap(splitMap);
}
Expand All @@ -67,7 +61,7 @@ private List<List<Integer>> decodeSplits(List<?> splits){
}

static
public List<List<?>> decodeSplitMap(List<?> splits){
public List<List<Object>> decodeSplitMap(List<?> splits){
return Lists.transform(splits, split -> {
return Lists.transform((List<?>)split, value -> {
return ScalarUtil.decode(value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Collections;
import java.util.List;

import org.jpmml.python.AttributeException;
import org.jpmml.python.ClassDictUtil;
import sklearn.Estimator;
import sklearn.HasNumberOfFeatures;
Expand Down Expand Up @@ -63,7 +64,7 @@ public List<String> generateFeatureNames(Step step){
int numberOfFeatures = step.getNumberOfFeatures();

if(numberOfFeatures == HasNumberOfFeatures.UNKNOWN){
throw new IllegalArgumentException("Attribute \'" + ClassDictUtil.formatMember(step, SkLearnFields.N_FEATURES_IN) + "\' is not set");
throw new AttributeException("Attribute \'" + ClassDictUtil.formatMember(step, SkLearnFields.N_FEATURES_IN) + "\' is not set");
}

return generateNames("x", numberOfFeatures, true);
Expand All @@ -74,7 +75,7 @@ public List<String> generateOutputNames(Estimator estimator){
int numberOfOutputs = estimator.getNumberOfOutputs();

if(numberOfOutputs == HasNumberOfOutputs.UNKNOWN){
throw new IllegalArgumentException("Attribute \'" + ClassDictUtil.formatMember(estimator, SkLearnFields.N_OUTPUTS) + "\' is not set");
throw new AttributeException("Attribute \'" + ClassDictUtil.formatMember(estimator, SkLearnFields.N_OUTPUTS) + "\' is not set");
}

return generateNames("y", numberOfOutputs, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ public List<DefineFunction> createSplineFunction(BSpline bspline, SkLearnEncoder
List<Number> c = bspline.getC();
List<Number> t = bspline.getT();

// XXX
int[] cShape = bspline.getArrayShape("c");
int[] cShape = bspline.getCShape();

int n = (t.size() - k - 1);

Expand Down
13 changes: 11 additions & 2 deletions pmml-sklearn/src/main/java/sklearn/Classifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public int getNumberOfOutputs(){
@Override
public List<?> getClasses(){

if(containsKey(SkLearn2PMMLFields.PMML_CLASSES)){
if(hasattr(SkLearn2PMMLFields.PMML_CLASSES)){
return getClasses(SkLearn2PMMLFields.PMML_CLASSES);
}

Expand All @@ -93,7 +93,16 @@ protected List<?> getClasses(String name){
List<?> values = getListLike(name);

values = values.stream()
.map(value -> (value instanceof HasArray) ? canonicalizeValues(((HasArray)value).getArrayContent()) : value)
.map(value -> {

if(value instanceof HasArray){
HasArray hasArray = (HasArray)value;

return canonicalizeValues(hasArray.getArrayContent());
}

return value;
})
.collect(Collectors.toList());

return canonicalizeValues(values);
Expand Down
Loading

0 comments on commit dc9a33b

Please sign in to comment.