Skip to content

Commit d62a084

Browse files
serialization Add support for decoding enum strings to their underlying values when using mapstructure tags (#741)
<!-- Copyright (C) 2020-2022 Arm Limited or its affiliates and Contributors. All rights reserved. SPDX-License-Identifier: Apache-2.0 --> ### Description <!-- Please add any detail or context that would be useful to a reviewer. --> Add support for decoding enum strings to their underlying values when using mapstructure tags ### Test Coverage <!-- Please put an `x` in the correct box e.g. `[x]` to indicate the testing coverage of this change. --> - [x] This change is covered by existing or additional automated tests. - [ ] Manual testing has been performed (and evidence provided) as automated testing was not feasible. - [ ] Additional tests are not required for this change (e.g. documentation update).
1 parent c3966be commit d62a084

File tree

6 files changed

+272
-10
lines changed

6 files changed

+272
-10
lines changed

changes/20251030161706.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: `serialization` Add support for decoding enum strings to their underlying values when using mapstructure tags

utils/config/service_configuration.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/ARM-software/golang-utils/utils/field"
2121
"github.com/ARM-software/golang-utils/utils/keyring"
2222
"github.com/ARM-software/golang-utils/utils/reflection"
23+
"github.com/ARM-software/golang-utils/utils/serialization/maps" //nolint:misspell
2324
)
2425

2526
const (
@@ -113,7 +114,12 @@ func LoadFromEnvironmentAndSystem(viperSession *viper.Viper, envVarPrefix string
113114
}
114115

115116
// Merge together all the sources and unmarshal into struct
116-
err = viperSession.Unmarshal(configurationToSet)
117+
err = viperSession.Unmarshal(configurationToSet, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
118+
maps.CustomTypeHookFunc(),
119+
// Keep these two as they are the default values used by viper and we don't want to override them
120+
mapstructure.StringToTimeDurationHookFunc(),
121+
mapstructure.StringToSliceHookFunc(","),
122+
)))
117123
if err != nil {
118124
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "unable to fill configuration structure from the configuration session")
119125
return

utils/config/service_configuration_test.go

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/ARM-software/golang-utils/utils/commonerrors"
2727
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
2828
"github.com/ARM-software/golang-utils/utils/keyring"
29+
mapstest "github.com/ARM-software/golang-utils/utils/serialization/maps/testing" //nolint:misspell
2930
)
3031

