Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement robust regularization in 'survival:aft' objective #5473

Merged
merged 8 commits into from
Apr 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demo/aft_survival/aft_survival_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@
print(df[np.isinf(df['Label (upper bound)'])])

# Save trained model
bst.save_model('aft_model.json')
bst.save_model('aft_model.json')
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion demo/aft_survival/aft_survival_demo_with_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ def objective(trial):
print(df[np.isinf(df['Label (upper bound)'])])

# Save trained model
bst.save_model('aft_best_model.json')
bst.save_model('aft_best_model.json')
2 changes: 1 addition & 1 deletion doc/tutorials/aft_survival_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Note that this model is a generalized form of a linear regression model :math:`Y
\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z
where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathbf{x}`.
where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathcal{T}(\mathbf{x})`.

**********
How to use
Expand Down
1 change: 1 addition & 0 deletions doc/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ See `Awesome XGBoost </~https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
monotonic
rf
feature_interaction_constraint
aft_survival_analysis
input_format
param_tuning
external_memory
Expand Down
166 changes: 142 additions & 24 deletions src/common/survival_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,106 @@
/~https://github.com/avinashbarnwal/GSOC-2019/blob/master/doc/Accelerated_Failure_Time.pdf
*/

namespace {

// Allowable range for gradient and hessian. Used for regularization
constexpr double kMinGradient = -15.0;
constexpr double kMaxGradient = 15.0;
constexpr double kMinHessian = 1e-16; // Ensure that no data point gets zero hessian
constexpr double kMaxHessian = 15.0;

constexpr double kEps = 1e-12; // A denomitor in a fraction should not be too small

// Clip (limit) x to fit range [x_min, x_max].
// If x < x_min, return x_min; if x > x_max, return x_max; if x_min <= x <= x_max, return x.
// This function assumes x_min < x_max; behavior is undefined if this assumption does not hold.
inline double Clip(double x, double x_min, double x_max) {
if (x < x_min) {
return x_min;
}
if (x > x_max) {
return x_max;
}
return x;
}

using xgboost::common::ProbabilityDistributionType;

enum class CensoringType : uint8_t {
kUncensored, kRightCensored, kLeftCensored, kIntervalCensored
};

using xgboost::GradientPairPrecise;

inline GradientPairPrecise GetLimitAtInfPred(ProbabilityDistributionType dist_type,
CensoringType censor_type,
double sign, double sigma) {
switch (censor_type) {
case CensoringType::kUncensored:
switch (dist_type) {
case ProbabilityDistributionType::kNormal:
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
case ProbabilityDistributionType::kLogistic:
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
case ProbabilityDistributionType::kExtreme:
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
default:
LOG(FATAL) << "Unknown distribution type";
}
case CensoringType::kRightCensored:
switch (dist_type) {
case ProbabilityDistributionType::kNormal:
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
: GradientPairPrecise{ 0.0, kMinHessian };
case ProbabilityDistributionType::kLogistic:
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
: GradientPairPrecise{ 0.0, kMinHessian };
case ProbabilityDistributionType::kExtreme:
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
: GradientPairPrecise{ 0.0, kMinHessian };
default:
LOG(FATAL) << "Unknown distribution type";
}
case CensoringType::kLeftCensored:
switch (dist_type) {
case ProbabilityDistributionType::kNormal:
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
case ProbabilityDistributionType::kLogistic:
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
case ProbabilityDistributionType::kExtreme:
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
default:
LOG(FATAL) << "Unknown distribution type";
}
case CensoringType::kIntervalCensored:
switch (dist_type) {
case ProbabilityDistributionType::kNormal:
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
case ProbabilityDistributionType::kLogistic:
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
case ProbabilityDistributionType::kExtreme:
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
default:
LOG(FATAL) << "Unknown distribution type";
}
default:
LOG(FATAL) << "Unknown censoring type";
}

return { 0.0, 0.0 };
}

} // anonymous namespace

