diff --git a/decimal.go b/decimal.go
index 475161dc..e68ad2c8 100644
--- a/decimal.go
+++ b/decimal.go
@@ -1381,6 +1381,33 @@ func (d NullDecimal) MarshalJSON() ([]byte, error) {
return d.Decimal.MarshalJSON()
}
+// UnmarshalText implements the encoding.TextUnmarshaler interface for XML
+// deserialization
+func (d *NullDecimal) UnmarshalText(text []byte) error {
+ str := string(text)
+
+ // check for empty XML or XML without body e.g.,
+ if str == "" {
+ d.Valid = false
+ return nil
+ }
+ if err := d.Decimal.UnmarshalText(text); err != nil {
+ d.Valid = false
+ return err
+ }
+ d.Valid = true
+ return nil
+}
+
+// MarshalText implements the encoding.TextMarshaler interface for XML
+// serialization.
+func (d NullDecimal) MarshalText() (text []byte, err error) {
+ if !d.Valid {
+ return []byte{}, nil
+ }
+ return d.Decimal.MarshalText()
+}
+
// Trig functions
// Atan returns the arctangent, in radians, of x.
diff --git a/decimal_test.go b/decimal_test.go
index 72750abb..1bfca281 100644
--- a/decimal_test.go
+++ b/decimal_test.go
@@ -766,6 +766,96 @@ func TestBadXML(t *testing.T) {
}
}
+func TestNullDecimalXML(t *testing.T) {
+ // test valid values
+ for _, x := range testTable {
+ s := x.short
+ var doc struct {
+ XMLName xml.Name `xml:"account"`
+ Amount NullDecimal `xml:"amount"`
+ }
+ docStr := `` + s + ``
+ err := xml.Unmarshal([]byte(docStr), &doc)
+ if err != nil {
+ t.Errorf("error unmarshaling %s: %v", docStr, err)
+ } else if doc.Amount.Decimal.String() != s {
+ t.Errorf("expected %s, got %s (%s, %d)",
+ s, doc.Amount.Decimal.String(),
+ doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp)
+ }
+
+ out, err := xml.Marshal(&doc)
+ if err != nil {
+ t.Errorf("error marshaling %+v: %v", doc, err)
+ } else if string(out) != docStr {
+ t.Errorf("expected %s, got %s", docStr, string(out))
+ }
+ }
+
+ var doc struct {
+ XMLName xml.Name `xml:"account"`
+ Amount NullDecimal `xml:"amount"`
+ }
+
+ // test for XML with empty body
+ docStr := ``
+ err := xml.Unmarshal([]byte(docStr), &doc)
+ if err != nil {
+ t.Errorf("error unmarshaling: %s: %v", docStr, err)
+ } else if doc.Amount.Valid {
+ t.Errorf("expected null value to have Valid = false, got Valid = true and Decimal = %s (%s, %d)",
+ doc.Amount.Decimal.String(),
+ doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp)
+ }
+
+ expected := ``
+ out, err := xml.Marshal(&doc)
+ if err != nil {
+ t.Errorf("error marshaling %+v: %v", doc, err)
+ } else if string(out) != expected {
+ t.Errorf("expected %s, got %s", expected, string(out))
+ }
+
+ // test for empty XML
+ docStr = ``
+ err = xml.Unmarshal([]byte(docStr), &doc)
+ if err != nil {
+ t.Errorf("error unmarshaling: %s: %v", docStr, err)
+ } else if doc.Amount.Valid {
+ t.Errorf("expected null value to have Valid = false, got Valid = true and Decimal = %s (%s, %d)",
+ doc.Amount.Decimal.String(),
+ doc.Amount.Decimal.value.String(), doc.Amount.Decimal.exp)
+ }
+
+ expected = ``
+ out, err = xml.Marshal(&doc)
+ if err != nil {
+ t.Errorf("error marshaling %+v: %v", doc, err)
+ } else if string(out) != expected {
+ t.Errorf("expected %s, got %s", expected, string(out))
+ }
+}
+
+func TestNullDecimalBadXML(t *testing.T) {
+ for _, testCase := range []string{
+ "o_o",
+ "7",
+ ``,
+ `nope`,
+ `0.333`,
+ } {
+ var doc struct {
+ XMLName xml.Name `xml:"account"`
+ Amount NullDecimal `xml:"amount"`
+ }
+ err := xml.Unmarshal([]byte(testCase), &doc)
+ if err == nil {
+ t.Errorf("expected error, got %+v", doc)
+ }
+ }
+}
+
func TestDecimal_rescale(t *testing.T) {
type Inp struct {
int int64