3132
var (
@@ -39,13 +40,15 @@ var (
3940
)
4041

4142
type DummyConfiguration struct {
42-
Host string `mapstructure:"dummy_host"`
43-
Port int `mapstructure:"port"`
44-
DB string `mapstructure:"db"`
45-
User string `mapstructure:"user"`
46-
Password string `mapstructure:"password"`
47-
Flag bool `mapstructure:"flag"`
48-
HealthCheckPeriod time.Duration `mapstructure:"healthcheck_period"`
43+
Host string `mapstructure:"dummy_host"`
44+
Port int `mapstructure:"port"`
45+
DB string `mapstructure:"db"`
46+
User string `mapstructure:"user"`
47+
Password string `mapstructure:"password"`
48+
Flag bool `mapstructure:"flag"`
49+
TestEnum mapstest.TestEnumWithUnmarshal `mapstructure:"enum"`
50+
TestEnum1 mapstest.TestEnumWithoutUnmarshal `mapstructure:"enum1"`
51+
HealthCheckPeriod time.Duration `mapstructure:"healthcheck_period"`
4952
}
5053

5154
func (cfg *DummyConfiguration) Validate() error {
@@ -55,6 +58,7 @@ func (cfg *DummyConfiguration) Validate() error {
5558
validation.Field(&cfg.DB, validation.Required),
5659
validation.Field(&cfg.User, validation.Required),
5760
validation.Field(&cfg.Password, validation.Required),
61+
validation.Field(&cfg.TestEnum, validation.By(mapstest.ValidationFunc)),
5862
)
5963
}
6064

@@ -189,6 +193,14 @@ func TestServiceConfigurationLoad(t *testing.T) {
189193
require.NoError(t, err)
190194
err = os.Setenv("TEST_DUMMY_CONFIG_USER", "a test user")
191195
require.NoError(t, err)
196+
err = os.Setenv("TEST_DUMMY_CONFIG_ENUM", mapstest.TestEnumStringVer1)
197+
require.NoError(t, err)
198+
err = os.Setenv("TEST_DUMMYCONFIG_ENUM", mapstest.TestEnumStringVer1)
199+
require.NoError(t, err)
200+
err = os.Setenv("TEST_DUMMY_CONFIG_ENUM1", "1")
201+
require.NoError(t, err)
202+
err = os.Setenv("TEST_DUMMYCONFIG_ENUM1", "1")
203+
require.NoError(t, err)
192204
err = os.Setenv("TEST_DUMMYCONFIG_DB", "a test db")
193205
require.NoError(t, err)
194206
err = os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB)
@@ -407,6 +419,8 @@ func TestFlagsBinding(t *testing.T) {
407419
flagSet.Int("int", 0, "dummy int")
408420
flagSet.Duration("time", time.Second, "dummy time")
409421
flagSet.Bool("flag", false, "dummy flag")
422+
flagSet.String("enum", mapstest.TestEnumStringVer1, "dummy enum")
423+
flagSet.String("enum1", "1", "dummy enum")
410424
err = BindFlagsToEnv(session, prefix, "TEST_DUMMYCONFIG_DUMMY_HOST", flagSet.Lookup("host2"), flagSet.Lookup("host2"))
411425
require.NoError(t, err)
412426
err = BindFlagsToEnv(session, prefix, "TEST_DUMMY_CONFIG_DUMMY_HOST", flagSet.Lookup("host1"), flagSet.Lookup("host2"))
@@ -419,6 +433,10 @@ func TestFlagsBinding(t *testing.T) {
419433
require.NoError(t, err)
420434
err = BindFlagsToEnv(session, prefix, "DUMMY_CONFIG_USER", flagSet.Lookup("user1"), flagSet.Lookup("user2"))
421435
require.NoError(t, err)
436+
err = BindFlagToEnv(session, prefix, "DUMMY_CONFIG_ENUM", flagSet.Lookup("enum"))
437+
require.NoError(t, err)
438+
err = BindFlagToEnv(session, prefix, "DUMMY_CONFIG_ENUM1", flagSet.Lookup("enum1"))
439+
require.NoError(t, err)
422440
err = BindFlagsToEnv(session, prefix, "TEST_DUMMYCONFIG_DB", flagSet.Lookup("db"))
423441
require.NoError(t, err)
424442
err = BindFlagsToEnv(session, prefix, "DUMMY_CONFIG_DB", flagSet.Lookup("db2"), flagSet.Lookup("db2"), flagSet.Lookup("db2"), flagSet.Lookup("db2"))
@@ -476,6 +494,12 @@ func TestFlagsBinding(t *testing.T) {
476494
assert.Equal(t, expectedHost, configTest.TestConfig2.Host)
477495
assert.Equal(t, expectedPassword, configTest.TestConfig.Password)
478496
assert.Equal(t, expectedPassword, configTest.TestConfig2.Password)
497+
assert.NotEqual(t, mapstest.TestEnumStringVer1, configTest.TestConfig2.TestEnum)
498+
assert.NotEqual(t, mapstest.TestEnumStringVer0, configTest.TestConfig.TestEnum)
499+
assert.Equal(t, mapstest.TestEnumWithUnmarshal1, configTest.TestConfig2.TestEnum)
500+
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig.TestEnum)
501+
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal1, configTest.TestConfig2.TestEnum1)
502+
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig.TestEnum1)
479503
assert.Equal(t, expectedDB, configTest.TestConfig.DB)
480504
assert.Equal(t, aDifferentDB, configTest.TestConfig2.DB)
481505
assert.NotEqual(t, expectedDB, configTest.TestConfig2.DB)
@@ -516,6 +540,8 @@ func TestFlagBindingDefaults(t *testing.T) {
516540
flagSet.Int("int", expectedInt, "dummy int")
517541
flagSet.Duration("time", expectedDuration, "dummy time")
518542
flagSet.Bool("flag", !DefaultDummyConfiguration().Flag, "dummy flag")
543+
flagSet.String("enum", mapstest.TestEnumStringVer0, "dummy enum")
544+
flagSet.String("enum1", "0", "dummy enum")
519545
err = BindFlagToEnv(session, prefix, "TEST_DUMMYCONFIG_DUMMY_HOST", flagSet.Lookup("host"))
520546
require.NoError(t, err)
521547
err = BindFlagToEnv(session, prefix, "TEST_DUMMY_CONFIG_DUMMY_HOST", flagSet.Lookup("host2"))
@@ -538,6 +564,10 @@ func TestFlagBindingDefaults(t *testing.T) {
538564
require.NoError(t, err)
539565
err = BindFlagToEnv(session, prefix, "DUMMY_Time", flagSet.Lookup("time"))
540566
require.NoError(t, err)
567+
err = BindFlagToEnv(session, prefix, "DUMMY_enum", flagSet.Lookup("enum"))
568+
require.NoError(t, err)
569+
err = BindFlagToEnv(session, prefix, "DUMMY_enum1", flagSet.Lookup("enum1"))
570+
require.NoError(t, err)
541571
err = os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB) // Should take precedence over flag default
542572
require.NoError(t, err)
543573
err = LoadFromViper(session, prefix, configTest, defaults)
@@ -556,6 +586,10 @@ func TestFlagBindingDefaults(t *testing.T) {
556586
assert.Equal(t, expectedPassword, configTest.TestConfig2.Password)
557587
assert.Equal(t, aDifferentDB, configTest.TestConfig.DB)
558588
assert.Equal(t, expectedDB, configTest.TestConfig2.DB)
589+
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig2.TestEnum)
590+
assert.Equal(t, mapstest.TestEnumWithUnmarshal0, configTest.TestConfig.TestEnum)
591+
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig2.TestEnum1)
592+
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal0, configTest.TestConfig.TestEnum1)
559593
// Defaults from the default structure provided take precedence over defaults from flags when empty.
560594
assert.Equal(t, DefaultConfiguration().TestConfig.Flag, configTest.TestConfig.Flag)
561595
assert.Equal(t, DefaultConfiguration().TestConfig.Flag, configTest.TestConfig2.Flag)
@@ -575,6 +609,8 @@ func TestGenerateEnvFile_Defaults(t *testing.T) {
575609
"TEST_PASSWORD": configTest.Password,
576610
"TEST_PORT": configTest.Port,
577611
"TEST_USER": configTest.User,
612+
"TEST_ENUM": configTest.TestEnum,
613+
"TEST_ENUM1": configTest.TestEnum1,
578614
}
579615

