Skip to content

Commit

Permalink
tailscale: support reading & writing ACL content as HuJSON
Browse files Browse the repository at this point in the history
Updates tailscale/terraform-provider-tailscale#227

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
  • Loading branch information
knyar committed Feb 14, 2024
1 parent 6b54b06 commit ce83419
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
jobs:
test:
runs-on: ubuntu-latest
container: golang:1.19
container: golang:1.22
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/tailscale/tailscale-client-go

go 1.19
go 1.22.0

require (
github.com/stretchr/testify v1.8.4
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
Expand Down
168 changes: 128 additions & 40 deletions tailscale/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type (
)

const baseURL = "https://api.tailscale.com"
const contentType = "application/json"
const defaultContentType = "application/json"
const defaultHttpClientTimeout = time.Minute
const defaultUserAgent = "tailscale-client-go"

Expand Down Expand Up @@ -134,18 +134,57 @@ func WithUserAgent(ua string) ClientOption {
}
}

// TODO: consider setting `headers` and `body` via opts to decrease the number of arguments.
func (c *Client) buildRequest(ctx context.Context, method, uri string, headers map[string]string, body interface{}) (*http.Request, error) {
type requestParams struct {
headers map[string]string
body any
contentType string
}

type requestOption func(*requestParams)

func requestBody(body any) requestOption {
return func(rof *requestParams) {
rof.body = body
}
}

func requestHeaders(headers map[string]string) requestOption {
return func(rof *requestParams) {
rof.headers = headers
}
}

func requestContentType(ct string) requestOption {
return func(rof *requestParams) {
rof.contentType = ct
}
}

func (c *Client) buildRequest(ctx context.Context, method, uri string, opts ...requestOption) (*http.Request, error) {
rof := &requestParams{
contentType: defaultContentType,
}
for _, opt := range opts {
opt(rof)
}

u, err := c.baseURL.Parse(uri)
if err != nil {
return nil, err
}

var bodyBytes []byte
if body != nil {
bodyBytes, err = json.MarshalIndent(body, "", " ")
if err != nil {
return nil, err
if rof.body != nil {
switch body := rof.body.(type) {
case string:
bodyBytes = []byte(body)
case []byte:
bodyBytes = body
default:
bodyBytes, err = json.MarshalIndent(rof.body, "", " ")
if err != nil {
return nil, err
}
}
}

Expand All @@ -158,15 +197,15 @@ func (c *Client) buildRequest(ctx context.Context, method, uri string, headers m
req.Header.Set("User-Agent", c.userAgent)
}

for k, v := range headers {
for k, v := range rof.headers {
req.Header.Set(k, v)
}

switch {
case body == nil:
req.Header.Set("Accept", contentType)
case rof.body == nil:
req.Header.Set("Accept", rof.contentType)
default:
req.Header.Set("Content-Type", contentType)
req.Header.Set("Content-Type", rof.contentType)
}

// c.apiKey will not be set on the client was configured with WithOAuthClientCredentials()
Expand Down Expand Up @@ -197,6 +236,12 @@ func (c *Client) performRequest(req *http.Request, out interface{}) error {
return nil
}

// If we're expected to write result into a []byte, do not attempt to parse it.
if o, ok := out.(*[]byte); ok {
*o = bytes.Clone(body)
return nil
}

// If we've got hujson back, convert it to JSON, so we can natively parse it.
if !json.Valid(body) {
body, err = hujson.Standardize(body)
Expand Down Expand Up @@ -229,9 +274,9 @@ func (err APIError) Error() string {
func (c *Client) SetDNSSearchPaths(ctx context.Context, searchPaths []string) error {
const uriFmt = "/api/v2/tailnet/%v/dns/searchpaths"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), requestBody(map[string][]string{
"searchPaths": searchPaths,
})
}))
if err != nil {
return err
}
Expand All @@ -243,7 +288,7 @@ func (c *Client) SetDNSSearchPaths(ctx context.Context, searchPaths []string) er
func (c *Client) DNSSearchPaths(ctx context.Context) ([]string, error) {
const uriFmt = "/api/v2/tailnet/%v/dns/searchpaths"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet))
if err != nil {
return nil, err
}
Expand All @@ -261,9 +306,9 @@ func (c *Client) DNSSearchPaths(ctx context.Context) ([]string, error) {
func (c *Client) SetDNSNameservers(ctx context.Context, dns []string) error {
const uriFmt = "/api/v2/tailnet/%v/dns/nameservers"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), requestBody(map[string][]string{
"dns": dns,
})
}))
if err != nil {
return err
}
Expand All @@ -275,7 +320,7 @@ func (c *Client) SetDNSNameservers(ctx context.Context, dns []string) error {
func (c *Client) DNSNameservers(ctx context.Context) ([]string, error) {
const uriFmt = "/api/v2/tailnet/%v/dns/nameservers"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -389,7 +434,7 @@ type (
func (c *Client) ACL(ctx context.Context) (*ACL, error) {
const uriFmt = "/api/v2/tailnet/%s/acl"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet))
if err != nil {
return nil, err
}
Expand All @@ -402,6 +447,24 @@ func (c *Client) ACL(ctx context.Context) (*ACL, error) {
return &resp, nil
}

// RawACL retrieves the ACL that is currently set for the given tailnet
// as a HuJSON string.
func (c *Client) RawACL(ctx context.Context) (string, error) {
const uriFmt = "/api/v2/tailnet/%s/acl"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), requestContentType("application/hujson"))
if err != nil {
return "", err
}

var resp []byte
if err = c.performRequest(req, &resp); err != nil {
return "", err
}

return string(resp), nil
}

type setACLParams struct {
headers map[string]string
}
Expand All @@ -415,28 +478,53 @@ func WithETag(etag string) SetACLOption {
}
}

