Skip to content

Commit

Permalink
bugfix: Fix data race in batchTimeSeries state
Browse files Browse the repository at this point in the history
Signed-off-by: Arthur Silva Sens <arthursens2005@gmail.com>
  • Loading branch information
ArthurSens committed Nov 25, 2024
1 parent 2052bd1 commit 910ad54
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 31 deletions.
6 changes: 3 additions & 3 deletions exporter/prometheusremotewriteexporter/exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ type prwExporter struct {
wal *prweWAL
exporterSettings prometheusremotewrite.Settings
telemetry prwTelemetry
batchTimeSeriesState batchTimeSeriesState

batchTimeSeriesState *batchTimeSeriesState
}

func newPRWTelemetry(set exporter.Settings) (prwTelemetry, error) {
Expand Down Expand Up @@ -191,7 +192,6 @@ func (prwe *prwExporter) PushMetrics(ctx context.Context, md pmetric.Metrics) er
case <-prwe.closeChan:
return errors.New("shutdown has been called")
default:

tsMap, err := prometheusremotewrite.FromMetrics(md, prwe.exporterSettings)
if err != nil {
prwe.telemetry.recordTranslationFailure(ctx)
Expand Down Expand Up @@ -229,7 +229,7 @@ func (prwe *prwExporter) handleExport(ctx context.Context, tsMap map[string]*pro
}

// Calls the helper function to convert and batch the TsMap to the desired format
requests, err := batchTimeSeries(tsMap, prwe.maxBatchSizeBytes, m, &prwe.batchTimeSeriesState)
requests, err := batchTimeSeries(tsMap, prwe.maxBatchSizeBytes, m, prwe.batchTimeSeriesState)
if err != nil {
return err
}
Expand Down
37 changes: 21 additions & 16 deletions exporter/prometheusremotewriteexporter/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,29 @@ import (
"errors"
"math"
"sort"
"sync/atomic"

"github.com/prometheus/prometheus/prompb"
)

type batchTimeSeriesState struct {
// Track batch sizes sent to avoid over allocating huge buffers.
// This helps in the case where large batches are sent to avoid allocating too much unused memory
nextTimeSeriesBufferSize int
nextMetricMetadataBufferSize int
nextRequestBufferSize int
nextTimeSeriesBufferSize atomic.Int64
nextMetricMetadataBufferSize atomic.Int64
nextRequestBufferSize atomic.Int64
}

func newBatchTimeSericesState() batchTimeSeriesState {
return batchTimeSeriesState{
nextTimeSeriesBufferSize: math.MaxInt,
nextMetricMetadataBufferSize: math.MaxInt,
nextRequestBufferSize: 0,
func newBatchTimeSericesState() *batchTimeSeriesState {
state := &batchTimeSeriesState{
nextTimeSeriesBufferSize: atomic.Int64{},
nextMetricMetadataBufferSize: atomic.Int64{},
nextRequestBufferSize: atomic.Int64{},
}
state.nextTimeSeriesBufferSize.Store(math.MaxInt64)
state.nextMetricMetadataBufferSize.Store(math.MaxInt64)
state.nextRequestBufferSize.Store(0)
return state
}

// batchTimeSeries splits series into multiple batch write requests.
Expand All @@ -34,22 +39,22 @@ func batchTimeSeries(tsMap map[string]*prompb.TimeSeries, maxBatchByteSize int,
}

// Allocate a buffer size of at least 10, or twice the last # of requests we sent
requests := make([]*prompb.WriteRequest, 0, max(10, state.nextRequestBufferSize))
requests := make([]*prompb.WriteRequest, 0, max(10, state.nextRequestBufferSize.Load()))

// Allocate a time series buffer 2x the last time series batch size or the length of the input if smaller
tsArray := make([]prompb.TimeSeries, 0, min(state.nextTimeSeriesBufferSize, len(tsMap)))
tsArray := make([]prompb.TimeSeries, 0, min(state.nextTimeSeriesBufferSize.Load(), int64(len(tsMap))))
sizeOfCurrentBatch := 0

i := 0
for _, v := range tsMap {
sizeOfSeries := v.Size()

if sizeOfCurrentBatch+sizeOfSeries >= maxBatchByteSize {
state.nextTimeSeriesBufferSize = max(10, 2*len(tsArray))
state.nextTimeSeriesBufferSize.Store(int64(max(10, 2*len(tsArray))))
wrapped := convertTimeseriesToRequest(tsArray)
requests = append(requests, wrapped)

tsArray = make([]prompb.TimeSeries, 0, min(state.nextTimeSeriesBufferSize, len(tsMap)-i))
tsArray = make([]prompb.TimeSeries, 0, min(state.nextTimeSeriesBufferSize.Load(), int64(len(tsMap)-i)))
sizeOfCurrentBatch = 0
}

Expand All @@ -64,18 +69,18 @@ func batchTimeSeries(tsMap map[string]*prompb.TimeSeries, maxBatchByteSize int,
}

// Allocate a metric metadata buffer 2x the last metric metadata batch size or the length of the input if smaller
mArray := make([]prompb.MetricMetadata, 0, min(state.nextMetricMetadataBufferSize, len(m)))
mArray := make([]prompb.MetricMetadata, 0, min(state.nextMetricMetadataBufferSize.Load(), int64(len(m))))
sizeOfCurrentBatch = 0
i = 0
for _, v := range m {
sizeOfM := v.Size()

if sizeOfCurrentBatch+sizeOfM >= maxBatchByteSize {
state.nextMetricMetadataBufferSize = max(10, 2*len(mArray))
state.nextMetricMetadataBufferSize.Store(int64(max(10, 2*len(mArray))))
wrapped := convertMetadataToRequest(mArray)
requests = append(requests, wrapped)

mArray = make([]prompb.MetricMetadata, 0, min(state.nextMetricMetadataBufferSize, len(m)-i))
mArray = make([]prompb.MetricMetadata, 0, min(state.nextMetricMetadataBufferSize.Load(), int64(len(m)-i)))
sizeOfCurrentBatch = 0
}

Expand All @@ -89,7 +94,7 @@ func batchTimeSeries(tsMap map[string]*prompb.TimeSeries, maxBatchByteSize int,
requests = append(requests, wrapped)
}

state.nextRequestBufferSize = 2 * len(requests)
state.nextRequestBufferSize.Store(int64(2 * len(requests)))
return requests, nil
}

Expand Down
24 changes: 12 additions & 12 deletions exporter/prometheusremotewriteexporter/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func Test_batchTimeSeries(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
state := newBatchTimeSericesState()
requests, err := batchTimeSeries(tt.tsMap, tt.maxBatchByteSize, nil, &state)
requests, err := batchTimeSeries(tt.tsMap, tt.maxBatchByteSize, nil, state)
if tt.returnErr {
assert.Error(t, err)
return
Expand All @@ -68,13 +68,13 @@ func Test_batchTimeSeries(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, requests, tt.numExpectedRequests)
if tt.numExpectedRequests <= 1 {
assert.Equal(t, math.MaxInt, state.nextTimeSeriesBufferSize)
assert.Equal(t, math.MaxInt, state.nextMetricMetadataBufferSize)
assert.Equal(t, 2*len(requests), state.nextRequestBufferSize)
assert.Equal(t, int64(math.MaxInt64), state.nextTimeSeriesBufferSize.Load())
assert.Equal(t, int64(math.MaxInt64), state.nextMetricMetadataBufferSize.Load())
assert.Equal(t, int64(2*len(requests)), state.nextRequestBufferSize.Load())
} else {
assert.Equal(t, max(10, len(requests[len(requests)-2].Timeseries)*2), state.nextTimeSeriesBufferSize)
assert.Equal(t, math.MaxInt, state.nextMetricMetadataBufferSize)
assert.Equal(t, 2*len(requests), state.nextRequestBufferSize)
assert.Equal(t, int64(max(10, len(requests[len(requests)-2].Timeseries)*2)), state.nextTimeSeriesBufferSize.Load())
assert.Equal(t, int64(math.MaxInt64), state.nextMetricMetadataBufferSize.Load())
assert.Equal(t, int64(2*len(requests)), state.nextRequestBufferSize.Load())
}
})
}
Expand All @@ -97,13 +97,13 @@ func Test_batchTimeSeriesUpdatesStateForLargeBatches(t *testing.T) {
tsMap1 := getTimeseriesMap(tsArray)

state := newBatchTimeSericesState()
requests, err := batchTimeSeries(tsMap1, 1000000, nil, &state)
requests, err := batchTimeSeries(tsMap1, 1000000, nil, state)

assert.NoError(t, err)
assert.Len(t, requests, 18)
assert.Equal(t, len(requests[len(requests)-2].Timeseries)*2, state.nextTimeSeriesBufferSize)
assert.Equal(t, math.MaxInt, state.nextMetricMetadataBufferSize)
assert.Equal(t, 36, state.nextRequestBufferSize)
assert.Equal(t, int64(len(requests[len(requests)-2].Timeseries)*2), state.nextTimeSeriesBufferSize.Load())
assert.Equal(t, int64(math.MaxInt64), state.nextMetricMetadataBufferSize.Load())
assert.Equal(t, int64(36), state.nextRequestBufferSize.Load())
}

// Benchmark_batchTimeSeries checks batchTimeSeries
Expand Down Expand Up @@ -132,7 +132,7 @@ func Benchmark_batchTimeSeries(b *testing.B) {
state := newBatchTimeSericesState()
// Run batchTimeSeries 100 times with a 1mb max request size
for i := 0; i < b.N; i++ {
requests, err := batchTimeSeries(tsMap1, 1000000, nil, &state)
requests, err := batchTimeSeries(tsMap1, 1000000, nil, state)
assert.NoError(b, err)
assert.Len(b, requests, 18)
}
Expand Down

0 comments on commit 910ad54

Please sign in to comment.