Skip to content
This repository has been archived by the owner on May 10, 2024. It is now read-only.

Commit

Permalink
ortools-backend: simplify aggregate function code
Browse files Browse the repository at this point in the history
Signed-off-by: Lalith Suresh <lsuresh@vmware.com>
  • Loading branch information
lalithsuresh committed Jul 19, 2021
1 parent be7b2c7 commit 5733753
Showing 1 changed file with 45 additions and 54 deletions.
99 changes: 45 additions & 54 deletions dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -1237,61 +1237,52 @@ protected JavaExpression visitFunctionCall(final FunctionCall node, final Transl
// Functions always apply on a vector. We compute the arguments to the function, and in doing so,
// add declarations to the corresponding for-loop that extracts the relevant columns/expressions from views.
final OutputIR.Block forLoop = context.currentScope().getForLoopByName(vectorName);

if (node.getArgument().size() == 1) {
context.enterScope(forLoop);
final JavaExpression processedArgument = visit(node.getArgument().get(0),
context.withEnterFunctionContext());
context.leaveScope();
final JavaExpression listOfProcessedItem = extractListFromLoop(processedArgument,
context.currentScope(), forLoop);
switch (node.getFunction()) {
case SUM:
case COUNT:
case MAX:
case MIN:
case ANY:
case ALL:
case ALL_EQUAL:
case ALL_DIFFERENT:
case INCREASING:
final boolean supportsAssumptions = supportsAssumptions(node);
final JavaType argType = processedArgument.type();
final JavaType outType = tupleMetadata.inferType(node);
final String functionName = String.format("%s%s", camelCase(node.getFunction().toString()),
argType);
final String argumentString = listOfProcessedItem.asString();
return supportsAssumptions ?
// Use the current scope as the assumption context
new JavaExpression(CodeBlock.of("o.$L($L, $S)", functionName, argumentString,
context.currentScope().getName()).toString(), outType)
: new JavaExpression(CodeBlock.of("o.$L($L)", functionName, argumentString).toString(),
outType);
default:
throw new UnsupportedOperationException("Unsupported unary aggregate function "
+ node.getFunction());
}
} else if (node.getArgument().size() == 2) {
switch (node.getFunction()) {
case SCALAR_PRODUCT:
context.enterScope(forLoop);
final JavaExpression arg1 = visit(node.getArgument().get(0),
context.withEnterFunctionContext());
final JavaExpression arg2 = visit(node.getArgument().get(1),
context.withEnterFunctionContext());
context.leaveScope();
final JavaExpression listOfArg1 = extractListFromLoop(arg1, context.currentScope(), forLoop);
final JavaExpression listOfArg2 = extractListFromLoop(arg2, context.currentScope(), forLoop);
final JavaType arg2Type = arg2.type();
return new JavaExpression(CodeBlock.of("o.scalProd$L($L, $L)", arg2Type.toString(),
listOfArg1.asString(), listOfArg2.asString()).toString(),
JavaType.IntVar);
default:
throw new UnsupportedOperationException("Unsupported binary aggregate function "
+ node.getFunction());
}
switch (node.getFunction()) {
// Start with unary functions
case SUM:
case COUNT:
case MAX:
case MIN:
case ANY:
case ALL:
case ALL_EQUAL:
case ALL_DIFFERENT:
case INCREASING:
context.enterScope(forLoop);
final JavaExpression processedArgument = visit(node.getArgument().get(0),
context.withEnterFunctionContext());
context.leaveScope();
final JavaExpression listOfProcessedItem = extractListFromLoop(processedArgument,
context.currentScope(), forLoop);
final boolean supportsAssumptions = supportsAssumptions(node);
final JavaType argType = processedArgument.type();
final JavaType outType = tupleMetadata.inferType(node);
final String functionName = String.format("%s%s", camelCase(node.getFunction().toString()),
argType);
final String argumentString = listOfProcessedItem.asString();
return supportsAssumptions ?
// Use the current scope as the assumption context
new JavaExpression(CodeBlock.of("o.$L($L, $S)", functionName, argumentString,
context.currentScope().getName()).toString(), outType)
: new JavaExpression(CodeBlock.of("o.$L($L)", functionName, argumentString).toString(),
outType);
// Binary functions
case SCALAR_PRODUCT:
context.enterScope(forLoop);
final JavaExpression arg1 = visit(node.getArgument().get(0),
context.withEnterFunctionContext());
final JavaExpression arg2 = visit(node.getArgument().get(1),
context.withEnterFunctionContext());
context.leaveScope();
final JavaExpression listOfArg1 = extractListFromLoop(arg1, context.currentScope(), forLoop);
final JavaExpression listOfArg2 = extractListFromLoop(arg2, context.currentScope(), forLoop);
final JavaType arg2Type = arg2.type();
return new JavaExpression(CodeBlock.of("o.scalProd$L($L, $L)", arg2Type.toString(),
listOfArg1.asString(), listOfArg2.asString()).toString(),
JavaType.IntVar);
default:
throw new UnsupportedOperationException("Unsupported aggregate function " + node.getFunction());
}
throw new UnsupportedOperationException("Unsupported aggregate function " + node.getFunction());
}

private boolean supportsAssumptions(final FunctionCall node) {
Expand Down

0 comments on commit 5733753

Please sign in to comment.