Skip to content

Commit

Permalink
Implement robust regularization in 'survival:aft' objective (#5473)
Browse files Browse the repository at this point in the history
* Robust regularization of AFT gradient and hessian

* Fix AFT doc; expose it to tutorial TOC

* Apply robust regularization to uncensored case too

* Revise unit test slightly

* Fix lint

* Update test_survival.py

* Use GradientPairPrecise

* Remove unused variables
  • Loading branch information
hcho3 authored Apr 4, 2020
1 parent 9399736 commit 5fc5ec5
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 42 deletions.
2 changes: 1 addition & 1 deletion demo/aft_survival/aft_survival_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@
print(df[np.isinf(df['Label (upper bound)'])])

# Save trained model
bst.save_model('aft_model.json')
bst.save_model('aft_model.json')
2 changes: 1 addition & 1 deletion demo/aft_survival/aft_survival_demo_with_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ def objective(trial):
print(df[np.isinf(df['Label (upper bound)'])])

# Save trained model
bst.save_model('aft_best_model.json')
bst.save_model('aft_best_model.json')
2 changes: 1 addition & 1 deletion doc/tutorials/aft_survival_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Note that this model is a generalized form of a linear regression model :math:`Y
\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z
where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathbf{x}`.
where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathcal{T}(\mathbf{x})`.

**********
How to use
Expand Down
1 change: 1 addition & 0 deletions doc/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ See `Awesome XGBoost </~https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
monotonic
rf
feature_interaction_constraint
aft_survival_analysis
input_format
param_tuning
external_memory
Expand Down
166 changes: 142 additions & 24 deletions src/common/survival_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,106 @@
/~https://github.com/avinashbarnwal/GSOC-2019/blob/master/doc/Accelerated_Failure_Time.pdf
*/

namespace {

// Allowable range for gradient and hessian. Used for regularization
constexpr double kMinGradient = -15.0;
constexpr double kMaxGradient = 15.0;
constexpr double kMinHessian = 1e-16; // Ensure that no data point gets zero hessian
constexpr double kMaxHessian = 15.0;

constexpr double kEps = 1e-12; // A denomitor in a fraction should not be too small

// Clip (limit) x to fit range [x_min, x_max].
// If x < x_min, return x_min; if x > x_max, return x_max; if x_min <= x <= x_max, return x.
// This function assumes x_min < x_max; behavior is undefined if this assumption does not hold.
inline double Clip(double x, double x_min, double x_max) {
if (x < x_min) {
return x_min;
}
if (x > x_max) {
return x_max;
}
return x;
}

using xgboost::common::ProbabilityDistributionType;

enum class CensoringType : uint8_t {
kUncensored, kRightCensored, kLeftCensored, kIntervalCensored
};

using xgboost::GradientPairPrecise;

inline GradientPairPrecise GetLimitAtInfPred(ProbabilityDistributionType dist_type,
CensoringType censor_type,
double sign, double sigma) {
switch (censor_type) {
case CensoringType::kUncensored:
switch (dist_type) {
case ProbabilityDistributionType::kNormal:
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
case ProbabilityDistributionType::kLogistic:
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
case ProbabilityDistributionType::kExtreme:
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
default:
LOG(FATAL) << "Unknown distribution type";
}
case CensoringType::kRightCensored:
switch (dist_type) {
case ProbabilityDistributionType::kNormal:
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
: GradientPairPrecise{ 0.0, kMinHessian };
case ProbabilityDistributionType::kLogistic:
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
: GradientPairPrecise{ 0.0, kMinHessian };
case ProbabilityDistributionType::kExtreme:
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
: GradientPairPrecise{ 0.0, kMinHessian };
default:
LOG(FATAL) << "Unknown distribution type";
}
case CensoringType::kLeftCensored:
switch (dist_type) {
case ProbabilityDistributionType::kNormal:
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
case ProbabilityDistributionType::kLogistic:
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
case ProbabilityDistributionType::kExtreme:
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
default:
LOG(FATAL) << "Unknown distribution type";
}
case CensoringType::kIntervalCensored:
switch (dist_type) {
case ProbabilityDistributionType::kNormal:
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
case ProbabilityDistributionType::kLogistic:
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
case ProbabilityDistributionType::kExtreme:
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
default:
LOG(FATAL) << "Unknown distribution type";
}
default:
LOG(FATAL) << "Unknown censoring type";
}

return { 0.0, 0.0 };
}

} // anonymous namespace

namespace xgboost {
namespace common {

Expand All @@ -26,14 +126,14 @@ DMLC_REGISTER_PARAMETER(AFTParam);
double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
const double eps = 1e-12;

double cost;

if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
// Regularize the denominator with eps, to avoid INF or NAN
cost = -std::log(std::max(pdf / (sigma * y_lower), eps));
cost = -std::log(std::max(pdf / (sigma * y_lower), kEps));
} else { // censored; now check what type of censorship we have
double z_u, z_l, cdf_u, cdf_l;
if (std::isinf(y_upper)) { // right-censored
Expand All @@ -49,7 +149,7 @@ double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma
cdf_l = dist_->CDF(z_l);
}
// Regularize the denominator with eps, to avoid INF or NAN
cost = -std::log(std::max(cdf_u - cdf_l, eps));
cost = -std::log(std::max(cdf_u - cdf_l, kEps));
}

return cost;
Expand All @@ -58,20 +158,25 @@ double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma
double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
double gradient;
const double eps = 1e-12;
double numerator, denominator, gradient; // numerator and denominator of gradient
CensoringType censor_type;
bool z_sign; // sign of z-score

if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
const double grad_pdf = dist_->GradPDF(z);
// Regularize the denominator with eps, so that gradient doesn't get too big
gradient = grad_pdf / (sigma * std::max(pdf, eps));
censor_type = CensoringType::kUncensored;
numerator = grad_pdf;
denominator = sigma * pdf;
z_sign = (z > 0);
} else { // censored; now check what type of censorship we have
double z_u, z_l, pdf_u, pdf_l, cdf_u, cdf_l;
double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l;
censor_type = CensoringType::kIntervalCensored;
if (std::isinf(y_upper)) { // right-censored
pdf_u = 0;
cdf_u = 1;
censor_type = CensoringType::kRightCensored;
} else { // interval-censored or left-censored
z_u = (log_y_upper - y_pred) / sigma;
pdf_u = dist_->PDF(z_u);
Expand All @@ -80,38 +185,48 @@ double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double s
if (std::isinf(y_lower)) { // left-censored
pdf_l = 0;
cdf_l = 0;
censor_type = CensoringType::kLeftCensored;
} else { // interval-censored or right-censored
z_l = (log_y_lower - y_pred) / sigma;
pdf_l = dist_->PDF(z_l);
cdf_l = dist_->CDF(z_l);
}
// Regularize the denominator with eps, so that gradient doesn't get too big
gradient = (pdf_u - pdf_l) / (sigma * std::max(cdf_u - cdf_l, eps));
z_sign = (z_u > 0 || z_l > 0);
numerator = pdf_u - pdf_l;
denominator = sigma * (cdf_u - cdf_l);
}
gradient = numerator / denominator;
if (denominator < kEps && (std::isnan(gradient) || std::isinf(gradient))) {
gradient = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).GetGrad();
}

return gradient;
return Clip(gradient, kMinGradient, kMaxGradient);
}