580616
// Generate env file
@@ -601,6 +637,8 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
601637
flagSet.String("password", "a password", "dummy password")
602638
flagSet.String("user", "a user", "dummy user")
603639
flagSet.String("db", "a db", "dummy db")
640+
flagSet.String("enum", mapstest.TestEnumStringVer1, "dummy enum")
641+
flagSet.String("enum1", "1", "dummy enum")
604642
err = BindFlagToEnv(session, prefix, "TEST_DUMMY_HOST", flagSet.Lookup("host"))
605643
require.NoError(t, err)
606644
err = BindFlagToEnv(session, prefix, "PASSWORD", flagSet.Lookup("password"))
@@ -611,6 +649,10 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
611649
require.NoError(t, err)
612650
err = BindFlagToEnv(session, prefix, "USER", flagSet.Lookup("user"))
613651
require.NoError(t, err)
652+
err = BindFlagToEnv(session, prefix, "ENUM", flagSet.Lookup("enum"))
653+
require.NoError(t, err)
654+
err = BindFlagToEnv(session, prefix, "ENUM1", flagSet.Lookup("enum1"))
655+
require.NoError(t, err)
614656
err = flagSet.Set("host", expectedHost)
615657
require.NoError(t, err)
616658
err = flagSet.Set("password", expectedPassword)
@@ -630,6 +672,8 @@ func TestGenerateEnvFile_Populated(t *testing.T) {
630672
"TEST_PASSWORD": configTest.Password,
631673
"TEST_PORT": configTest.Port,
632674
"TEST_USER": configTest.User,
675+
"TEST_ENUM": configTest.TestEnum,
676+
"TEST_ENUM1": configTest.TestEnum1,
633677
}
634678

