Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: lula validations in store don't persist validation state #795

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/pkg/common/oscal/assessment-results.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (

"github.com/defenseunicorns/go-oscal/src/pkg/uuid"
oscalTypes_1_1_2 "github.com/defenseunicorns/go-oscal/src/types/oscal-1-1-2"
"gopkg.in/yaml.v3"

"github.com/defenseunicorns/lula/src/config"
"github.com/defenseunicorns/lula/src/pkg/common"
"github.com/defenseunicorns/lula/src/pkg/common/result"
"github.com/defenseunicorns/lula/src/types"
"gopkg.in/yaml.v3"
)

const OSCAL_VERSION = "1.1.2"
Expand Down Expand Up @@ -295,12 +296,14 @@ func CreateObservation(method string, relevantEvidence *[]oscalTypes_1_1_2.Relev
Description: fmt.Sprintf(descriptionPattern, descriptionArgs...),
RelevantEvidence: relevantEvidence,
}
observation.Props = &[]oscalTypes_1_1_2.Property{
{
Name: "validation",
Ns: "https://docs.lula.dev/oscal/ns",
Value: common.AddIdPrefix(validation.UUID),
},
if validation != nil {
observation.Props = &[]oscalTypes_1_1_2.Property{
{
Name: "validation",
Ns: "https://docs.lula.dev/oscal/ns",
Value: common.AddIdPrefix(validation.UUID),
},
}
}
if resourcesHref != "" {
observation.Links = &[]oscalTypes_1_1_2.Link{
Expand Down
124 changes: 64 additions & 60 deletions src/pkg/common/validation-store/validation-store.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ import (

type ValidationStore struct {
backMatterMap map[string]string
validationMap types.LulaValidationMap
validationMap map[string]*types.LulaValidation
observationMap map[string]*oscalTypes_1_1_2.Observation
}

// NewValidationStore creates a new validation store
func NewValidationStore() *ValidationStore {
return &ValidationStore{
backMatterMap: make(map[string]string),
validationMap: make(types.LulaValidationMap),
validationMap: make(map[string]*types.LulaValidation),
observationMap: make(map[string]*oscalTypes_1_1_2.Observation),
}
}
Expand All @@ -34,7 +34,7 @@ func NewValidationStore() *ValidationStore {
func NewValidationStoreFromBackMatter(backMatter oscalTypes_1_1_2.BackMatter) *ValidationStore {
return &ValidationStore{
backMatterMap: oscal.BackMatterToMap(backMatter),
validationMap: make(types.LulaValidationMap),
validationMap: make(map[string]*types.LulaValidation),
observationMap: make(map[string]*oscalTypes_1_1_2.Observation),
}
}
Expand All @@ -54,7 +54,8 @@ func (v *ValidationStore) AddValidation(validation *common.Validation) (id strin
validation.Metadata.UUID = uuid.NewUUID()
}

v.validationMap[validation.Metadata.UUID], err = validation.ToLulaValidation(validation.Metadata.UUID)
lulaValidation, err := validation.ToLulaValidation(validation.Metadata.UUID)
v.validationMap[validation.Metadata.UUID] = &lulaValidation

if err != nil {
return "", err
Expand All @@ -66,23 +67,23 @@ func (v *ValidationStore) AddValidation(validation *common.Validation) (id strin
// AddLulaValidation adds a LulaValidation to the store
func (v *ValidationStore) AddLulaValidation(validation *types.LulaValidation, id string) {
trimmedId := common.TrimIdPrefix(id)
v.validationMap[trimmedId] = *validation
v.validationMap[trimmedId] = validation
}

// GetLulaValidation gets the LulaValidation from the store
func (v *ValidationStore) GetLulaValidation(id string) (validation *types.LulaValidation, err error) {
trimmedId := common.TrimIdPrefix(id)

if validation, ok := v.validationMap[trimmedId]; ok {
return &validation, nil
return validation, nil
}

if validationString, ok := v.backMatterMap[trimmedId]; ok {
lulaValidation, err := common.ValidationFromString(validationString, trimmedId)
if err != nil {
return &lulaValidation, err
}
v.validationMap[trimmedId] = lulaValidation
v.validationMap[trimmedId] = &lulaValidation
return &lulaValidation, nil
}

Expand All @@ -93,7 +94,7 @@ func (v *ValidationStore) GetLulaValidation(id string) (validation *types.LulaVa
func (v *ValidationStore) DryRun() (executable bool, msg string) {
executableValidations := make([]string, 0)
for k, val := range v.validationMap {
if val.Domain != nil {
if val != nil && val.Domain != nil {
if (*val.Domain).IsExecutable() {
executableValidations = append(executableValidations, k)
}
Expand All @@ -110,67 +111,70 @@ func (v *ValidationStore) RunValidations(ctx context.Context, confirmExecution,
observations := make([]oscalTypes_1_1_2.Observation, 0, len(v.validationMap))

for k, val := range v.validationMap {
completedText := "evaluated"
spinnerMessage := fmt.Sprintf("Running validation %s", k)
spinner := message.NewProgressSpinner("%s", spinnerMessage)
defer spinner.Stop()
err := val.Validate(ctx, types.ExecutionAllowed(confirmExecution))
if err != nil {
message.Debugf("Error running validation %s: %v", k, err)
// Update validation with failed results
val.Result.State = "not-satisfied"
val.Result.Observations = map[string]string{
"Error running validation": err.Error(),
if val != nil {
// Create observation for each non-nil validation
completedText := "evaluated"
spinnerMessage := fmt.Sprintf("Running validation %s", k)
spinner := message.NewProgressSpinner("%s", spinnerMessage)
defer spinner.Stop()
err := val.Validate(ctx, types.ExecutionAllowed(confirmExecution))
if err != nil {
message.Debugf("Error running validation %s: %v", k, err)
// Update validation with failed results
val.Result.State = "not-satisfied"
val.Result.Observations = map[string]string{
"Error running validation": err.Error(),
}
completedText = "NOT evaluated"
}
completedText = "NOT evaluated"
}

// Update individual result state
if val.Result.Passing > 0 && val.Result.Failing <= 0 {
val.Result.State = "satisfied"
} else {
val.Result.State = "not-satisfied"
}
// Update individual result state
if val.Result.Passing > 0 && val.Result.Failing <= 0 {
val.Result.State = "satisfied"
} else {
val.Result.State = "not-satisfied"
}

// Add the observation to the observation map
var remarks string
if len(val.Result.Observations) > 0 {
for k, v := range val.Result.Observations {
remarks += fmt.Sprintf("%s: %s\n", k, v)
// Add the observation to the observation map
var remarks string
if len(val.Result.Observations) > 0 {
for k, v := range val.Result.Observations {
remarks += fmt.Sprintf("%s: %s\n", k, v)
}
}
}

// Save Resources if specified
var resourceHref string
if saveResources {
resourceUuid := uuid.NewUUID()
// Create a remote resource file -> create directory 'resources' in the assessment-results directory -> create file with UUID as name
filename := fmt.Sprintf("%s.json", resourceUuid)
resourceFile := filepath.Join(resourcesDir, "resources", filename)
err := os.MkdirAll(filepath.Dir(resourceFile), os.ModePerm) // #nosec G301
if err != nil {
message.Debugf("Error creating directory for remote resource: %v", err)
// Save Resources if specified
var resourceHref string
if saveResources {
resourceUuid := uuid.NewUUID()
// Create a remote resource file -> create directory 'resources' in the assessment-results directory -> create file with UUID as name
filename := fmt.Sprintf("%s.json", resourceUuid)
resourceFile := filepath.Join(resourcesDir, "resources", filename)
err := os.MkdirAll(filepath.Dir(resourceFile), os.ModePerm) // #nosec G301
if err != nil {
message.Debugf("Error creating directory for remote resource: %v", err)
}
jsonData := val.GetDomainResourcesAsJSON()
err = files.WriteOutput(jsonData, resourceFile)
if err != nil {
message.Debugf("Error writing remote resource file: %v", err)
}
resourceHref = fmt.Sprintf("file://./resources/%s", filename)
}
jsonData := val.GetDomainResourcesAsJSON()
err = files.WriteOutput(jsonData, resourceFile)
if err != nil {
message.Debugf("Error writing remote resource file: %v", err)

// Create an observation
relevantEvidence := &[]oscalTypes_1_1_2.RelevantEvidence{
{
Description: fmt.Sprintf("Result: %s\n", val.Result.State),
Remarks: remarks,
},
}
resourceHref = fmt.Sprintf("file://./resources/%s", filename)
}
observation := oscal.CreateObservation("TEST", relevantEvidence, val, resourceHref, "[TEST]: %s - %s\n", k, val.Name)
v.observationMap[k] = &observation
observations = append(observations, observation)

// Create an observation
relevantEvidence := &[]oscalTypes_1_1_2.RelevantEvidence{
{
Description: fmt.Sprintf("Result: %s\n", val.Result.State),
Remarks: remarks,
},
spinner.Successf("%s -> %s -> %s", spinnerMessage, completedText, val.Result.State)
}
observation := oscal.CreateObservation("TEST", relevantEvidence, &val, resourceHref, "[TEST]: %s - %s\n", k, val.Name)
v.observationMap[k] = &observation
observations = append(observations, observation)

spinner.Successf("%s -> %s -> %s", spinnerMessage, completedText, val.Result.State)
}
return observations
}
Expand Down
21 changes: 20 additions & 1 deletion src/pkg/common/validation-store/validation-store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (

"github.com/defenseunicorns/go-oscal/src/pkg/uuid"
oscalTypes_1_1_2 "github.com/defenseunicorns/go-oscal/src/types/oscal-1-1-2"
"github.com/stretchr/testify/require"

"github.com/defenseunicorns/lula/src/pkg/common"
validationstore "github.com/defenseunicorns/lula/src/pkg/common/validation-store"
"github.com/defenseunicorns/lula/src/pkg/message"
"github.com/defenseunicorns/lula/src/types"
"github.com/stretchr/testify/require"
)

const (
Expand Down Expand Up @@ -173,6 +174,24 @@ func TestRunValidations(t *testing.T) {
}
})
}

// Test that validations are stored as pointers -> store data across accesses
t.Run("Validations store data", func(t *testing.T) {
// Create a new validation store, add the validation
validationUuid := uuid.NewUUID()
v := validationstore.NewValidationStore()
v.AddLulaValidation(validation, validationUuid)

// Run the validations on the store
_ = v.RunValidations(context.Background(), true, false, "")

// Check that the validation data has been stored in v, state should be satisfied
val, err := v.GetLulaValidation(validationUuid)
require.NoError(t, err)
require.NotNil(t, val)
require.NotNil(t, val.Result)
require.Equal(t, "satisfied", val.Result.State)
})
}

func TestGetRelatedObservation(t *testing.T) {
Expand Down
29 changes: 13 additions & 16 deletions src/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ func CreatePassingLulaValidation(name string) *LulaValidation {
}
}

// LulaValidationMap is a map of LulaValidation objects
type LulaValidationMap = map[string]LulaValidation

// Lula Validation Options settings
type lulaValidationOptions struct {
staticResources DomainResources
Expand Down Expand Up @@ -117,16 +114,16 @@ func GetResourcesOnly(onlyResources bool) LulaValidationOption {
}

// Perform the validation, and store the result in the LulaValidation struct
func (val *LulaValidation) Validate(ctx context.Context, opts ...LulaValidationOption) error {
if !val.Evaluated {
func (v *LulaValidation) Validate(ctx context.Context, opts ...LulaValidationOption) error {
if !v.Evaluated {
var result Result
var err error
var resources DomainResources

// Update the validation
val.DomainResources = &resources
val.Result = &result
val.Evaluated = true
v.DomainResources = &resources
v.Result = &result
v.Evaluated = true

// Set Validation config from options passed
config := &lulaValidationOptions{
Expand All @@ -141,7 +138,7 @@ func (val *LulaValidation) Validate(ctx context.Context, opts ...LulaValidationO
}

// Check if confirmation needed before execution
if (*val.Domain).IsExecutable() && config.staticResources == nil {
if (*v.Domain).IsExecutable() && config.staticResources == nil {
if !config.executionAllowed {
if config.isInteractive {
// Run confirmation user prompt
Expand All @@ -158,7 +155,7 @@ func (val *LulaValidation) Validate(ctx context.Context, opts ...LulaValidationO
if config.staticResources != nil {
resources = config.staticResources
} else {
resources, err = (*val.Domain).GetResources(ctx)
resources, err = (*v.Domain).GetResources(ctx)
if err != nil {
return fmt.Errorf("%w: %v", ErrDomainGetResources, err)
}
Expand All @@ -168,7 +165,7 @@ func (val *LulaValidation) Validate(ctx context.Context, opts ...LulaValidationO
}

// Perform the evaluation using the provider
result, err = (*val.Provider).Evaluate(resources)
result, err = (*v.Provider).Evaluate(resources)
if err != nil {
return fmt.Errorf("%w: %v", ErrProviderEvaluate, err)
}
Expand All @@ -177,16 +174,16 @@ func (val *LulaValidation) Validate(ctx context.Context, opts ...LulaValidationO
}

// Check if the validation requires confirmation before possible execution code is run
func (val *LulaValidation) RequireExecutionConfirmation() (confirm bool) {
return !(*val.Domain).IsExecutable()
func (v *LulaValidation) RequireExecutionConfirmation() (confirm bool) {
return !(*v.Domain).IsExecutable()
}

// Return domain resources as a json []byte
func (val *LulaValidation) GetDomainResourcesAsJSON() []byte {
if val.DomainResources == nil {
func (v *LulaValidation) GetDomainResourcesAsJSON() []byte {
if v.DomainResources == nil {
return []byte("{}")
}
jsonData, err := json.MarshalIndent(val.DomainResources, "", " ")
jsonData, err := json.MarshalIndent(v.DomainResources, "", " ")
if err != nil {
message.Debugf("Error marshalling domain resources to JSON: %v", err)
jsonData = []byte(`{"Error": "Error marshalling to JSON"}`)
Expand Down
Loading