double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
const double eps = 1e-12;
double hessian;
double numerator, denominator, hessian; // numerator and denominator of hessian
CensoringType censor_type;
bool z_sign; // sign of z-score

if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
const double grad_pdf = dist_->GradPDF(z);
const double hess_pdf = dist_->HessPDF(z);
// Regularize the denominator with eps, so that gradient doesn't get too big
hessian = -(pdf * hess_pdf - std::pow(grad_pdf, 2))
/ (std::pow(sigma, 2) * std::pow(std::max(pdf, eps), 2));
censor_type = CensoringType::kUncensored;
numerator = -(pdf * hess_pdf - grad_pdf * grad_pdf);
denominator = sigma * sigma * pdf * pdf;
z_sign = (z > 0);
} else { // censored; now check what type of censorship we have
double z_u, z_l, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
censor_type = CensoringType::kIntervalCensored;
if (std::isinf(y_upper)) { // right-censored
pdf_u = 0;
cdf_u = 1;
grad_pdf_u = 0;
censor_type = CensoringType::kRightCensored;
} else { // interval-censored or left-censored
z_u = (log_y_upper - y_pred) / sigma;
pdf_u = dist_->PDF(z_u);
Expand All @@ -122,6 +237,7 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si
pdf_l = 0;
cdf_l = 0;
grad_pdf_l = 0;
censor_type = CensoringType::kLeftCensored;
} else { // interval-censored or right-censored
z_l = (log_y_lower - y_pred) / sigma;
pdf_l = dist_->PDF(z_l);
Expand All @@ -131,15 +247,17 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si
const double cdf_diff = cdf_u - cdf_l;
const double pdf_diff = pdf_u - pdf_l;
const double grad_diff = grad_pdf_u - grad_pdf_l;
// Regularize the denominator with eps, so that gradient doesn't get too big
const double cdf_diff_thresh = std::max(cdf_diff, eps);
const double numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
const double sqrt_denominator = sigma * cdf_diff_thresh;
const double denominator = sqrt_denominator * sqrt_denominator;
hessian = numerator / denominator;
const double sqrt_denominator = sigma * cdf_diff;
z_sign = (z_u > 0 || z_l > 0);
numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
denominator = sqrt_denominator * sqrt_denominator;
}
hessian = numerator / denominator;
if (denominator < kEps && (std::isnan(hessian) || std::isinf(hessian))) {
hessian = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).GetHess();
}

