Skip to content

Commit

Permalink
add: get sagemaker-user-pool (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaroniscode authored Sep 3, 2024
1 parent 5d813ea commit d8fb5f5
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 10 deletions.
2 changes: 2 additions & 0 deletions cmd/get/sagemaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package get
import (
"github.com/awslabs/eksdemo/pkg/resource"
"github.com/awslabs/eksdemo/pkg/resource/sagemaker/domain"
"github.com/awslabs/eksdemo/pkg/resource/sagemaker/userprofile"
"github.com/spf13/cobra"
)

Expand All @@ -28,5 +29,6 @@ func NewGetSageMakerCmd() *cobra.Command {
func init() {
sagemaker = []func() *resource.Resource{
domain.New,
userprofile.New,
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/organizations v1.25.1
github.com/aws/aws-sdk-go-v2/service/route53 v1.40.1
github.com/aws/aws-sdk-go-v2/service/s3 v1.51.1
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.154.0
github.com/aws/aws-sdk-go-v2/service/sqs v1.31.1
github.com/aws/aws-sdk-go-v2/service/ssm v1.49.1
github.com/aws/aws-sdk-go-v2/service/sts v1.28.1
Expand Down Expand Up @@ -76,7 +77,6 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.154.0 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.20.1 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
Expand Down
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 h1:4daAzAu0
github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg=
github.com/aws/aws-sdk-go v1.34.9 h1:cUGBW9CVdi0mS7K1hDzxIqTpfeWhpoQiguq81M1tjK0=
github.com/aws/aws-sdk-go v1.34.9/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0=
github.com/aws/aws-sdk-go-v2 v1.25.2 h1:/uiG1avJRgLGiQM9X3qJM8+Qa6KRGK5rRPuXE0HUM+w=
github.com/aws/aws-sdk-go-v2 v1.25.2/go.mod h1:Evoc5AsmtveRt1komDwIsjHFyrP5tDuF1D1U+6z6pNo=
github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8=
github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 h1:gTK2uhtAPtFcdRRJilZPx8uJLL2J85xK11nKtWL0wfU=
Expand All @@ -83,12 +81,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.17.4 h1:h5Vztbd8qLppiPwX+y0Q6WiwMZgp
github.com/aws/aws-sdk-go-v2/credentials v1.17.4/go.mod h1:+30tpwrkOgvkJL1rUZuRLoxcJwtI/OkeBLYnHxJtVe0=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.2 h1:AK0J8iYBFeUk2Ax7O8YpLtFsfhdOByh2QIkHmigpRYk=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.2/go.mod h1:iRlGzMix0SExQEviAyptRWRGdYNo3+ufW/lCzvKVTUc=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.2 h1:bNo4LagzUKbjdxE0tIcR9pMzLR2U/Tgie1Hq1HQ3iH8=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.2/go.mod h1:wRQv0nN6v9wDXuWThpovGQjqF1HFdcgWjporw14lS8k=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 h1:TNyt/+X43KJ9IJJMjKfa3bNTiZbUP7DeCxfbTROESwY=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16/go.mod h1:2DwJF39FlNAUiX5pAc0UNeiz16lK2t7IaFcm0LFHEgc=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.2 h1:EtOU5jsPdIQNP+6Q2C5e3d65NKT1PeCiQk+9OdzO12Q=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.2/go.mod h1:tyF5sKccmDz0Bv4NrstEr+/9YkSPJHrcO7UsUKf7pWM=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY0L1/KftReOGxI/4NtVSTh9O/I=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16/go.mod h1:7ZfEPZxkW42Afq4uQB8H2E2e6ebh6mXTueEpYzjCzcs=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
Expand Down Expand Up @@ -159,8 +153,6 @@ github.com/aws/aws-sdk-go-v2/service/vpclattice v1.7.1 h1:004A7QJBNhZjD9M3YZ9td+
github.com/aws/aws-sdk-go-v2/service/vpclattice v1.7.1/go.mod h1:Hk6AN73+u06HX+Ggnv/dO5clORF+vggub+955WHOGzQ=
github.com/aws/session-manager-plugin v0.0.0-20221012155945-c523002ee02c h1:6cCrrTmS+7B+saEBhMnNblArJpA7BNmjd9F6MUHS6sQ=
github.com/aws/session-manager-plugin v0.0.0-20221012155945-c523002ee02c/go.mod h1:7n17tunRPUsniNBu5Ja9C7WwJWTdOzaLqr/H0Ns3uuI=
github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw=
github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4=
github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
Expand Down
2 changes: 1 addition & 1 deletion pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func FormatError(err error) error {
func FormatErrorAsMessageOnly(err error) error {
var ae smithy.APIError
if err != nil && errors.As(err, &ae) {
return fmt.Errorf(ae.ErrorMessage())
return fmt.Errorf("%s", ae.ErrorMessage())
}
return err
}
33 changes: 33 additions & 0 deletions pkg/aws/sagemaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ func (c *SageMakerClient) DescribeDomain(id string) (*sagemaker.DescribeDomainOu
return result, nil
}

func (c *SageMakerClient) DescribeUserProfile(domainID, userProfileName string) (*sagemaker.DescribeUserProfileOutput, error) {
result, err := c.Client.DescribeUserProfile(context.Background(), &sagemaker.DescribeUserProfileInput{
DomainId: aws.String(domainID),
UserProfileName: aws.String(userProfileName),
})

if err != nil {
return nil, err
}

return result, nil
}

func (c *SageMakerClient) ListDomains() ([]types.DomainDetails, error) {
result, err := c.Client.ListDomains(context.Background(), &sagemaker.ListDomainsInput{})

Expand All @@ -37,3 +50,23 @@ func (c *SageMakerClient) ListDomains() ([]types.DomainDetails, error) {

return result.Domains, nil
}

func (c *SageMakerClient) ListUserProfiles(domainID, userProfileNameContains string) ([]types.UserProfileDetails, error) {
input := sagemaker.ListUserProfilesInput{}

if domainID != "" {
input.DomainIdEquals = aws.String(domainID)
}

if userProfileNameContains != "" {
input.UserProfileNameContains = aws.String(userProfileNameContains)
}

result, err := c.Client.ListUserProfiles(context.Background(), &input)

if err != nil {
return nil, err
}

return result.UserProfiles, nil
}
116 changes: 116 additions & 0 deletions pkg/resource/sagemaker/userprofile/get.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package userprofile

import (
"fmt"
"os"
"strings"

awssdk "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sagemaker"
"github.com/aws/aws-sdk-go-v2/service/sagemaker/types"
"github.com/awslabs/eksdemo/pkg/aws"
"github.com/awslabs/eksdemo/pkg/printer"
"github.com/awslabs/eksdemo/pkg/resource"
)

type Getter struct {
sagemakerClient *aws.SageMakerClient
}

func NewGetter(sagemakerClient *aws.SageMakerClient) *Getter {
return &Getter{sagemakerClient}
}

func (g *Getter) Init() {
if g.sagemakerClient == nil {
g.sagemakerClient = aws.NewSageMakerClient()
}
}

func (g *Getter) Get(userProfileName string, output printer.Output, o resource.Options) error {
options, ok := o.(*Options)
if !ok {
return fmt.Errorf("internal error, unable to cast options to domain.Options")
}

var userProfile *sagemaker.DescribeUserProfileOutput
var userProfiles []*sagemaker.DescribeUserProfileOutput
var err error

switch {
case userProfileName != "":
userProfile, err = g.GetUserProfileByName(userProfileName)
userProfiles = []*sagemaker.DescribeUserProfileOutput{userProfile}
case options.DomainID != "":
userProfiles, err = g.GetUserProfilesByDomainID(options.DomainID)
default:
userProfiles, err = g.GetAllUserProfiles()
}

if err != nil {
return err
}

return output.Print(os.Stdout, NewPrinter(userProfiles))
}

func (g *Getter) GetAllUserProfiles() ([]*sagemaker.DescribeUserProfileOutput, error) {
return g.getUserProfiles("", "")
}

func (g *Getter) GetUserProfilesByDomainID(domainID string) ([]*sagemaker.DescribeUserProfileOutput, error) {
return g.getUserProfiles(domainID, "")

}

func (g *Getter) GetUserProfileByName(userProfileName string) (*sagemaker.DescribeUserProfileOutput, error) {
profileDetails, err := g.sagemakerClient.ListUserProfiles("", "")
if err != nil {
return nil, err
}

found := []types.UserProfileDetails{}

for _, up := range profileDetails {
if strings.EqualFold(userProfileName, awssdk.ToString(up.UserProfileName)) {
found = append(found, up)
}
}

if len(found) == 0 {
return nil, &resource.NotFoundByNameError{Type: "sagemaker-user-profile", Name: userProfileName}
}

if len(found) > 1 {
return nil, fmt.Errorf("multiple sagemaker user profiles found with name: %s", userProfileName)
}

userProfile, err := g.sagemakerClient.DescribeUserProfile(awssdk.ToString(found[0].DomainId), userProfileName)
if err != nil {
return nil, err
}

return userProfile, nil
}

func (g *Getter) getUserProfiles(domainID, userProfileNameContains string) ([]*sagemaker.DescribeUserProfileOutput, error) {
profileDetails, err := g.sagemakerClient.ListUserProfiles(domainID, userProfileNameContains)
if err != nil {
return nil, err
}

userprofiles := make([]*sagemaker.DescribeUserProfileOutput, 0, len(profileDetails))

for _, up := range profileDetails {
result, err := g.sagemakerClient.DescribeUserProfile(
awssdk.ToString(up.DomainId),
awssdk.ToString(up.UserProfileName),
)
if err != nil {
return nil, err
}
userprofiles = append(userprofiles, result)
}

return userprofiles, nil
}
42 changes: 42 additions & 0 deletions pkg/resource/sagemaker/userprofile/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package userprofile

import (
"github.com/awslabs/eksdemo/pkg/cmd"
"github.com/awslabs/eksdemo/pkg/resource"
"github.com/spf13/cobra"
)

type Options struct {
resource.CommonOptions

// Get
DomainID string
}

func newOptions() (options *Options, getFlags cmd.Flags) {
options = &Options{
CommonOptions: resource.CommonOptions{
Name: "sagemaker-domain",
ClusterFlagDisabled: true,
},
}

getFlags = cmd.Flags{
&cmd.StringFlag{
CommandFlag: cmd.CommandFlag{
Name: "domain-id",
Description: "id of the sagemaker domain",
Shorthand: "D",
Validate: func(_ *cobra.Command, args []string) error {
if len(args) > 0 && options.DomainID != "" {
return &cmd.ArgumentAndFlagCantBeUsedTogetherError{Arg: "USER_PROFILE_NAME", Flag: "domain-id"}
}
return nil
},
},
Option: &options.DomainID,
},
}

return
}
47 changes: 47 additions & 0 deletions pkg/resource/sagemaker/userprofile/printer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package userprofile

import (
"io"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sagemaker"
"github.com/awslabs/eksdemo/pkg/printer"
"github.com/hako/durafmt"
)

type Printer struct {
userProfiles []*sagemaker.DescribeUserProfileOutput
}

func NewPrinter(userProfiles []*sagemaker.DescribeUserProfileOutput) *Printer {
return &Printer{userProfiles}
}

func (p *Printer) PrintTable(writer io.Writer) error {
table := printer.NewTablePrinter()
table.SetHeader([]string{"Age", "Status", "User Profile", "Domain Id"})

for _, up := range p.userProfiles {
age := durafmt.ParseShort(time.Since(aws.ToTime(up.CreationTime)))

table.AppendRow([]string{
age.String(),
string(up.Status),
aws.ToString(up.UserProfileName),
aws.ToString(up.DomainId),
})
}

table.Print(writer)

return nil
}

func (p *Printer) PrintJSON(writer io.Writer) error {
return printer.EncodeJSON(writer, p.userProfiles)
}

func (p *Printer) PrintYAML(writer io.Writer) error {
return printer.EncodeYAML(writer, p.userProfiles)
}
25 changes: 25 additions & 0 deletions pkg/resource/sagemaker/userprofile/user_profile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package userprofile

import (
"github.com/awslabs/eksdemo/pkg/cmd"
"github.com/awslabs/eksdemo/pkg/resource"
)

func New() *resource.Resource {
options, getFlags := newOptions()

return &resource.Resource{
Command: cmd.Command{
Name: "user-profile",
Description: "SageMaker User Profile",
Aliases: []string{"user-profiles", "userprofiles", "userprofile", "up"},
Args: []string{"USER_PROFILE_NAME"},
},

GetFlags: getFlags,

Getter: &Getter{},

Options: options,
}
}

0 comments on commit d8fb5f5

Please sign in to comment.