Skip to content

Commit

Permalink
Use constraint distance instead of bool validity in STOMP cost functi…
Browse files Browse the repository at this point in the history
…on (#2418)

* Use constraint distance instead of bool validity in STOMP cost function

* Fix comment
  • Loading branch information
henningkayser authored Oct 24, 2023
1 parent 19c58b8 commit c19d2aa
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions moveit_planners/stomp/include/stomp_moveit/cost_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@

namespace stomp_moveit
{
// Decides if the given state position vector is valid or not - example use cases are collision or constraint checking
using StateValidatorFn = std::function<bool(const Eigen::VectorXd& state_positions)>;
// Validates a given state and produces a scalar cost penalty - example use cases are collision or constraint checking
using StateValidatorFn = std::function<double(const Eigen::VectorXd& state_positions)>;

namespace costs
{

// Interpolation step size for collision checking (joint space, L2 norm)
constexpr double COL_CHECK_DISTANCE = 0.05;
constexpr double CONSTRAINT_CHECK_DISTANCE = 0.1;
constexpr double CONSTRAINT_CHECK_DISTANCE = 0.05;

/**
* Creates a cost function from a binary robot state validation function.
Expand All @@ -69,12 +69,10 @@ constexpr double CONSTRAINT_CHECK_DISTANCE = 0.1;
*
* @param state_validator_fn The validator function that tests for binary conditions
* @param interpolation_step_size The L2 norm distance step used for interpolation
* @param penalty The penalty cost value applied to invalid states
*
* @return Cost function that computes smooth costs for binary validity conditions
*/
CostFn get_cost_function_from_state_validator(const StateValidatorFn& state_validator_fn,
double interpolation_step_size, double penalty)
CostFn get_cost_function_from_state_validator(const StateValidatorFn& state_validator_fn, double interpolation_step_size)
{
CostFn cost_fn = [=](const Eigen::MatrixXd& values, Eigen::VectorXd& costs, bool& validity) {
costs.setZero(values.cols());
Expand All @@ -96,11 +94,13 @@ CostFn get_cost_function_from_state_validator(const StateValidatorFn& state_vali
double interpolation_fraction = 0.0;
const double interpolation_step = std::min(0.5, interpolation_step_size / segment_distance);
bool found_invalid_state = false;
double penalty = 0.0;
while (!found_invalid_state && interpolation_fraction < 1.0)
{
Eigen::VectorXd sample_vec = (1 - interpolation_fraction) * current + interpolation_fraction * next;

found_invalid_state = !state_validator_fn(sample_vec);
penalty = state_validator_fn(sample_vec);
found_invalid_state = penalty > 0.0;
interpolation_fraction += interpolation_step;
}

Expand Down Expand Up @@ -178,10 +178,10 @@ CostFn get_collision_cost_function(const std::shared_ptr<const planning_scene::P
set_joint_positions(positions, joints, state);
state.update();

return !planning_scene->isStateColliding(state, group_name);
return planning_scene->isStateColliding(state, group_name) ? collision_penalty : 0.0;
};

return get_cost_function_from_state_validator(collision_validator_fn, COL_CHECK_DISTANCE, collision_penalty);
return get_cost_function_from_state_validator(collision_validator_fn, COL_CHECK_DISTANCE);
}

/**
Expand All @@ -192,13 +192,13 @@ CostFn get_collision_cost_function(const std::shared_ptr<const planning_scene::P
* @param planning_scene The planning scene instance to use for computing transforms
* @param group The group to use for computing link transforms from joint positions
* @param constraints_msg The constraints used for validating group states
* @param constraints_penalty The penalty cost value applied to invalid states
* @param cost_scale A scalar factor applied to the distance cost of invalid states
*
* @return Cost function that computes smooth costs for invalid path segments
*/
CostFn get_constraints_cost_function(const std::shared_ptr<const planning_scene::PlanningScene>& planning_scene,
const moveit::core::JointModelGroup* group,
const moveit_msgs::msg::Constraints& constraints_msg, double constraints_penalty)
const moveit_msgs::msg::Constraints& constraints_msg, double cost_scale)
{
const auto& joints = group ? group->getActiveJointModels() : planning_scene->getRobotModel()->getActiveJointModels();

Expand All @@ -212,13 +212,10 @@ CostFn get_constraints_cost_function(const std::shared_ptr<const planning_scene:
set_joint_positions(positions, joints, state);
state.update();

// NOTE: the returned ConstraintEvaluationResult also provides a `double distance` which might be used as an
// actual cost gradient instead of the binary state penalty
return constraints.decide(state).satisfied;
return constraints.decide(state).distance * cost_scale;
};

return get_cost_function_from_state_validator(constraints_validator_fn, CONSTRAINT_CHECK_DISTANCE,
constraints_penalty);
return get_cost_function_from_state_validator(constraints_validator_fn, CONSTRAINT_CHECK_DISTANCE);
}

/**
Expand Down

0 comments on commit c19d2aa

Please sign in to comment.