return hessian;
return Clip(hessian, kMinHessian, kMaxHessian);
}

} // namespace common
Expand Down
9 changes: 5 additions & 4 deletions src/common/survival_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ struct AFTParam : public XGBoostParameter<AFTParam> {
class AFTLoss {
private:
std::unique_ptr<ProbabilityDistribution> dist_;
ProbabilityDistributionType dist_type_;

public:
/*!
* \brief Constructor for AFT loss function
* \param dist Choice of probability distribution for the noise term in AFT
* \param dist_type Choice of probability distribution for the noise term in AFT
*/
explicit AFTLoss(ProbabilityDistributionType dist) {
dist_.reset(ProbabilityDistribution::Create(dist));
}
explicit AFTLoss(ProbabilityDistributionType dist_type)
: dist_(ProbabilityDistribution::Create(dist_type)),
dist_type_(dist_type) {}

public:
/*!
Expand Down
44 changes: 44 additions & 0 deletions tests/cpp/common/test_survival_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*!
* Copyright (c) by Contributors 2020
*/
#include <gtest/gtest.h>

#include "../../../src/common/survival_util.h"

namespace xgboost {
namespace common {

inline static void RobustTestSuite(ProbabilityDistributionType dist_type,
double y_lower, double y_upper, double sigma) {
AFTLoss loss(dist_type);
for (int i = 50; i >= -50; --i) {
const double y_pred = std::pow(10.0, static_cast<double>(i));
const double z = (std::log(y_lower) - std::log(y_pred)) / sigma;
const double gradient = loss.Gradient(y_lower, y_upper, std::log(y_pred), sigma);
const double hessian = loss.Hessian(y_lower, y_upper, std::log(y_pred), sigma);
ASSERT_FALSE(std::isnan(gradient)) << "z = " << z << ", y \\in ["
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
<< ", dist = " << static_cast<int>(dist_type);
ASSERT_FALSE(std::isinf(gradient)) << "z = " << z << ", y \\in ["
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
<< ", dist = " << static_cast<int>(dist_type);
ASSERT_FALSE(std::isnan(hessian)) << "z = " << z << ", y \\in ["
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
<< ", dist = " << static_cast<int>(dist_type);
ASSERT_FALSE(std::isinf(hessian)) << "z = " << z << ", y \\in ["
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
<< ", dist = " << static_cast<int>(dist_type);
}
}

TEST(AFTLoss, RobustGradientPair) { // Ensure that INF and NAN don't show up in gradient pair
RobustTestSuite(ProbabilityDistributionType::kNormal, 16.0, 200.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kLogistic, 16.0, 200.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kExtreme, 16.0, 200.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kNormal, 100.0, 100.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kLogistic, 100.0, 100.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kExtreme, 100.0, 100.0, 2.0);
}

} // namespace common
} // namespace xgboost
15 changes: 7 additions & 8 deletions tests/cpp/objective/test_aft_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ TEST(Objective, AFTObjGPairUncensoredLabels) {
{ 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f,
0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f });
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme",
{ -0.0000f, -29.0026f, -17.0031f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f,
{ -15.0000f, -15.0000f, -15.0000f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f,
0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f,
0.9969f },
{ 0.0000f, 30.0026f, 18.0031f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f,
{ 15.0000f, 15.0000f, 15.0000f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f,
0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f });
}

Expand All @@ -106,10 +106,9 @@ TEST(Objective, AFTObjGPairLeftCensoredLabels) {

CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "normal",
{ 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f,
3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 0.5072f },
3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 7.5326f },
{ 0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f,
0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9812f, 0.0045f },
2e-4);
0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9813f, 0.9877f });
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "logistic",
{ 0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f,
0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f },
Expand Down Expand Up @@ -139,10 +138,10 @@ TEST(Objective, AFTObjGPairRightCensoredLabels) {
{ 0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f,
0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f });
CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "extreme",
{ -2.8073f, -18.0015f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f,
{ -15.0000f, -15.0000f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f,
-0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, -0.0235f, -0.0141f, -0.0085f, -0.0051f,
-0.0031f, -0.0018f },
{ 0.2614f, 18.0015f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f,
-0.0031f, -0.0018f },
{ 15.0000f, 15.0000f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f,
0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f });
}

Expand Down
Loading

0 comments on commit 5fc5ec5

Please sign in to comment.