Skip to content

Commit

Permalink
client: add client opts to enable system certificates
Browse files Browse the repository at this point in the history
Additionally, this splits out WithCredentials into separate methods:

- WithCredentials configures the client credentials presented to the
  server
- WithServerConfig configures the name and ca certificate to check the
  server's certificate against
- WithServerConfigSystem configures the name to check the server's
  certificate against - the ca certificate is automatically pulled from
  the system store.

Signed-off-by: Justin Chadwell <me@jedevc.com>
  • Loading branch information
jedevc committed Mar 30, 2023
1 parent 8b7bcb9 commit 8f66706
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 36 deletions.
135 changes: 101 additions & 34 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(defaults.DefaultMaxSendMsgSize)),
}
needDialer := true
needWithInsecure := true
tlsServerName := ""

var unary []grpc.UnaryClientInterceptor
var stream []grpc.StreamClientInterceptor
Expand All @@ -56,19 +54,17 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
var tracerDelegate TracerDelegate
var sessionDialer func(context.Context, string, map[string][]string) (net.Conn, error)
var customDialOptions []grpc.DialOption
var creds *withCredentials

for _, o := range opts {
if _, ok := o.(*withFailFast); ok {
gopts = append(gopts, grpc.FailOnNonTempDialError(true))
}
if credInfo, ok := o.(*withCredentials); ok {
opt, err := loadCredentials(credInfo)
if err != nil {
return nil, err
if creds == nil {
creds = &withCredentials{}
}
gopts = append(gopts, opt)
needWithInsecure = false
tlsServerName = credInfo.ServerName
creds = creds.merge(credInfo)
}
if wt, ok := o.(*withTracer); ok {
customTracer = true
Expand All @@ -89,6 +85,16 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
}
}

if creds == nil {
gopts = append(gopts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
credOpts, err := loadCredentials(creds)
if err != nil {
return nil, err
}
gopts = append(gopts, credOpts)
}

if !customTracer {
if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
tracerProvider = span.TracerProvider()
Expand All @@ -108,9 +114,6 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
}
gopts = append(gopts, grpc.WithContextDialer(dialFn))
}
if needWithInsecure {
gopts = append(gopts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
if address == "" {
address = appdefaults.Address
}
Expand All @@ -122,7 +125,10 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
// ref: https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.3
// - However, when TLS specified, grpc-go requires it must match
// with its servername specified for certificate validation.
authority := tlsServerName
var authority string
if creds != nil && creds.serverName != "" {
authority = creds.serverName
}
if authority == "" {
// authority as hostname from target address
uri, err := url.Parse(address)
Expand Down Expand Up @@ -201,47 +207,108 @@ func WithContextDialer(df func(context.Context, string) (net.Conn, error)) Clien
}

type withCredentials struct {
ServerName string
CACert string
Cert string
Key string
// server options
serverName string
caCert string
caCertSystem bool

// client options
cert string
key string
}

func (opts *withCredentials) merge(opts2 *withCredentials) *withCredentials {
result := *opts
if opts2 == nil {
return &result
}

// server options
if opts2.serverName != "" {
result.serverName = opts2.serverName
}
if opts2.caCert != "" {
result.caCert = opts2.caCert
}
if opts2.caCertSystem {
result.caCertSystem = opts2.caCertSystem
}

// client options
if opts2.cert != "" {
result.cert = opts2.cert
}
if opts2.key != "" {
result.key = opts2.key
}

return &result
}

func (*withCredentials) isClientOpt() {}

// WithCredentials configures the TLS parameters of the client.
// Arguments:
// * serverName: specifies the name of the target server
// * ca: specifies the filepath of the CA certificate to use for verification
// * cert: specifies the filepath of the client certificate
// * key: specifies the filepath of the client key
func WithCredentials(serverName, ca, cert, key string) ClientOpt {
return &withCredentials{serverName, ca, cert, key}
// * cert: specifies the filepath of the client certificate
// * key: specifies the filepath of the client key
func WithCredentials(cert, key string) ClientOpt {
return &withCredentials{
cert: cert,
key: key,
}
}

// WithServerConfig configures the TLS parameters to connect to the server.
// Arguments:
// * serverName: specifies the server name to verify the hostname
// * caCert: specifies the filepath of the CA certificate
func WithServerConfig(serverName, caCert string) ClientOpt {
return &withCredentials{
serverName: serverName,
caCert: caCert,
}
}

// WithServerConfigSystem configures the TLS parameters to connect to the
// server, using the system's certificate pool.
func WithServerConfigSystem(serverName string) ClientOpt {
return &withCredentials{
serverName: serverName,
caCertSystem: true,
}
}

func loadCredentials(opts *withCredentials) (grpc.DialOption, error) {
ca, err := os.ReadFile(opts.CACert)
if err != nil {
return nil, errors.Wrap(err, "could not read ca certificate")
cfg := &tls.Config{}

if opts.caCertSystem {
cfg.RootCAs, _ = x509.SystemCertPool()
}
if cfg.RootCAs == nil {
cfg.RootCAs = x509.NewCertPool()
}

certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
if opts.caCert != "" {
ca, err := os.ReadFile(opts.caCert)
if err != nil {
return nil, errors.Wrap(err, "could not read ca certificate")
}
if ok := cfg.RootCAs.AppendCertsFromPEM(ca); !ok {
return nil, errors.New("failed to append ca certs")
}
}

cfg := &tls.Config{
ServerName: opts.ServerName,
RootCAs: certPool,
if opts.serverName != "" {
cfg.ServerName = opts.serverName
}

// we will produce an error if the user forgot about either cert or key if at least one is specified
if opts.Cert != "" || opts.Key != "" {
cert, err := tls.LoadX509KeyPair(opts.Cert, opts.Key)
if opts.cert != "" || opts.key != "" {
cert, err := tls.LoadX509KeyPair(opts.cert, opts.key)
if err != nil {
return nil, errors.Wrap(err, "could not read certificate/key")
}
cfg.Certificates = []tls.Certificate{cert}
cfg.Certificates = append(cfg.Certificates, cert)
}

return grpc.WithTransportCredentials(credentials.NewTLS(cfg)), nil
Expand Down
7 changes: 5 additions & 2 deletions cmd/buildctl/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ func ResolveClient(c *cli.Context) (*client.Client, error) {
}
}

if caCert != "" || cert != "" || key != "" {
opts = append(opts, client.WithCredentials(serverName, caCert, cert, key))
if caCert != "" {
opts = append(opts, client.WithServerConfig(serverName, caCert))
}
if cert != "" || key != "" {
opts = append(opts, client.WithCredentials(cert, key))
}

timeout := time.Duration(c.GlobalInt("timeout"))
Expand Down

0 comments on commit 8f66706

Please sign in to comment.