diff --git a/dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java b/dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java index efc09702..fd528792 100644 --- a/dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java +++ b/dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java @@ -1237,61 +1237,52 @@ protected JavaExpression visitFunctionCall(final FunctionCall node, final Transl // Functions always apply on a vector. We compute the arguments to the function, and in doing so, // add declarations to the corresponding for-loop that extracts the relevant columns/expressions from views. final OutputIR.Block forLoop = context.currentScope().getForLoopByName(vectorName); - - if (node.getArgument().size() == 1) { - context.enterScope(forLoop); - final JavaExpression processedArgument = visit(node.getArgument().get(0), - context.withEnterFunctionContext()); - context.leaveScope(); - final JavaExpression listOfProcessedItem = extractListFromLoop(processedArgument, - context.currentScope(), forLoop); - switch (node.getFunction()) { - case SUM: - case COUNT: - case MAX: - case MIN: - case ANY: - case ALL: - case ALL_EQUAL: - case ALL_DIFFERENT: - case INCREASING: - final boolean supportsAssumptions = supportsAssumptions(node); - final JavaType argType = processedArgument.type(); - final JavaType outType = tupleMetadata.inferType(node); - final String functionName = String.format("%s%s", camelCase(node.getFunction().toString()), - argType); - final String argumentString = listOfProcessedItem.asString(); - return supportsAssumptions ? - // Use the current scope as the assumption context - new JavaExpression(CodeBlock.of("o.$L($L, $S)", functionName, argumentString, - context.currentScope().getName()).toString(), outType) - : new JavaExpression(CodeBlock.of("o.$L($L)", functionName, argumentString).toString(), - outType); - default: - throw new UnsupportedOperationException("Unsupported unary aggregate function " - + node.getFunction()); - } - } else if (node.getArgument().size() == 2) { - switch (node.getFunction()) { - case SCALAR_PRODUCT: - context.enterScope(forLoop); - final JavaExpression arg1 = visit(node.getArgument().get(0), - context.withEnterFunctionContext()); - final JavaExpression arg2 = visit(node.getArgument().get(1), - context.withEnterFunctionContext()); - context.leaveScope(); - final JavaExpression listOfArg1 = extractListFromLoop(arg1, context.currentScope(), forLoop); - final JavaExpression listOfArg2 = extractListFromLoop(arg2, context.currentScope(), forLoop); - final JavaType arg2Type = arg2.type(); - return new JavaExpression(CodeBlock.of("o.scalProd$L($L, $L)", arg2Type.toString(), - listOfArg1.asString(), listOfArg2.asString()).toString(), - JavaType.IntVar); - default: - throw new UnsupportedOperationException("Unsupported binary aggregate function " - + node.getFunction()); - } + switch (node.getFunction()) { + // Start with unary functions + case SUM: + case COUNT: + case MAX: + case MIN: + case ANY: + case ALL: + case ALL_EQUAL: + case ALL_DIFFERENT: + case INCREASING: + context.enterScope(forLoop); + final JavaExpression processedArgument = visit(node.getArgument().get(0), + context.withEnterFunctionContext()); + context.leaveScope(); + final JavaExpression listOfProcessedItem = extractListFromLoop(processedArgument, + context.currentScope(), forLoop); + final boolean supportsAssumptions = supportsAssumptions(node); + final JavaType argType = processedArgument.type(); + final JavaType outType = tupleMetadata.inferType(node); + final String functionName = String.format("%s%s", camelCase(node.getFunction().toString()), + argType); + final String argumentString = listOfProcessedItem.asString(); + return supportsAssumptions ? + // Use the current scope as the assumption context + new JavaExpression(CodeBlock.of("o.$L($L, $S)", functionName, argumentString, + context.currentScope().getName()).toString(), outType) + : new JavaExpression(CodeBlock.of("o.$L($L)", functionName, argumentString).toString(), + outType); + // Binary functions + case SCALAR_PRODUCT: + context.enterScope(forLoop); + final JavaExpression arg1 = visit(node.getArgument().get(0), + context.withEnterFunctionContext()); + final JavaExpression arg2 = visit(node.getArgument().get(1), + context.withEnterFunctionContext()); + context.leaveScope(); + final JavaExpression listOfArg1 = extractListFromLoop(arg1, context.currentScope(), forLoop); + final JavaExpression listOfArg2 = extractListFromLoop(arg2, context.currentScope(), forLoop); + final JavaType arg2Type = arg2.type(); + return new JavaExpression(CodeBlock.of("o.scalProd$L($L, $L)", arg2Type.toString(), + listOfArg1.asString(), listOfArg2.asString()).toString(), + JavaType.IntVar); + default: + throw new UnsupportedOperationException("Unsupported aggregate function " + node.getFunction()); } - throw new UnsupportedOperationException("Unsupported aggregate function " + node.getFunction()); } private boolean supportsAssumptions(final FunctionCall node) {