From 4c71fd468705ecefbf8ebf65f01d06255f819550 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 30 Oct 2024 01:42:08 -0700 Subject: [PATCH] Add "newtime" directive to use official messagepack time format (#378) This adds `msgp:newtime` file directive that will encode all time fields using the -1 extension as defined in the [(revised) messagepack spec](/~https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type) ReadTime/ReadTimeBytes will now support both types natively, and will accept either as input. Extensions should remain unaffected. Fixes #300 --- _generated/newtime.go | 36 ++++++++++ _generated/newtime_test.go | 130 +++++++++++++++++++++++++++++++++++++ gen/encode.go | 3 + gen/marshal.go | 4 ++ gen/spec.go | 8 ++- msgp/errors.go | 25 +++++++ msgp/extension.go | 42 ++++++++---- msgp/json_bytes.go | 4 +- msgp/read.go | 81 +++++++++++++++++++---- msgp/read_bytes.go | 55 +++++++++++++--- msgp/write.go | 43 ++++++++++++ msgp/write_bytes.go | 35 ++++++++++ msgp/write_bytes_test.go | 10 +++ parse/directives.go | 7 ++ parse/getast.go | 4 +- 15 files changed, 446 insertions(+), 41 deletions(-) create mode 100644 _generated/newtime.go create mode 100644 _generated/newtime_test.go diff --git a/_generated/newtime.go b/_generated/newtime.go new file mode 100644 index 00000000..4c3ed9f9 --- /dev/null +++ b/_generated/newtime.go @@ -0,0 +1,36 @@ +package _generated + +import "time" + +//go:generate msgp -v + +//msgp:newtime + +type NewTime struct { + T time.Time + Array []time.Time + Map map[string]time.Time +} + +func (t1 NewTime) Equal(t2 NewTime) bool { + if !t1.T.Equal(t2.T) { + return false + } + if len(t1.Array) != len(t2.Array) { + return false + } + for i := range t1.Array { + if !t1.Array[i].Equal(t2.Array[i]) { + return false + } + } + if len(t1.Map) != len(t2.Map) { + return false + } + for k, v := range t1.Map { + if !t2.Map[k].Equal(v) { + return false + } + } + return true +} diff --git a/_generated/newtime_test.go b/_generated/newtime_test.go new file mode 100644 index 00000000..2e44fabc --- /dev/null +++ b/_generated/newtime_test.go @@ -0,0 +1,130 @@ +package _generated + +import ( + "bytes" + "math/rand" + "testing" + "time" + + "github.com/tinylib/msgp/msgp" +) + +func TestNewTime(t *testing.T) { + value := NewTime{ + T: time.Now().UTC(), + Array: []time.Time{time.Now().UTC(), time.Now().UTC()}, + Map: map[string]time.Time{ + "a": time.Now().UTC(), + }, + } + encoded, err := value.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + checkExtMinusOne(t, encoded) + var got NewTime + _, err = got.UnmarshalMsg(encoded) + if err != nil { + t.Fatal(err) + } + if !value.Equal(got) { + t.Errorf("UnmarshalMsg got %v want %v", value, got) + } + + var buf bytes.Buffer + w := msgp.NewWriter(&buf) + err = value.EncodeMsg(w) + if err != nil { + t.Fatal(err) + } + w.Flush() + checkExtMinusOne(t, buf.Bytes()) + + got = NewTime{} + r := msgp.NewReader(&buf) + err = got.DecodeMsg(r) + if err != nil { + t.Fatal(err) + } + if !value.Equal(got) { + t.Errorf("DecodeMsg got %v want %v", value, got) + } +} + +func checkExtMinusOne(t *testing.T, b []byte) { + r := msgp.NewReader(bytes.NewBuffer(b)) + _, err := r.ReadMapHeader() + if err != nil { + t.Fatal(err) + } + key, err := r.ReadMapKey(nil) + if err != nil { + t.Fatal(err) + } + for !bytes.Equal(key, []byte("T")) { + key, err = r.ReadMapKey(nil) + if err != nil { + t.Fatal(err) + } + } + n, _, err := r.ReadExtensionRaw() + if err != nil { + t.Fatal(err) + } + if n != -1 { + t.Fatalf("got %v want -1", n) + } + t.Log("Was -1 extension") +} + +func TestNewTimeRandom(t *testing.T) { + rng := rand.New(rand.NewSource(0)) + runs := int(1e6) + if testing.Short() { + runs = 1e4 + } + for i := 0; i < runs; i++ { + nanos := rng.Int63n(999999999 + 1) + secs := rng.Uint64() + // Tweak the distribution, so we get more than average number of + // length 4 and 8 timestamps. + if rng.Intn(5) == 0 { + secs %= uint64(time.Now().Unix()) + if rng.Intn(2) == 0 { + nanos = 0 + } + } + + value := NewTime{ + T: time.Unix(int64(secs), nanos), + } + encoded, err := value.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + var got NewTime + _, err = got.UnmarshalMsg(encoded) + if err != nil { + t.Fatal(err) + } + if !value.Equal(got) { + t.Fatalf("UnmarshalMsg got %v want %v", value, got) + } + var buf bytes.Buffer + w := msgp.NewWriter(&buf) + err = value.EncodeMsg(w) + if err != nil { + t.Fatal(err) + } + w.Flush() + got = NewTime{} + r := msgp.NewReader(&buf) + err = got.DecodeMsg(r) + if err != nil { + t.Fatal(err) + } + if !value.Equal(got) { + t.Fatalf("DecodeMsg got %v want %v", value, got) + } + } +} diff --git a/gen/encode.go b/gen/encode.go index 680d7bd0..92aa7681 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -30,6 +30,9 @@ func (e *encodeGen) writeAndCheck(typ string, argfmt string, arg interface{}) { if e.ctx.compFloats && typ == "Float64" { typ = "Float" } + if e.ctx.newTime && typ == "Time" { + typ = "TimeExt" + } e.p.printf("\nerr = en.Write%s(%s)", typ, fmt.Sprintf(argfmt, arg)) e.p.wrapErrCheck(e.ctx.ArgsStr()) diff --git a/gen/marshal.go b/gen/marshal.go index 6fb95ec6..bb5cdbc6 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -65,6 +65,10 @@ func (m *marshalGen) rawAppend(typ string, argfmt string, arg interface{}) { if m.ctx.compFloats && typ == "Float64" { typ = "Float" } + if m.ctx.newTime && typ == "Time" { + typ = "TimeExt" + } + m.p.printf("\no = msgp.Append%s(o, %s)", typ, fmt.Sprintf(argfmt, arg)) } diff --git a/gen/spec.go b/gen/spec.go index 36af0427..ba31ff39 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -79,6 +79,7 @@ type Printer struct { gens []generator CompactFloats bool ClearOmitted bool + NewTime bool } func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer { @@ -148,7 +149,11 @@ func (p *Printer) Print(e Elem) error { // collisions between idents created during SetVarname and idents created during Print, // hence the separate prefixes. resetIdent("zb") - err := g.Execute(e, Context{compFloats: p.CompactFloats, clearOmitted: p.ClearOmitted}) + err := g.Execute(e, Context{ + compFloats: p.CompactFloats, + clearOmitted: p.ClearOmitted, + newTime: p.NewTime, + }) resetIdent("za") if err != nil { @@ -178,6 +183,7 @@ type Context struct { path []contextItem compFloats bool clearOmitted bool + newTime bool } func (c *Context) PushString(s string) { diff --git a/msgp/errors.go b/msgp/errors.go index 984cca32..e6b42b68 100644 --- a/msgp/errors.go +++ b/msgp/errors.go @@ -212,6 +212,31 @@ func (u UintOverflow) Resumable() bool { return true } func (u UintOverflow) withContext(ctx string) error { u.ctx = addCtx(u.ctx, ctx); return u } +// InvalidTimestamp is returned when an invalid timestamp is encountered +type InvalidTimestamp struct { + Nanos int64 // value of the nano, if invalid + FieldLength int // Unexpected field length. + ctx string +} + +// Error implements the error interface +func (u InvalidTimestamp) Error() (str string) { + if u.Nanos > 0 { + str = "msgp: timestamp nanosecond field value " + strconv.FormatInt(u.Nanos, 10) + " exceeds maximum allows of 999999999" + } else if u.FieldLength >= 0 { + str = "msgp: invalid timestamp field length " + strconv.FormatInt(int64(u.FieldLength), 10) + " - must be 4, 8 or 12" + } + if u.ctx != "" { + str += " at " + u.ctx + } + return str +} + +// Resumable is always 'true' for overflows +func (u InvalidTimestamp) Resumable() bool { return true } + +func (u InvalidTimestamp) withContext(ctx string) error { u.ctx = addCtx(u.ctx, ctx); return u } + // UintBelowZero is returned when a call // would cast a signed integer below zero // to an unsigned integer. diff --git a/msgp/extension.go b/msgp/extension.go index 5f762473..cda71c98 100644 --- a/msgp/extension.go +++ b/msgp/extension.go @@ -15,8 +15,15 @@ const ( // TimeExtension is the extension number used for time.Time TimeExtension = 5 + + // MsgTimeExtension is the extension number for timestamps as defined in + // /~https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type + MsgTimeExtension = -1 ) +// msgTimeExtension is a painful workaround to avoid "constant -1 overflows byte". +var msgTimeExtension = int8(MsgTimeExtension) + // our extensions live here var extensionReg = make(map[int8]func() Extension) @@ -477,15 +484,27 @@ func AppendExtension(b []byte, e Extension) ([]byte, error) { // - InvalidPrefixError // - An umarshal error returned from e.UnmarshalBinary func ReadExtensionBytes(b []byte, e Extension) ([]byte, error) { + typ, remain, data, err := readExt(b) + if err != nil { + return b, err + } + if typ != e.ExtensionType() { + return b, errExt(typ, e.ExtensionType()) + } + return remain, e.UnmarshalBinary(data) +} + +// readExt will read the extension type, and return remaining bytes, +// as well as the data of the extension. +func readExt(b []byte) (typ int8, remain []byte, data []byte, err error) { l := len(b) if l < 3 { - return b, ErrShortBytes + return 0, b, nil, ErrShortBytes } lead := b[0] var ( sz int // size of 'data' off int // offset of 'data' - typ int8 ) switch lead { case mfixext1: @@ -513,35 +532,30 @@ func ReadExtensionBytes(b []byte, e Extension) ([]byte, error) { typ = int8(b[2]) off = 3 if sz == 0 { - return b[3:], e.UnmarshalBinary(b[3:3]) + return typ, b[3:], b[3:3], nil } case mext16: if l < 4 { - return b, ErrShortBytes + return 0, b, nil, ErrShortBytes } sz = int(big.Uint16(b[1:])) typ = int8(b[3]) off = 4 case mext32: if l < 6 { - return b, ErrShortBytes + return 0, b, nil, ErrShortBytes } sz = int(big.Uint32(b[1:])) typ = int8(b[5]) off = 6 default: - return b, badPrefix(ExtensionType, lead) - } - - if typ != e.ExtensionType() { - return b, errExt(typ, e.ExtensionType()) + return 0, b, nil, badPrefix(ExtensionType, lead) } - // the data of the extension starts // at 'off' and is 'sz' bytes long + tot := off + sz if len(b[off:]) < sz { - return b, ErrShortBytes + return 0, b, nil, ErrShortBytes } - tot := off + sz - return b[tot:], e.UnmarshalBinary(b[off:tot]) + return typ, b[tot:], b[off:tot:tot], nil } diff --git a/msgp/json_bytes.go b/msgp/json_bytes.go index 88ec6045..d4fbda63 100644 --- a/msgp/json_bytes.go +++ b/msgp/json_bytes.go @@ -72,7 +72,7 @@ func writeNext(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byt if err != nil { return nil, scratch, err } - if et == TimeExtension { + if et == TimeExtension || et == MsgTimeExtension { t = TimeType } } @@ -276,7 +276,7 @@ func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte } // if it's time.Time - if et == TimeExtension { + if et == TimeExtension || et == MsgTimeExtension { var tm time.Time tm, msg, err = ReadTimeBytes(msg) if err != nil { diff --git a/msgp/read.go b/msgp/read.go index 5eb0b107..d15355e2 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -1,6 +1,7 @@ package msgp import ( + "encoding/binary" "encoding/json" "io" "math" @@ -259,7 +260,7 @@ func (m *Reader) NextType() (Type, error) { return Complex64Type, nil case Complex128Extension: return Complex128Type, nil - case TimeExtension: + case TimeExtension, MsgTimeExtension: return TimeType, nil } } @@ -1262,22 +1263,76 @@ func (m *Reader) ReadMapStrIntf(mp map[string]interface{}) (err error) { // ReadTime reads a time.Time object from the reader. // The returned time's location will be set to time.Local. func (m *Reader) ReadTime() (t time.Time, err error) { - var p []byte - p, err = m.R.Peek(15) + offset, length, extType, err := m.peekExtensionHeader() if err != nil { - return - } - if p[0] != mext8 || p[1] != 12 { - err = badPrefix(TimeType, p[0]) - return + return t, err } - if int8(p[2]) != TimeExtension { - err = errExt(int8(p[2]), TimeExtension) + + switch extType { + case TimeExtension: + var p []byte + p, err = m.R.Peek(15) + if err != nil { + return + } + if p[0] != mext8 || p[1] != 12 { + err = badPrefix(TimeType, p[0]) + return + } + if int8(p[2]) != TimeExtension { + err = errExt(int8(p[2]), TimeExtension) + return + } + sec, nsec := getUnix(p[3:]) + t = time.Unix(sec, int64(nsec)).Local() + _, err = m.R.Skip(15) return + case MsgTimeExtension: + switch length { + case 4, 8, 12: + var tmp [12]byte + _, err = m.R.Skip(offset) + if err != nil { + return + } + var n int + n, err = m.R.Read(tmp[:length]) + if err != nil { + return + } + if n != length { + err = ErrShortBytes + return + } + b := tmp[:length] + switch length { + case 4: + t = time.Unix(int64(binary.BigEndian.Uint32(b)), 0).Local() + case 8: + v := binary.BigEndian.Uint64(b) + nanos := int64(v >> 34) + if nanos > 999999999 { + // In timestamp 64 and timestamp 96 formats, nanoseconds must not be larger than 999999999. + err = InvalidTimestamp{Nanos: nanos} + return + } + t = time.Unix(int64(v&(1<<34-1)), nanos).Local() + case 12: + nanos := int64(binary.BigEndian.Uint32(b)) + if nanos > 999999999 { + // In timestamp 64 and timestamp 96 formats, nanoseconds must not be larger than 999999999. + err = InvalidTimestamp{Nanos: nanos} + return + } + ux := int64(binary.BigEndian.Uint64(b[4:])) + t = time.Unix(ux, nanos).Local() + } + default: + err = InvalidTimestamp{FieldLength: length} + } + default: + err = errExt(extType, TimeExtension) } - sec, nsec := getUnix(p[3:]) - t = time.Unix(sec, int64(nsec)).Local() - _, err = m.R.Skip(15) return } diff --git a/msgp/read_bytes.go b/msgp/read_bytes.go index cd20e97f..292704dd 100644 --- a/msgp/read_bytes.go +++ b/msgp/read_bytes.go @@ -29,7 +29,7 @@ func NextType(b []byte) Type { tp = int8(b[spec.size-1]) } switch tp { - case TimeExtension: + case TimeExtension, MsgTimeExtension: return TimeType case Complex128Extension: return Complex128Type @@ -1069,29 +1069,64 @@ func ReadComplex64Bytes(b []byte) (c complex64, o []byte, err error) { // ReadTimeBytes reads a time.Time // extension object from 'b' and returns the // remaining bytes. +// Both the official and the format in this package will be read. // // Possible errors: // // - [ErrShortBytes] (not enough bytes in 'b') -// - [TypeError] (object not a complex64) +// - [TypeError] (object not a time extension 5 or -1) // - [ExtensionTypeError] (object an extension of the correct size, but not a time.Time) func ReadTimeBytes(b []byte) (t time.Time, o []byte, err error) { - if len(b) < 15 { + if len(b) < 6 { err = ErrShortBytes return } - if b[0] != mext8 || b[1] != 12 { - err = badPrefix(TimeType, b[0]) + typ, o, b, err := readExt(b) + if err != nil { return } - if int8(b[2]) != TimeExtension { + switch typ { + case TimeExtension: + if len(b) != 12 { + err = ErrShortBytes + return + } + sec, nsec := getUnix(b) + t = time.Unix(sec, int64(nsec)).Local() + return + case MsgTimeExtension: + switch len(b) { + case 4: + t = time.Unix(int64(binary.BigEndian.Uint32(b)), 0).Local() + return + case 8: + v := binary.BigEndian.Uint64(b) + nanos := int64(v >> 34) + if nanos > 999999999 { + // In timestamp 64 and timestamp 96 formats, nanoseconds must not be larger than 999999999. + err = InvalidTimestamp{Nanos: nanos} + return + } + t = time.Unix(int64(v&(1<<34-1)), nanos).Local() + return + case 12: + nanos := int64(binary.BigEndian.Uint32(b)) + if nanos > 999999999 { + // In timestamp 64 and timestamp 96 formats, nanoseconds must not be larger than 999999999. + err = InvalidTimestamp{Nanos: nanos} + return + } + ux := int64(binary.BigEndian.Uint64(b[4:])) + t = time.Unix(ux, nanos).Local() + return + default: + err = InvalidTimestamp{FieldLength: len(b)} + return + } + default: err = errExt(int8(b[2]), TimeExtension) return } - sec, nsec := getUnix(b[3:]) - t = time.Unix(sec, int64(nsec)).Local() - o = b[15:] - return } // ReadMapStrIntfBytes reads a map[string]interface{} diff --git a/msgp/write.go b/msgp/write.go index dfe0d3e8..352350f9 100644 --- a/msgp/write.go +++ b/msgp/write.go @@ -1,6 +1,7 @@ package msgp import ( + "encoding/binary" "encoding/json" "errors" "io" @@ -635,6 +636,48 @@ func (mw *Writer) WriteTime(t time.Time) error { return nil } +// WriteTimeExt will write t using the official msgpack extension spec. +// /~https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type +func (mw *Writer) WriteTimeExt(t time.Time) error { + // Time rounded towards zero. + secPrec := t.Truncate(time.Second) + remain := t.Sub(secPrec).Nanoseconds() + asSecs := secPrec.Unix() + switch { + case remain == 0 && asSecs > 0 && asSecs <= math.MaxUint32: + // 4 bytes + o, err := mw.require(6) + if err != nil { + return err + } + mw.buf[o] = mfixext4 + mw.buf[o+1] = byte(msgTimeExtension) + binary.BigEndian.PutUint32(mw.buf[o+2:], uint32(asSecs)) + return nil + case asSecs < 0 || asSecs >= (1<<34): + // 12 bytes + o, err := mw.require(12 + 3) + if err != nil { + return err + } + mw.buf[o] = mext8 + mw.buf[o+1] = 12 + mw.buf[o+2] = byte(msgTimeExtension) + binary.BigEndian.PutUint32(mw.buf[o+3:], uint32(remain)) + binary.BigEndian.PutUint64(mw.buf[o+3+4:], uint64(asSecs)) + default: + // 8 bytes + o, err := mw.require(10) + if err != nil { + return err + } + mw.buf[o] = mfixext8 + mw.buf[o+1] = byte(msgTimeExtension) + binary.BigEndian.PutUint64(mw.buf[o+2:], uint64(asSecs)|(uint64(remain)<<34)) + } + return nil +} + // WriteJSONNumber writes the json.Number to the stream as either integer or float. func (mw *Writer) WriteJSONNumber(n json.Number) error { if n == "" { diff --git a/msgp/write_bytes.go b/msgp/write_bytes.go index a95b1d0b..70450174 100644 --- a/msgp/write_bytes.go +++ b/msgp/write_bytes.go @@ -1,6 +1,7 @@ package msgp import ( + "encoding/binary" "encoding/json" "errors" "math" @@ -322,6 +323,40 @@ func AppendTime(b []byte, t time.Time) []byte { return o } +// AppendTimeExt will write t using the official msgpack extension spec. +// /~https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type +func AppendTimeExt(b []byte, t time.Time) []byte { + // Time rounded towards zero. + secPrec := t.Truncate(time.Second) + remain := t.Sub(secPrec).Nanoseconds() + asSecs := secPrec.Unix() + switch { + case remain == 0 && asSecs > 0 && asSecs <= math.MaxUint32: + // 4 bytes + o, n := ensure(b, 2+4) + o[n+0] = mfixext4 + o[n+1] = byte(msgTimeExtension) + binary.BigEndian.PutUint32(o[n+2:], uint32(asSecs)) + return o + case asSecs < 0 || asSecs >= (1<<34): + // 12 bytes + o, n := ensure(b, 3+12) + o[n+0] = mext8 + o[n+1] = 12 + o[n+2] = byte(msgTimeExtension) + binary.BigEndian.PutUint32(o[n+3:], uint32(remain)) + binary.BigEndian.PutUint64(o[n+3+4:], uint64(asSecs)) + return o + default: + // 8 bytes + o, n := ensure(b, 2+8) + o[n+0] = mfixext8 + o[n+1] = byte(msgTimeExtension) + binary.BigEndian.PutUint64(o[n+2:], uint64(asSecs)|(uint64(remain)<<34)) + return o + } +} + // AppendMapStrStr appends a map[string]string to the slice // as a MessagePack map with 'str'-type keys and values func AppendMapStrStr(b []byte, m map[string]string) []byte { diff --git a/msgp/write_bytes_test.go b/msgp/write_bytes_test.go index 68f2eb3e..86d93d3b 100644 --- a/msgp/write_bytes_test.go +++ b/msgp/write_bytes_test.go @@ -420,6 +420,16 @@ func BenchmarkAppendTime(b *testing.B) { } } +func BenchmarkAppendTimeExt(b *testing.B) { + t := time.Now() + buf := make([]byte, 0, 15) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + AppendTimeExt(buf[0:0], t) + } +} + // TestEncodeDecode does a back-and-forth test of encoding and decoding and compare the value with a given output. func TestEncodeDecode(t *testing.T) { for _, tc := range []struct { diff --git a/parse/directives.go b/parse/directives.go index 1a50a98a..c48caeb4 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -28,6 +28,7 @@ var directives = map[string]directive{ "tuple": astuple, "compactfloats": compactfloats, "clearomitted": clearomitted, + "newtime": newtime, } // map of all recognized directives which will be applied @@ -200,3 +201,9 @@ func clearomitted(text []string, f *FileSet) error { f.ClearOmitted = true return nil } + +//msgp:newtime +func newtime(text []string, f *FileSet) error { + f.NewTime = true + return nil +} diff --git a/parse/getast.go b/parse/getast.go index 35c2bd8f..9c319dc4 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -23,6 +23,7 @@ type FileSet struct { Imports []*ast.ImportSpec // imports CompactFloats bool // Use smaller floats when feasible ClearOmitted bool // Set omitted fields to zero value + NewTime bool // Set to use -1 extension for time.Time tagName string // tag to read field names from pointerRcv bool // generate with pointer receivers. } @@ -273,6 +274,7 @@ loop: } p.CompactFloats = f.CompactFloats p.ClearOmitted = f.ClearOmitted + p.NewTime = f.NewTime } func (f *FileSet) PrintTo(p *gen.Printer) error { @@ -530,7 +532,7 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { case *ast.Ident: b := gen.Ident(e.Name) - // work to resove this expression + // work to resolve this expression // can be done later, once we've resolved // everything else. if b.Value == gen.IDENT {