namespace xgboost {
namespace common {

Expand All @@ -26,14 +126,14 @@ DMLC_REGISTER_PARAMETER(AFTParam);
double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
const double eps = 1e-12;

double cost;

if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
// Regularize the denominator with eps, to avoid INF or NAN
cost = -std::log(std::max(pdf / (sigma * y_lower), eps));
cost = -std::log(std::max(pdf / (sigma * y_lower), kEps));
} else { // censored; now check what type of censorship we have
double z_u, z_l, cdf_u, cdf_l;
if (std::isinf(y_upper)) { // right-censored
Expand All @@ -49,7 +149,7 @@ double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma
cdf_l = dist_->CDF(z_l);
}
// Regularize the denominator with eps, to avoid INF or NAN
cost = -std::log(std::max(cdf_u - cdf_l, eps));
cost = -std::log(std::max(cdf_u - cdf_l, kEps));
}

return cost;
Expand All @@ -58,20 +158,25 @@ double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma
double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
double gradient;
const double eps = 1e-12;
double numerator, denominator, gradient; // numerator and denominator of gradient
CensoringType censor_type;
bool z_sign; // sign of z-score

if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
const double grad_pdf = dist_->GradPDF(z);
// Regularize the denominator with eps, so that gradient doesn't get too big
gradient = grad_pdf / (sigma * std::max(pdf, eps));
censor_type = CensoringType::kUncensored;
numerator = grad_pdf;
denominator = sigma * pdf;
z_sign = (z > 0);
} else { // censored; now check what type of censorship we have
double z_u, z_l, pdf_u, pdf_l, cdf_u, cdf_l;
double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l;
censor_type = CensoringType::kIntervalCensored;
if (std::isinf(y_upper)) { // right-censored
pdf_u = 0;
cdf_u = 1;
censor_type = CensoringType::kRightCensored;
} else { // interval-censored or left-censored
z_u = (log_y_upper - y_pred) / sigma;
pdf_u = dist_->PDF(z_u);
Expand All @@ -80,38 +185,48 @@ double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double s
if (std::isinf(y_lower)) { // left-censored
pdf_l = 0;
cdf_l = 0;
censor_type = CensoringType::kLeftCensored;
} else { // interval-censored or right-censored
z_l = (log_y_lower - y_pred) / sigma;
pdf_l = dist_->PDF(z_l);
cdf_l = dist_->CDF(z_l);
}
// Regularize the denominator with eps, so that gradient doesn't get too big
gradient = (pdf_u - pdf_l) / (sigma * std::max(cdf_u - cdf_l, eps));
z_sign = (z_u > 0 || z_l > 0);
numerator = pdf_u - pdf_l;
denominator = sigma * (cdf_u - cdf_l);
}
gradient = numerator / denominator;
if (denominator < kEps && (std::isnan(gradient) || std::isinf(gradient))) {
gradient = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).GetGrad();
}

return gradient;
return Clip(gradient, kMinGradient, kMaxGradient);
}

double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
const double eps = 1e-12;
double hessian;
double numerator, denominator, hessian; // numerator and denominator of hessian
CensoringType censor_type;
bool z_sign; // sign of z-score

if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
const double grad_pdf = dist_->GradPDF(z);
const double hess_pdf = dist_->HessPDF(z);
// Regularize the denominator with eps, so that gradient doesn't get too big
hessian = -(pdf * hess_pdf - std::pow(grad_pdf, 2))
/ (std::pow(sigma, 2) * std::pow(std::max(pdf, eps), 2));
censor_type = CensoringType::kUncensored;
numerator = -(pdf * hess_pdf - grad_pdf * grad_pdf);
denominator = sigma * sigma * pdf * pdf;
z_sign = (z > 0);
} else { // censored; now check what type of censorship we have
double z_u, z_l, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
censor_type = CensoringType::kIntervalCensored;
if (std::isinf(y_upper)) { // right-censored
pdf_u = 0;
cdf_u = 1;
grad_pdf_u = 0;
censor_type = CensoringType::kRightCensored;
} else { // interval-censored or left-censored
z_u = (log_y_upper - y_pred) / sigma;
pdf_u = dist_->PDF(z_u);
Expand All @@ -122,6 +237,7 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si
pdf_l = 0;
cdf_l = 0;
grad_pdf_l = 0;
censor_type = CensoringType::kLeftCensored;
} else { // interval-censored or right-censored
z_l = (log_y_lower - y_pred) / sigma;
pdf_l = dist_->PDF(z_l);
Expand All @@ -131,15 +247,17 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si
const double cdf_diff = cdf_u - cdf_l;
const double pdf_diff = pdf_u - pdf_l;
const double grad_diff = grad_pdf_u - grad_pdf_l;
// Regularize the denominator with eps, so that gradient doesn't get too big
const double cdf_diff_thresh = std::max(cdf_diff, eps);
const double numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
const double sqrt_denominator = sigma * cdf_diff_thresh;
const double denominator = sqrt_denominator * sqrt_denominator;
hessian = numerator / denominator;
const double sqrt_denominator = sigma * cdf_diff;
z_sign = (z_u > 0 || z_l > 0);
numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
denominator = sqrt_denominator * sqrt_denominator;
}
hessian = numerator / denominator;
if (denominator < kEps && (std::isnan(hessian) || std::isinf(hessian))) {
hessian = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).GetHess();
}