// SetACL sets the ACL for the given tailnet.
func (c *Client) SetACL(ctx context.Context, acl ACL, opts ...SetACLOption) error {
// SetACL sets the ACL for the given tailnet. `acl` can either be an [ACL],
// or a HuJSON string.
func (c *Client) SetACL(ctx context.Context, acl any, opts ...SetACLOption) error {
const uriFmt = "/api/v2/tailnet/%s/acl"

p := &setACLParams{headers: make(map[string]string)}
for _, opt := range opts {
opt(p)
}

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), p.headers, acl)
reqOpts := []requestOption{
requestHeaders(p.headers),
requestBody(acl),
}
switch v := acl.(type) {
case ACL:
case string:
reqOpts = append(reqOpts, requestContentType("application/hujson"))
default:
return fmt.Errorf("expected ACL content as a string or as ACL struct; got %T", v)
}

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), reqOpts...)
if err != nil {
return err
}

return c.performRequest(req, nil)
}

// ValidateACL validates the provided ACL via the API.
func (c *Client) ValidateACL(ctx context.Context, acl ACL) error {
// ValidateACL validates the provided ACL via the API. `acl` can either be an [ACL],
// or a HuJSON string.
func (c *Client) ValidateACL(ctx context.Context, acl any) error {
const uriFmt = "/api/v2/tailnet/%s/acl/validate"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, acl)
reqOpts := []requestOption{
requestBody(acl),
}
switch v := acl.(type) {
case ACL:
case string:
reqOpts = append(reqOpts, requestContentType("application/hujson"))
default:
return fmt.Errorf("expected ACL content as a string or as ACL struct; got %T", v)
}

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), reqOpts...)
if err != nil {
return err
}
Expand All @@ -460,7 +548,7 @@ type DNSPreferences struct {
func (c *Client) DNSPreferences(ctx context.Context) (*DNSPreferences, error) {
const uriFmt = "/api/v2/tailnet/%s/dns/preferences"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet))
if err != nil {
return nil, err
}
Expand All @@ -478,7 +566,7 @@ func (c *Client) DNSPreferences(ctx context.Context) (*DNSPreferences, error) {
func (c *Client) SetDNSPreferences(ctx context.Context, preferences DNSPreferences) error {
const uriFmt = "/api/v2/tailnet/%s/dns/preferences"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, preferences)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), requestBody(preferences))
if err != nil {
return nil
}
Expand All @@ -498,9 +586,9 @@ type (
func (c *Client) SetDeviceSubnetRoutes(ctx context.Context, deviceID string, routes []string) error {
const uriFmt = "/api/v2/device/%s/routes"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), requestBody(map[string][]string{
"routes": routes,
})
}))
if err != nil {
return err
}
Expand All @@ -514,7 +602,7 @@ func (c *Client) SetDeviceSubnetRoutes(ctx context.Context, deviceID string, rou
func (c *Client) DeviceSubnetRoutes(ctx context.Context, deviceID string) (*DeviceRoutes, error) {
const uriFmt = "/api/v2/device/%s/routes"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, deviceID), nil, nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, deviceID))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -576,7 +664,7 @@ type Device struct {
func (c *Client) Devices(ctx context.Context) ([]Device, error) {
const uriFmt = "/api/v2/tailnet/%s/devices"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet))
if err != nil {
return nil, err
}
Expand All @@ -598,9 +686,9 @@ func (c *Client) AuthorizeDevice(ctx context.Context, deviceID string) error {
func (c *Client) SetDeviceAuthorized(ctx context.Context, deviceID string, authorized bool) error {
const uriFmt = "/api/v2/device/%s/authorized"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string]bool{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), requestBody(map[string]bool{
"authorized": authorized,
})
}))
if err != nil {
return err
}
Expand All @@ -611,7 +699,7 @@ func (c *Client) SetDeviceAuthorized(ctx context.Context, deviceID string, autho
// DeleteDevice deletes the device given its deviceID.
func (c *Client) DeleteDevice(ctx context.Context, deviceID string) error {
const uriFmt = "/api/v2/device/%s"
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, deviceID), nil, nil)
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, deviceID))
if err != nil {
return err
}
Expand Down Expand Up @@ -686,7 +774,7 @@ func (c *Client) CreateKey(ctx context.Context, capabilities KeyCapabilities, op
}
}

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, ckr)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), requestBody(ckr))
if err != nil {
return Key{}, err
}
Expand All @@ -700,7 +788,7 @@ func (c *Client) CreateKey(ctx context.Context, capabilities KeyCapabilities, op
func (c *Client) GetKey(ctx context.Context, id string) (Key, error) {
const uriFmt = "/api/v2/tailnet/%s/keys/%s"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet, id), nil, nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet, id))
if err != nil {
return Key{}, err
}
Expand All @@ -714,7 +802,7 @@ func (c *Client) GetKey(ctx context.Context, id string) (Key, error) {
func (c *Client) Keys(ctx context.Context) ([]Key, error) {
const uriFmt = "/api/v2/tailnet/%s/keys"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet))
if err != nil {
return nil, err
}
Expand All @@ -731,7 +819,7 @@ func (c *Client) Keys(ctx context.Context) ([]Key, error) {
func (c *Client) DeleteKey(ctx context.Context, id string) error {
const uriFmt = "/api/v2/tailnet/%s/keys/%s"

req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, c.tailnet, id), nil, nil)
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, c.tailnet, id))
if err != nil {
return err
}
Expand All @@ -743,9 +831,9 @@ func (c *Client) DeleteKey(ctx context.Context, id string) error {
func (c *Client) SetDeviceTags(ctx context.Context, deviceID string, tags []string) error {
const uriFmt = "/api/v2/device/%s/tags"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), requestBody(map[string][]string{
"tags": tags,
})
}))
if err != nil {
return err
}
Expand All @@ -765,7 +853,7 @@ type (
func (c *Client) SetDeviceKey(ctx context.Context, deviceID string, key DeviceKey) error {
const uriFmt = "/api/v2/device/%s/key"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, key)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), requestBody(key))
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit ce83419

Please sign in to comment.