diff --git a/lib/_tls_wrap.js b/lib/_tls_wrap.js index c1037a7096a755..84b02a731bef16 100644 --- a/lib/_tls_wrap.js +++ b/lib/_tls_wrap.js @@ -295,9 +295,15 @@ TLSSocket.prototype._wrapHandle = function(handle) { } }); + this.on('close', this._destroySSL); + return res; }; +TLSSocket.prototype._destroySSL = function _destroySSL() { + return this.ssl.destroySSL(); +}; + TLSSocket.prototype._init = function(socket, wrap) { var self = this; var options = this._tlsOptions; @@ -416,6 +422,9 @@ TLSSocket.prototype.renegotiate = function(options, callback) { var requestCert = this._requestCert, rejectUnauthorized = this._rejectUnauthorized; + if (this.destroyed) + return; + if (typeof options.requestCert !== 'undefined') requestCert = !!options.requestCert; if (typeof options.rejectUnauthorized !== 'undefined') diff --git a/src/node_crypto.cc b/src/node_crypto.cc index 04b50361bab264..1a7388bea0ff32 100644 --- a/src/node_crypto.cc +++ b/src/node_crypto.cc @@ -131,6 +131,7 @@ template int SSLWrap::SelectNextProtoCallback( void* arg); #endif template int SSLWrap::TLSExtStatusCallback(SSL* s, void* arg); +template void SSLWrap::DestroySSL(); static void crypto_threadid_cb(CRYPTO_THREADID* tid) { @@ -1871,6 +1872,16 @@ void SSLWrap::SSLGetter(Local property, } +template +void SSLWrap::DestroySSL() { + if (ssl_ == nullptr) + return; + + SSL_free(ssl_); + ssl_ = nullptr; +} + + void Connection::OnClientHelloParseEnd(void* arg) { Connection* conn = static_cast(arg); diff --git a/src/node_crypto.h b/src/node_crypto.h index 75ffe4f31ddeea..8fec4bb6253c2e 100644 --- a/src/node_crypto.h +++ b/src/node_crypto.h @@ -144,10 +144,7 @@ class SSLWrap { } virtual ~SSLWrap() { - if (ssl_ != nullptr) { - SSL_free(ssl_); - ssl_ = nullptr; - } + DestroySSL(); if (next_sess_ != nullptr) { SSL_SESSION_free(next_sess_); next_sess_ = nullptr; @@ -221,6 +218,8 @@ class SSLWrap { static void SSLGetter(v8::Local property, const v8::PropertyCallbackInfo& info); + void DestroySSL(); + inline Environment* ssl_env() const { return env_; } diff --git a/src/tls_wrap.cc b/src/tls_wrap.cc index c774a8490b54f2..c013bf935eb9cb 100644 --- a/src/tls_wrap.cc +++ b/src/tls_wrap.cc @@ -208,7 +208,7 @@ void TLSWrap::Receive(const FunctionCallbackInfo& args) { uv_buf_t buf; // Copy given buffer entirely or partiall if handle becomes closed - while (len > 0 && !wrap->IsClosing()) { + while (len > 0 && wrap->IsAlive() && !wrap->IsClosing()) { wrap->stream_->OnAlloc(len, &buf); size_t copy = buf.len > len ? len : buf.len; memcpy(buf.base, data, copy); @@ -282,6 +282,9 @@ void TLSWrap::EncOut() { if (established_ && !write_item_queue_.IsEmpty()) MakePending(); + if (ssl_ == nullptr) + return; + // No data to write if (BIO_pending(enc_out_) == 0) { if (clear_in_->Length() == 0) @@ -396,7 +399,8 @@ void TLSWrap::ClearOut() { if (eof_) return; - CHECK_NE(ssl_, nullptr); + if (ssl_ == nullptr) + return; char out[kClearOutChunkSize]; int read; @@ -451,6 +455,9 @@ bool TLSWrap::ClearIn() { if (!hello_parser_.IsEnded()) return false; + if (ssl_ == nullptr) + return false; + int written = 0; while (clear_in_->Length() > 0) { size_t avail = 0; @@ -503,7 +510,7 @@ int TLSWrap::GetFD() { bool TLSWrap::IsAlive() { - return stream_->IsAlive(); + return ssl_ != nullptr && stream_->IsAlive(); } @@ -573,6 +580,9 @@ int TLSWrap::DoWrite(WriteWrap* w, return 0; } + if (ssl_ == nullptr) + return UV_EPROTO; + int written = 0; for (i = 0; i < count; i++) { written = SSL_write(ssl_, bufs[i].base, bufs[i].len); @@ -660,7 +670,10 @@ void TLSWrap::DoRead(ssize_t nread, } // Only client connections can receive data - CHECK_NE(ssl_, nullptr); + if (ssl_ == nullptr) { + OnRead(UV_EPROTO, nullptr); + return; + } // Commit read data NodeBIO* enc_in = NodeBIO::FromBIO(enc_in_); @@ -680,7 +693,7 @@ void TLSWrap::DoRead(ssize_t nread, int TLSWrap::DoShutdown(ShutdownWrap* req_wrap) { - if (SSL_shutdown(ssl_) == 0) + if (ssl_ != nullptr && SSL_shutdown(ssl_) == 0) SSL_shutdown(ssl_); shutdown_ = true; EncOut(); @@ -696,6 +709,9 @@ void TLSWrap::SetVerifyMode(const FunctionCallbackInfo& args) { if (args.Length() < 2 || !args[0]->IsBoolean() || !args[1]->IsBoolean()) return env->ThrowTypeError("Bad arguments, expected two booleans"); + if (wrap->ssl_ == nullptr) + return env->ThrowTypeError("SetVerifyMode after destroySSL"); + int verify_mode; if (wrap->is_server()) { bool request_cert = args[0]->IsTrue(); @@ -735,6 +751,14 @@ void TLSWrap::EnableHelloParser(const FunctionCallbackInfo& args) { } +void TLSWrap::DestroySSL(const FunctionCallbackInfo& args) { + TLSWrap* wrap = Unwrap(args.Holder()); + wrap->SSLWrap::DestroySSL(); + delete wrap->clear_in_; + wrap->clear_in_ = nullptr; +} + + void TLSWrap::OnClientHelloParseEnd(void* arg) { TLSWrap* c = static_cast(arg); c->Cycle(); @@ -747,6 +771,8 @@ void TLSWrap::GetServername(const FunctionCallbackInfo& args) { TLSWrap* wrap = Unwrap(args.Holder()); + CHECK_NE(wrap->ssl_, nullptr); + const char* servername = SSL_get_servername(wrap->ssl_, TLSEXT_NAMETYPE_host_name); if (servername != nullptr) { @@ -771,6 +797,8 @@ void TLSWrap::SetServername(const FunctionCallbackInfo& args) { if (!wrap->is_client()) return; + CHECK_NE(wrap->ssl_, nullptr); + #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB node::Utf8Value servername(env->isolate(), args[0].As()); SSL_set_tlsext_host_name(wrap->ssl_, *servername); @@ -830,6 +858,7 @@ void TLSWrap::Initialize(Handle target, env->SetProtoMethod(t, "setVerifyMode", SetVerifyMode); env->SetProtoMethod(t, "enableSessionCallbacks", EnableSessionCallbacks); env->SetProtoMethod(t, "enableHelloParser", EnableHelloParser); + env->SetProtoMethod(t, "destroySSL", DestroySSL); StreamBase::AddMethods(env, t, StreamBase::kFlagHasWritev); SSLWrap::AddMethods(env, t); diff --git a/src/tls_wrap.h b/src/tls_wrap.h index 9f095355bb58bd..25088d30261189 100644 --- a/src/tls_wrap.h +++ b/src/tls_wrap.h @@ -132,6 +132,7 @@ class TLSWrap : public crypto::SSLWrap, const v8::FunctionCallbackInfo& args); static void EnableHelloParser( const v8::FunctionCallbackInfo& args); + static void DestroySSL(const v8::FunctionCallbackInfo& args); #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB static void GetServername(const v8::FunctionCallbackInfo& args);