635679
// Generate env file
@@ -670,13 +714,17 @@ func TestGenerateEnvFile_Nested(t *testing.T) {
670714
"TEST_DEEP_CONFIG_DUMMYCONFIG_PASSWORD": configTest.TestConfigDeep.TestConfig.Password,
671715
"TEST_DEEP_CONFIG_DUMMYCONFIG_PORT": configTest.TestConfigDeep.TestConfig.Port,
672716
"TEST_DEEP_CONFIG_DUMMYCONFIG_USER": configTest.TestConfigDeep.TestConfig.User,
717+
"TEST_DEEP_CONFIG_DUMMYCONFIG_ENUM": configTest.TestConfigDeep.TestConfig.TestEnum,
718+
"TEST_DEEP_CONFIG_DUMMYCONFIG_ENUM1": configTest.TestConfigDeep.TestConfig.TestEnum1,
673719
"TEST_DEEP_CONFIG_DUMMY_CONFIG_DB": configTest.TestConfigDeep.TestConfig2.DB,
674720
"TEST_DEEP_CONFIG_DUMMY_CONFIG_DUMMY_HOST": configTest.TestConfigDeep.TestConfig2.Host,
675721
"TEST_DEEP_CONFIG_DUMMY_CONFIG_FLAG": configTest.TestConfigDeep.TestConfig2.Flag,
676722
"TEST_DEEP_CONFIG_DUMMY_CONFIG_HEALTHCHECK_PERIOD": configTest.TestConfigDeep.TestConfig2.HealthCheckPeriod,
677723
"TEST_DEEP_CONFIG_DUMMY_CONFIG_PASSWORD": configTest.TestConfigDeep.TestConfig2.Password,
678724
"TEST_DEEP_CONFIG_DUMMY_CONFIG_PORT": configTest.TestConfigDeep.TestConfig2.Port,
679725
"TEST_DEEP_CONFIG_DUMMY_CONFIG_USER": configTest.TestConfigDeep.TestConfig2.User,
726+
"TEST_DEEP_CONFIG_DUMMY_CONFIG_ENUM": configTest.TestConfigDeep.TestConfig2.TestEnum,
727+
"TEST_DEEP_CONFIG_DUMMY_CONFIG_ENUM1": configTest.TestConfigDeep.TestConfig2.TestEnum1,
680728
"TEST_DEEP_CONFIG_DUMMY_INT": configTest.TestConfigDeep.TestInt,
681729
"TEST_DUMMY_STRING": configTest.TestString,
682730
"TEST_DEEP_CONFIG_DUMMY_TIME": configTest.TestConfigDeep.TestTime,
@@ -1040,3 +1088,57 @@ func loadEnvIntoEnvironment(t *testing.T, envPath string) (err error) {
10401088

10411089
return
10421090
}
1091+
1092+
func TestCustomTypeHook_Success(t *testing.T) {
1093+
t.Cleanup(os.Clearenv)
1094+
os.Clearenv()
1095+
1096+
cfg := &ConfigurationTest{}
1097+
defaults := DefaultConfiguration()
1098+
1099+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DUMMY_HOST", expectedHost))
1100+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_PASSWORD", expectedPassword))
1101+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_USER", "user"))
1102+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DB", expectedDB))
1103+
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DUMMY_HOST", expectedHost))
1104+
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_PASSWORD", expectedPassword))
1105+
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_USER", "user"))
1106+
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB))
1107+
require.NoError(t, os.Setenv("TEST_DUMMY_INT", fmt.Sprintf("%v", expectedInt)))
1108+
require.NoError(t, os.Setenv("TEST_DUMMY_TIME", expectedDuration.String()))
1109+
1110+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM", mapstest.TestEnumStringVer1))
1111+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM1", "1"))
1112+
1113+
err := Load("test", cfg, defaults)
1114+
require.NoError(t, err)
1115+
require.NoError(t, cfg.Validate())
1116+
1117+
assert.Equal(t, mapstest.TestEnumWithUnmarshal1, cfg.TestConfig.TestEnum)
1118+
assert.Equal(t, mapstest.TestEnumWithoutUnmarshal1, cfg.TestConfig.TestEnum1)
1119+
}
1120+
1121+
func TestCustomTypeHook_InvalidValue(t *testing.T) {
1122+
t.Cleanup(os.Clearenv)
1123+
os.Clearenv()
1124+
1125+
cfg := &ConfigurationTest{}
1126+
defaults := DefaultConfiguration()
1127+
1128+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DUMMY_HOST", expectedHost))
1129+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_PASSWORD", expectedPassword))
1130+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_USER", "user"))
1131+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_DB", expectedDB))
1132+
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DUMMY_HOST", expectedHost))
1133+
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_PASSWORD", expectedPassword))
1134+
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_USER", "user"))
1135+
require.NoError(t, os.Setenv("TEST_DUMMY_CONFIG_DB", expectedDB))
1136+
require.NoError(t, os.Setenv("TEST_DUMMY_INT", fmt.Sprintf("%v", expectedInt)))
1137+
require.NoError(t, os.Setenv("TEST_DUMMY_TIME", expectedDuration.String()))
1138+
1139+
require.NoError(t, os.Setenv("TEST_DUMMYCONFIG_ENUM", "4"))
1140+
1141+
err := Load("test", cfg, defaults)
1142+
errortest.AssertError(t, err, commonerrors.ErrInvalid)
1143+
errortest.AssertErrorDescription(t, err, "structure failed validation")
1144+
}