return hessian;
return Clip(hessian, kMinHessian, kMaxHessian);
}

} // namespace common
Expand Down
9 changes: 5 additions & 4 deletions src/common/survival_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ struct AFTParam : public XGBoostParameter<AFTParam> {
class AFTLoss {
private:
std::unique_ptr<ProbabilityDistribution> dist_;
ProbabilityDistributionType dist_type_;

public:
/*!
* \brief Constructor for AFT loss function
* \param dist Choice of probability distribution for the noise term in AFT
* \param dist_type Choice of probability distribution for the noise term in AFT
*/
explicit AFTLoss(ProbabilityDistributionType dist) {
dist_.reset(ProbabilityDistribution::Create(dist));
}
explicit AFTLoss(ProbabilityDistributionType dist_type)
: dist_(ProbabilityDistribution::Create(dist_type)),
dist_type_(dist_type) {}

public:
/*!
Expand Down
44 changes: 44 additions & 0 deletions tests/cpp/common/test_survival_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*!
* Copyright (c) by Contributors 2020
*/
#include <gtest/gtest.h>

#include "../../../src/common/survival_util.h"

namespace xgboost {
namespace common {

inline static void RobustTestSuite(ProbabilityDistributionType dist_type,
double y_lower, double y_upper, double sigma) {
AFTLoss loss(dist_type);
for (int i = 50; i >= -50; --i) {
const double y_pred = std::pow(10.0, static_cast<double>(i));
const double z = (std::log(y_lower) - std::log(y_pred)) / sigma;
const double gradient = loss.Gradient(y_lower, y_upper, std::log(y_pred), sigma);
const double hessian = loss.Hessian(y_lower, y_upper, std::log(y_pred), sigma);
ASSERT_FALSE(std::isnan(gradient)) << "z = " << z << ", y \\in ["
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
<< ", dist = " << static_cast<int>(dist_type);
ASSERT_FALSE(std::isinf(gradient)) << "z = " << z << ", y \\in ["
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
<< ", dist = " << static_cast<int>(dist_type);
ASSERT_FALSE(std::isnan(hessian)) << "z = " << z << ", y \\in ["
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
<< ", dist = " << static_cast<int>(dist_type);
ASSERT_FALSE(std::isinf(hessian)) << "z = " << z << ", y \\in ["
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
<< ", dist = " << static_cast<int>(dist_type);
}
}

TEST(AFTLoss, RobustGradientPair) { // Ensure that INF and NAN don't show up in gradient pair
RobustTestSuite(ProbabilityDistributionType::kNormal, 16.0, 200.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kLogistic, 16.0, 200.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kExtreme, 16.0, 200.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kNormal, 100.0, 100.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kLogistic, 100.0, 100.0, 2.0);
RobustTestSuite(ProbabilityDistributionType::kExtreme, 100.0, 100.0, 2.0);
}

} // namespace common
} // namespace xgboost
15 changes: 7 additions & 8 deletions tests/cpp/objective/test_aft_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ TEST(Objective, AFTObjGPairUncensoredLabels) {
{ 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f,
0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f });
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme",
{ -0.0000f, -29.0026f, -17.0031f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f,
{ -15.0000f, -15.0000f, -15.0000f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f,
0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f,
0.9969f },
{ 0.0000f, 30.0026f, 18.0031f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f,
{ 15.0000f, 15.0000f, 15.0000f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f,
0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f });
}

Expand All @@ -106,10 +106,9 @@ TEST(Objective, AFTObjGPairLeftCensoredLabels) {

CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "normal",
{ 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f,
3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 0.5072f },
3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 7.5326f },
{ 0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f,
0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9812f, 0.0045f },
2e-4);
0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9813f, 0.9877f });
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "logistic",
{ 0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f,
0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f },
Expand Down Expand Up @@ -139,10 +138,10 @@ TEST(Objective, AFTObjGPairRightCensoredLabels) {
{ 0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f,
0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f });
CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "extreme",
{ -2.8073f, -18.0015f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f,
{ -15.0000f, -15.0000f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f,
-0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, -0.0235f, -0.0141f, -0.0085f, -0.0051f,
-0.0031f, -0.0018f },
{ 0.2614f, 18.0015f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f,
-0.0031f, -0.0018f },
{ 15.0000f, 15.0000f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f,
0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f });
}

Expand Down
Loading