From 027c73e7aae1619f765ea113b1b838171fdaeeb0 Mon Sep 17 00:00:00 2001 From: Nick Logan Date: Sat, 11 May 2024 20:00:18 -0500 Subject: [PATCH] Make rawQuery timeout handling more thread-safe Previous an `elapsed` duration was saved on the `Client` object and was updated multiple times on each call to `rawQuery`. This could lead to a race condition when calling `whois.Whois(...)` in parallel, where one thread halfway through the function has a non-zero elapsed duration and another thread entering the function sets it to 0. This removes the `elapsed` field on the `Client` object and uses a local variable instead, which avoids the aforementioned issue. --- whois.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/whois.go b/whois.go index b91f310..0f336c3 100644 --- a/whois.go +++ b/whois.go @@ -48,7 +48,6 @@ var DefaultClient = NewClient() type Client struct { dialer proxy.Dialer timeout time.Duration - elapsed time.Duration disableStats bool disableReferral bool } @@ -175,7 +174,6 @@ func (c *Client) Whois(domain string, servers ...string) (result string, err err // rawQuery do raw query to the server func (c *Client) rawQuery(domain, server, port string) (string, error) { - c.elapsed = 0 start := time.Now() if server == "whois.arin.net" { @@ -202,24 +200,22 @@ func (c *Client) rawQuery(domain, server, port string) (string, error) { } defer conn.Close() - c.elapsed = time.Since(start) + elapsed := time.Since(start) - _ = conn.SetWriteDeadline(time.Now().Add(c.timeout - c.elapsed)) + _ = conn.SetWriteDeadline(time.Now().Add(c.timeout - elapsed)) _, err = conn.Write([]byte(domain + "\r\n")) if err != nil { return "", fmt.Errorf("whois: send to whois server failed: %w", err) } - c.elapsed = time.Since(start) + elapsed = time.Since(start) - _ = conn.SetReadDeadline(time.Now().Add(c.timeout - c.elapsed)) + _ = conn.SetReadDeadline(time.Now().Add(c.timeout - elapsed)) buffer, err := io.ReadAll(conn) if err != nil { return "", fmt.Errorf("whois: read from whois server failed: %w", err) } - c.elapsed = time.Since(start) - return string(buffer), nil }