diff --git a/client.go b/client.go index 7a24309c..4c4b6b82 100644 --- a/client.go +++ b/client.go @@ -6,12 +6,14 @@ package websocket import ( "bytes" + "context" "crypto/tls" "errors" "io" "io/ioutil" "net" "net/http" + "net/http/httptrace" "net/url" "strings" "time" @@ -51,6 +53,10 @@ type Dialer struct { // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) + // NetDialContext specifies the dial function for creating TCP connections. If + // NetDialContext is nil, net.DialContext is used. + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -95,6 +101,11 @@ type Dialer struct { Jar http.CookieJar } +// Dial creates a new client connection by calling DialContext with a background context. +func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + return d.DialContext(urlStr, requestHeader, context.Background()) +} + var errMalformedURL = errors.New("malformed ws or wss URL") func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { @@ -124,17 +135,18 @@ var DefaultDialer = &Dialer{ // nilDialer is dialer to use when receiver is nil. var nilDialer Dialer = *DefaultDialer -// Dial creates a new client connection. Use requestHeader to specify the +// DialContext creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // Use the response.Header to get the selected subprotocol // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // +// The context will be used in the request and in the Dialer +// // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. -func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { - +func (d *Dialer) DialContext(urlStr string, requestHeader http.Header, ctx context.Context) (*Conn, *http.Response, error) { if d == nil { d = &nilDialer } @@ -172,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re Header: make(http.Header), Host: u.Host, } + req = req.WithContext(ctx) // Set the cookies present in the cookie jar of the dialer if d.Jar != nil { @@ -215,20 +228,30 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} } - var deadline time.Time if d.HandshakeTimeout != 0 { - deadline = time.Now().Add(d.HandshakeTimeout) + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() } // Get network dial function. - netDial := d.NetDial - if netDial == nil { - netDialer := &net.Dialer{Deadline: deadline} - netDial = netDialer.Dial + var netDial func(network, add string) (net.Conn, error) + + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } else { + netDialer := &net.Dialer{} + netDial = func(network, addr string) (net.Conn, error) { + return netDialer.DialContext(ctx, network, addr) + } } // If needed, wrap the dial function to set the connection deadline. - if !deadline.Equal(time.Time{}) { + if deadline, ok := ctx.Deadline(); ok { forwardDial := netDial netDial = func(network, addr string) (net.Conn, error) { c, err := forwardDial(network, addr) @@ -260,7 +283,17 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + netConn, err := netDial("tcp", hostPort) + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } if err != nil { return nil, nil, err } @@ -278,13 +311,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } tlsConn := tls.Client(netConn, cfg) netConn = tlsConn - if err := tlsConn.Handshake(); err != nil { - return nil, nil, err + + var err error + if trace != nil { + err = doHandshakeWithTrace(trace, tlsConn, cfg) + } else { + err = doHandshake(tlsConn, cfg) } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return nil, nil, err - } + + if err != nil { + return nil, nil, err } } @@ -294,6 +330,12 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + resp, err := http.ReadResponse(conn.br, req) if err != nil { return nil, nil, err @@ -339,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re netConn = nil // to avoid close in defer. return conn, resp, nil } + +func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.Handshake(); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/client_server_test.go b/client_server_test.go index cdbc9d4b..f829566a 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -6,6 +6,7 @@ package websocket import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/base64" @@ -16,6 +17,7 @@ import ( "net/http" "net/http/cookiejar" "net/http/httptest" + "net/http/httptrace" "net/url" "reflect" "strings" @@ -40,6 +42,12 @@ var cstDialer = Dialer{ HandshakeTimeout: 30 * time.Second, } +var cstDialerWithoutHandshakeTimeout = Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + type cstHandler struct{ *testing.T } type cstServer struct { @@ -403,6 +411,26 @@ func TestHandshakeTimeout(t *testing.T) { ws.Close() } +func TestHandshakeTimeoutInContext(t *testing.T) { + s := newServer(t) + defer s.Close() + + d := cstDialerWithoutHandshakeTimeout + d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { + netDialer := &net.Dialer{} + c, err := netDialer.DialContext(ctx, n, a) + return &requireDeadlineNetConn{c: c, t: t}, err + } + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) + defer cancel() + ws, _, err := d.DialContext(s.URL, nil, ctx) + if err != nil { + t.Fatal("Dial:", err) + } + ws.Close() +} + func TestDialBadScheme(t *testing.T) { s := newServer(t) defer s.Close() @@ -659,3 +687,104 @@ func TestSocksProxyDial(t *testing.T) { defer ws.Close() sendRecv(t, ws) } + +func TestTracingDialWithContext(t *testing.T) { + + var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool + trace := &httptrace.ClientTrace{ + WroteHeaders: func() { + headersWrote = true + }, + WroteRequest: func(httptrace.WroteRequestInfo) { + requestWrote = true + }, + GetConn: func(hostPort string) { + getConn = true + }, + GotConn: func(info httptrace.GotConnInfo) { + gotConn = true + }, + ConnectDone: func(network, addr string, err error) { + connectDone = true + }, + GotFirstResponseByte: func() { + gotFirstResponseByte = true + }, + } + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) + defer s.Close() + + certs := x509.NewCertPool() + for _, c := range s.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: certs} + + ws, _, err := d.DialContext(s.URL, nil, ctx) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + if !headersWrote { + t.Fatal("Headers was not written") + } + if !requestWrote { + t.Fatal("Request was not written") + } + if !getConn { + t.Fatal("getConn was not called") + } + if !gotConn { + t.Fatal("gotConn was not called") + } + if !connectDone { + t.Fatal("connectDone was not called") + } + if !gotFirstResponseByte { + t.Fatal("GotFirstResponseByte was not called") + } + + defer ws.Close() + sendRecv(t, ws) +} + +func TestEmptyTracingDialWithContext(t *testing.T) { + + trace := &httptrace.ClientTrace{} + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) + defer s.Close() + + certs := x509.NewCertPool() + for _, c := range s.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: certs} + + ws, _, err := d.DialContext(s.URL, nil, ctx) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + defer ws.Close() + sendRecv(t, ws) +} diff --git a/trace.go b/trace.go new file mode 100644 index 00000000..834f122a --- /dev/null +++ b/trace.go @@ -0,0 +1,19 @@ +// +build go1.8 + +package websocket + +import ( + "crypto/tls" + "net/http/httptrace" +) + +func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { + if trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(tlsConn, cfg) + if trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + return err +} diff --git a/trace_17.go b/trace_17.go new file mode 100644 index 00000000..77d05a0b --- /dev/null +++ b/trace_17.go @@ -0,0 +1,12 @@ +// +build !go1.8 + +package websocket + +import ( + "crypto/tls" + "net/http/httptrace" +) + +func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { + return doHandshake(tlsConn, cfg) +}