diff --git a/data/passfile b/data/passfile new file mode 100644 index 000000000..3f44b6fe7 --- /dev/null +++ b/data/passfile @@ -0,0 +1,2 @@ +localhost:5432:dbname:username:password +127.0.0.1:5432:*:*:password2 diff --git a/go.mod b/go.mod index 6056b37c2..98d9c1b8e 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-playground/validator/v10 v10.9.0 // indirect github.com/golang/protobuf v1.5.2 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/mattn/go-isatty v0.0.14 // indirect diff --git a/go.sum b/go.sum index 3b0194cdd..dd6aa2a74 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= diff --git a/pkg/api/api.go b/pkg/api/api.go index ee8f91a71..0849bede0 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -142,22 +142,22 @@ func Connect(c *gin.Context) { return } - var sshInfo *shared.SSHInfo url := c.Request.FormValue("url") - if url == "" { badRequest(c, errURLRequired) return } - opts := command.Options{URL: url} - url, err := connection.FormatURL(opts) - + url, err := connection.FormatURL(command.Options{ + URL: url, + Passfile: command.Opts.Passfile, + }) if err != nil { badRequest(c, err) return } + var sshInfo *shared.SSHInfo if c.Request.FormValue("ssh") != "" { sshInfo = parseSshInfo(c) } diff --git a/pkg/command/options.go b/pkg/command/options.go index 348deb4ce..ac942481d 100644 --- a/pkg/command/options.go +++ b/pkg/command/options.go @@ -5,8 +5,10 @@ import ( "fmt" "os" "os/user" + "path/filepath" "strings" + "github.com/jackc/pgpassfile" "github.com/jessevdk/go-flags" "github.com/sirupsen/logrus" ) @@ -26,6 +28,7 @@ type Options struct { Port int `long:"port" description:"Server port" default:"5432"` User string `long:"user" description:"Database user"` Pass string `long:"pass" description:"Password for user"` + Passfile string `long:"passfile" description:"Local passwords file location"` DbName string `long:"db" description:"Database name"` SSLMode string `long:"ssl" description:"SSL mode"` SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"` @@ -79,6 +82,23 @@ func ParseOptions(args []string) (Options, error) { opts.Prefix = getPrefixedEnvVar("URL_PREFIX") } + if opts.Passfile == "" { + passfile := os.Getenv("PGPASSFILE") + if passfile == "" { + passfile = filepath.Join(os.Getenv("HOME"), ".pgpass") + } + + _, err := os.Stat(passfile) + if err == nil { + _, err = pgpassfile.ReadPassfile(passfile) + if err == nil { + opts.Passfile = passfile + } else { + fmt.Printf("[WARN] Pgpass file unreadable: %s\n", err) + } + } + } + // Handle edge case where pgweb is started with a default host `localhost` and no user. // When user is not set the `lib/pq` connection will fail and cause pgweb's termination. if (opts.Host == "localhost" || opts.Host == "127.0.0.1") && opts.User == "" { diff --git a/pkg/command/options_test.go b/pkg/command/options_test.go index 961370c2f..75eed5f7b 100644 --- a/pkg/command/options_test.go +++ b/pkg/command/options_test.go @@ -1,47 +1,75 @@ package command import ( + "os" "testing" "github.com/stretchr/testify/assert" ) func TestParseOptions(t *testing.T) { - // Test default behavior - opts, err := ParseOptions([]string{}) - assert.NoError(t, err) - assert.Equal(t, false, opts.Sessions) - assert.Equal(t, "", opts.Prefix) - assert.Equal(t, "", opts.ConnectToken) - assert.Equal(t, "", opts.ConnectHeaders) - assert.Equal(t, false, opts.DisableSSH) - assert.Equal(t, false, opts.DisablePrettyJSON) - assert.Equal(t, false, opts.DisableConnectionIdleTimeout) - assert.Equal(t, 180, opts.ConnectionIdleTimeout) - assert.Equal(t, false, opts.Cors) - assert.Equal(t, "*", opts.CorsOrigin) - - // Test sessions - opts, err = ParseOptions([]string{"--sessions", "1"}) - assert.NoError(t, err) - assert.Equal(t, true, opts.Sessions) - - // Test url prefix - opts, err = ParseOptions([]string{"--prefix", "pgweb"}) - assert.NoError(t, err) - assert.Equal(t, "pgweb/", opts.Prefix) - - opts, err = ParseOptions([]string{"--prefix", "pgweb/"}) - assert.NoError(t, err) - assert.Equal(t, "pgweb/", opts.Prefix) - - // Test connect backend options - opts, err = ParseOptions([]string{"--connect-backend", "test"}) - assert.EqualError(t, err, "--sessions flag must be set") - - opts, err = ParseOptions([]string{"--connect-backend", "test", "--sessions"}) - assert.EqualError(t, err, "--connect-token flag must be set") - - opts, err = ParseOptions([]string{"--connect-backend", "test", "--sessions", "--connect-token", "token"}) - assert.NoError(t, err) + t.Run("defaults", func(t *testing.T) { + opts, err := ParseOptions([]string{}) + assert.NoError(t, err) + assert.Equal(t, false, opts.Sessions) + assert.Equal(t, "", opts.Prefix) + assert.Equal(t, "", opts.ConnectToken) + assert.Equal(t, "", opts.ConnectHeaders) + assert.Equal(t, false, opts.DisableSSH) + assert.Equal(t, false, opts.DisablePrettyJSON) + assert.Equal(t, false, opts.DisableConnectionIdleTimeout) + assert.Equal(t, 180, opts.ConnectionIdleTimeout) + assert.Equal(t, false, opts.Cors) + assert.Equal(t, "*", opts.CorsOrigin) + assert.Equal(t, "", opts.Passfile) + }) + + t.Run("sessions", func(t *testing.T) { + opts, err := ParseOptions([]string{"--sessions", "1"}) + assert.NoError(t, err) + assert.Equal(t, true, opts.Sessions) + }) + + t.Run("url prefix", func(t *testing.T) { + opts, err := ParseOptions([]string{"--prefix", "pgweb"}) + assert.NoError(t, err) + assert.Equal(t, "pgweb/", opts.Prefix) + + opts, err = ParseOptions([]string{"--prefix", "pgweb/"}) + assert.NoError(t, err) + assert.Equal(t, "pgweb/", opts.Prefix) + }) + + t.Run("connect backend", func(t *testing.T) { + _, err := ParseOptions([]string{"--connect-backend", "test"}) + assert.EqualError(t, err, "--sessions flag must be set") + + _, err = ParseOptions([]string{"--connect-backend", "test", "--sessions"}) + assert.EqualError(t, err, "--connect-token flag must be set") + + _, err = ParseOptions([]string{"--connect-backend", "test", "--sessions", "--connect-token", "token"}) + assert.NoError(t, err) + }) + + t.Run("passfile", func(t *testing.T) { + defer os.Unsetenv("PGPASSFILE") + + // File does not exist + os.Setenv("PGPASSFILE", "/tmp/foo") + opts, err := ParseOptions([]string{}) + assert.NoError(t, err) + assert.Equal(t, "", opts.Passfile) + + // File exists and valid + os.Setenv("PGPASSFILE", "../../data/passfile") + opts, err = ParseOptions([]string{}) + assert.NoError(t, err) + assert.Equal(t, "../../data/passfile", opts.Passfile) + + // Set via flag + os.Unsetenv("PGPASSFILE") + opts, err = ParseOptions([]string{"--passfile", "../../data/passfile"}) + assert.NoError(t, err) + assert.Equal(t, "../../data/passfile", opts.Passfile) + }) } diff --git a/pkg/connection/connection_string.go b/pkg/connection/connection_string.go index 935b4d995..4ed8e7679 100644 --- a/pkg/connection/connection_string.go +++ b/pkg/connection/connection_string.go @@ -8,6 +8,7 @@ import ( "os/user" "strings" + "github.com/jackc/pgpassfile" "github.com/sosedoff/pgweb/pkg/command" ) @@ -76,6 +77,17 @@ func FormatURL(opts command.Options) (string, error) { } } + // When password is not provided, look it up from a .pgpass file + if uri.User != nil { + pass, _ := uri.User.Password() + if pass == "" && opts.Passfile != "" { + pass = lookupPassword(opts, uri) + if pass != "" { + uri.User = neturl.UserPassword(uri.User.Username(), pass) + } + } + } + // Rebuild query params query := neturl.Values{} for k, v := range params { @@ -125,6 +137,11 @@ func BuildStringFromOptions(opts command.Options) (string, error) { query.Add("sslrootcert", opts.SSLRootCert) } + // Grab password from .pgpass file if it's available + if opts.Pass == "" && opts.Passfile != "" { + opts.Pass = lookupPassword(opts, nil) + } + url := neturl.URL{ Scheme: "postgres", Host: fmt.Sprintf("%v:%v", opts.Host, opts.Port), @@ -135,3 +152,34 @@ func BuildStringFromOptions(opts command.Options) (string, error) { return url.String(), nil } + +func lookupPassword(opts command.Options, url *neturl.URL) string { + if opts.Passfile == "" { + return "" + } + + passfile, err := pgpassfile.ReadPassfile(opts.Passfile) + if err != nil { + fmt.Println("[WARN] .pgpassfile", opts.Passfile, "is not readable") + return "" + } + + if url != nil { + var dbName string + fmt.Sscanf(url.Path, "/%s", &dbName) + + return passfile.FindPassword( + url.Hostname(), + url.Port(), + dbName, + url.User.Username(), + ) + } + + return passfile.FindPassword( + opts.Host, + fmt.Sprintf("%d", opts.Port), + opts.DbName, + opts.User, + ) +} diff --git a/pkg/connection/connection_string_test.go b/pkg/connection/connection_string_test.go index ac734daae..4e656fa3f 100644 --- a/pkg/connection/connection_string_test.go +++ b/pkg/connection/connection_string_test.go @@ -10,163 +10,243 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_Invalid_Url(t *testing.T) { - opts := command.Options{} - examples := []string{ - "postgre://foobar", - "tcp://blah", - "foobar", - } - - for _, val := range examples { - opts.URL = val - str, err := BuildStringFromOptions(opts) - - assert.Equal(t, "", str) - assert.Error(t, err) - assert.Equal(t, "Invalid URL. Valid format: postgres://user:password@host:port/db?sslmode=mode", err.Error()) - } -} - -func Test_Valid_Url(t *testing.T) { - url := "postgres://myhost/database" - str, err := BuildStringFromOptions(command.Options{URL: url}) - - assert.Equal(t, nil, err) - assert.Equal(t, url, str) -} +func TestBuildStringFromOptions(t *testing.T) { + t.Run("valid url", func(t *testing.T) { + url := "postgres://myhost/database" + str, err := BuildStringFromOptions(command.Options{URL: url}) -func Test_Url_And_Ssl_Flag(t *testing.T) { - str, err := BuildStringFromOptions(command.Options{ - URL: "postgres://myhost/database", - SSLMode: "disable", + assert.NoError(t, err) + assert.Equal(t, url, str) }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://myhost/database?sslmode=disable", str) -} + t.Run("with sslmode param", func(t *testing.T) { + str, err := BuildStringFromOptions(command.Options{ + URL: "postgres://myhost/database", + SSLMode: "disable", + }) -func Test_Localhost_Url_And_No_Ssl_Flag(t *testing.T) { - str, err := BuildStringFromOptions(command.Options{ - URL: "postgres://localhost/database", + assert.NoError(t, err) + assert.Equal(t, "postgres://myhost/database?sslmode=disable", str) }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://localhost/database?sslmode=disable", str) - str, err = BuildStringFromOptions(command.Options{ - URL: "postgres://127.0.0.1/database", + t.Run("sets sslmode param if not set", func(t *testing.T) { + str, err := BuildStringFromOptions(command.Options{ + URL: "postgres://localhost/database", + }) + assert.NoError(t, err) + assert.Equal(t, "postgres://localhost/database?sslmode=disable", str) + + str, err = BuildStringFromOptions(command.Options{ + URL: "postgres://127.0.0.1/database", + }) + assert.NoError(t, err) + assert.Equal(t, "postgres://127.0.0.1/database?sslmode=disable", str) }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://127.0.0.1/database?sslmode=disable", str) -} -func Test_Localhost_Url_And_Ssl_Flag(t *testing.T) { - str, err := BuildStringFromOptions(command.Options{ - URL: "postgres://localhost/database", - SSLMode: "require", + t.Run("sslmode as an option", func(t *testing.T) { + str, err := BuildStringFromOptions(command.Options{ + URL: "postgres://localhost/database", + SSLMode: "require", + }) + assert.NoError(t, err) + assert.Equal(t, "postgres://localhost/database?sslmode=require", str) + + str, err = BuildStringFromOptions(command.Options{ + URL: "postgres://127.0.0.1/database", + SSLMode: "require", + }) + assert.NoError(t, err) + assert.Equal(t, "postgres://127.0.0.1/database?sslmode=require", str) }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://localhost/database?sslmode=require", str) - str, err = BuildStringFromOptions(command.Options{ - URL: "postgres://127.0.0.1/database", - SSLMode: "require", + t.Run("localhost and sslmode flag", func(t *testing.T) { + str, err := BuildStringFromOptions(command.Options{ + URL: "postgres://localhost/database?sslmode=require", + }) + assert.NoError(t, err) + assert.Equal(t, "postgres://localhost/database?sslmode=require", str) + + str, err = BuildStringFromOptions(command.Options{ + URL: "postgres://127.0.0.1/database?sslmode=require", + }) + assert.NoError(t, err) + assert.Equal(t, "postgres://127.0.0.1/database?sslmode=require", str) }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://127.0.0.1/database?sslmode=require", str) -} -func Test_Localhost_Url_And_Ssl_Arg(t *testing.T) { - str, err := BuildStringFromOptions(command.Options{ - URL: "postgres://localhost/database?sslmode=require", + t.Run("extended options", func(t *testing.T) { + str, err := BuildStringFromOptions(command.Options{ + URL: "postgres://localhost/database?sslmode=require&sslcert=cert&sslkey=key&sslrootcert=ca", + }) + assert.NoError(t, err) + assert.Equal(t, "postgres://localhost/database?sslcert=cert&sslkey=key&sslmode=require&sslrootcert=ca", str) }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://localhost/database?sslmode=require", str) - str, err = BuildStringFromOptions(command.Options{ - URL: "postgres://127.0.0.1/database?sslmode=require", + t.Run("from flags", func(t *testing.T) { + str, err := BuildStringFromOptions(command.Options{ + Host: "host", + Port: 5432, + User: "user", + Pass: "password", + DbName: "db", + }) + + assert.NoError(t, err) + assert.Equal(t, "postgres://user:password@host:5432/db", str) }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://127.0.0.1/database?sslmode=require", str) -} -func Test_ExtendedSSLFlags(t *testing.T) { - str, err := BuildStringFromOptions(command.Options{ - URL: "postgres://localhost/database?sslmode=require&sslcert=cert&sslkey=key&sslrootcert=ca", - }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://localhost/database?sslcert=cert&sslkey=key&sslmode=require&sslrootcert=ca", str) -} + t.Run("localhost", func(t *testing.T) { + opts := command.Options{ + Host: "localhost", + Port: 5432, + User: "user", + Pass: "password", + DbName: "db", + } + + str, err := BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://user:password@localhost:5432/db?sslmode=disable", str) -func Test_Flag_Args(t *testing.T) { - str, err := BuildStringFromOptions(command.Options{ - Host: "host", - Port: 5432, - User: "user", - Pass: "password", - DbName: "db", + opts.Host = "127.0.0.1" + str, err = BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://user:password@127.0.0.1:5432/db?sslmode=disable", str) }) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://user:password@host:5432/db", str) -} + t.Run("localhost and ssl", func(t *testing.T) { + opts := command.Options{ + Host: "localhost", + Port: 5432, + User: "user", + Pass: "password", + DbName: "db", + SSLMode: "require", + SSLKey: "keyPath", + SSLCert: "certPath", + SSLRootCert: "caPath", + } -func Test_Localhost(t *testing.T) { - opts := command.Options{ - Host: "localhost", - Port: 5432, - User: "user", - Pass: "password", - DbName: "db", - } + str, err := BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://user:password@localhost:5432/db?sslcert=certPath&sslkey=keyPath&sslmode=require&sslrootcert=caPath", str) + }) - str, err := BuildStringFromOptions(opts) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://user:password@localhost:5432/db?sslmode=disable", str) + t.Run("no user", func(t *testing.T) { + opts := command.Options{Host: "host", Port: 5432, DbName: "db"} + u, _ := user.Current() + str, err := BuildStringFromOptions(opts) + userAndPass := url.UserPassword(u.Username, "").String() - opts.Host = "127.0.0.1" - str, err = BuildStringFromOptions(opts) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://user:password@127.0.0.1:5432/db?sslmode=disable", str) -} + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("postgres://%s@host:5432/db", userAndPass), str) + }) -func Test_Localhost_And_Ssl(t *testing.T) { - opts := command.Options{ - Host: "localhost", - Port: 5432, - User: "user", - Pass: "password", - DbName: "db", - SSLMode: "require", - SSLKey: "keyPath", - SSLCert: "certPath", - SSLRootCert: "caPath", - } + t.Run("port", func(t *testing.T) { + opts := command.Options{Host: "host", User: "user", Port: 5000, DbName: "db"} + str, err := BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://user:@host:5000/db", str) + }) - str, err := BuildStringFromOptions(opts) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://user:password@localhost:5432/db?sslcert=certPath&sslkey=keyPath&sslmode=require&sslrootcert=caPath", str) -} + t.Run("with pgpass", func(t *testing.T) { + opts := command.Options{ + Host: "localhost", + Port: 5432, + User: "username", + DbName: "dbname", + Passfile: "../../data/passfile", + } -func Test_No_User(t *testing.T) { - opts := command.Options{Host: "host", Port: 5432, DbName: "db"} - u, _ := user.Current() - str, err := BuildStringFromOptions(opts) - userAndPass := url.UserPassword(u.Username, "").String() + str, err := BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://username:password@localhost:5432/dbname?sslmode=disable", str) + + opts.User = "foobar" + str, err = BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://foobar:@localhost:5432/dbname?sslmode=disable", str) + + opts.Host = "127.0.0.1" + opts.DbName = "foobar2" + str, err = BuildStringFromOptions(opts) + assert.NoError(t, err) + assert.Equal(t, "postgres://foobar:password2@127.0.0.1:5432/foobar2?sslmode=disable", str) + }) - assert.Equal(t, nil, err) - assert.Equal(t, fmt.Sprintf("postgres://%s@host:5432/db", userAndPass), str) + t.Run("invalid url", func(t *testing.T) { + opts := command.Options{} + examples := []string{ + "postgre://foobar", + "tcp://blah", + "foobar", + } + + for _, val := range examples { + opts.URL = val + str, err := BuildStringFromOptions(opts) + + assert.Equal(t, "", str) + assert.Error(t, err) + assert.Equal(t, "Invalid URL. Valid format: postgres://user:password@host:port/db?sslmode=mode", err.Error()) + } + }) } -func Test_Port(t *testing.T) { - opts := command.Options{Host: "host", User: "user", Port: 5000, DbName: "db"} - str, err := BuildStringFromOptions(opts) +func TestFormatURL(t *testing.T) { + examples := []struct { + name string + input command.Options + result string + err string + }{ + { + name: "empty opts", + input: command.Options{}, + }, + { + name: "invalid url", + input: command.Options{URL: "barurl"}, + err: "Invalid URL", + }, + { + name: "good", + input: command.Options{ + URL: "postgres://user:pass@localhost:5432/dbname", + }, + result: "postgres://user:pass@localhost:5432/dbname?sslmode=disable", + }, + { + name: "password lookup, password set", + input: command.Options{ + URL: "postgres://username:@localhost:5432/dbname", + Passfile: "../../data/passfile", + }, + result: "postgres://username:password@localhost:5432/dbname?sslmode=disable", + }, + { + name: "password lookup, password not set", + input: command.Options{ + URL: "postgres://username@localhost:5432/dbname", + Passfile: "../../data/passfile", + }, + result: "postgres://username:password@localhost:5432/dbname?sslmode=disable", + }, + } + + for _, ex := range examples { + t.Run(ex.name, func(t *testing.T) { + str, err := FormatURL(ex.input) - assert.Equal(t, nil, err) - assert.Equal(t, "postgres://user:@host:5000/db", str) + if ex.err != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), ex.err) + } + assert.Equal(t, ex.result, str) + }) + } } -func Test_Blank(t *testing.T) { +func TestIsBlank(t *testing.T) { assert.Equal(t, true, IsBlank(command.Options{})) assert.Equal(t, false, IsBlank(command.Options{Host: "host", User: "user"})) assert.Equal(t, false, IsBlank(command.Options{Host: "host", User: "user", DbName: "db"}))