utils/serialization/maps/map.go

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package maps
22

33
import (
4+
"encoding"
45
"reflect"
6+
"strconv"
57
"time"
68

79
"github.com/go-viper/mapstructure/v2"
810

911
"github.com/ARM-software/golang-utils/utils/commonerrors"
1012
"github.com/ARM-software/golang-utils/utils/maps"
13+
"github.com/ARM-software/golang-utils/utils/safecast"
1114
)
1215

1316
// ToMapFromPointer is like ToMap but deals with a pointer.
@@ -72,7 +75,7 @@ func FromMapToPointer[T any](m map[string]string, o T) (err error) {
7275

7376
err = mapstructureDecoder(expandedMap, o)
7477
if err != nil {
75-
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to deserialise upload request")
78+
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to deserialise the map")
7679
}
7780
return
7881
}
@@ -148,11 +151,57 @@ func toTime(f reflect.Type, t reflect.Type, data any) (any, error) {
148151
}
149152
}
150153

154+
func toCustomTypeIntFallback(f reflect.Type, t reflect.Type, data any) (any, error) {
155+
if f == nil || t == nil || f.Kind() != reflect.String {
156+
return data, nil
157+
}
158+
if t.Kind() != reflect.Int {
159+
return data, nil
160+
}
161+
162+
s, ok := data.(string)
163+
if !ok {
164+
return data, nil
165+
}
166+
167+
i, err := strconv.Atoi(s)
168+
if err != nil {
169+
return data, nil
170+
}
171+
172+
ptr := reflect.New(t).Elem()
173+
ptr.SetInt(safecast.ToInt64(i))
174+
175+
return ptr.Interface(), nil
176+
}
177+
178+
func toCustomType(f reflect.Type, t reflect.Type, data any) (any, error) {
179+
if f == nil || t == nil || f.Kind() != reflect.String {
180+
return data, nil
181+
}
182+
183+
customType, ok := reflect.New(t).Interface().(encoding.TextUnmarshaler)
184+
if !ok {
185+
return toCustomTypeIntFallback(f, t, data)
186+
}
187+
188+
err := customType.UnmarshalText([]byte(data.(string))) // we know it is a string based on reflection
189+
if err != nil {
190+
return toCustomTypeIntFallback(f, t, data)
191+
}
192+
193+
return customType, nil
194+
}
195+
196+
func CustomTypeHookFunc() mapstructure.DecodeHookFunc {
197+
return toCustomType
198+
}
199+
151200
func mapstructureDecoder(input, result any) error {
152201
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
153202
WeaklyTypedInput: true,
154203
DecodeHook: mapstructure.ComposeDecodeHookFunc(
155-
timeHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToURLHookFunc(), mapstructure.StringToIPHookFunc()),
204+
timeHookFunc(), CustomTypeHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToURLHookFunc(), mapstructure.StringToIPHookFunc()),
156205
Result: result,
157206
})
158207
if err != nil {

0 commit comments

Comments
 (0)