Go 合并 Map

1. 特性

以下以将 map m2 合并到 map m1 进行说明。

  1. 如果 key 出现在 m2 中,但是 m1 中没有,则向 m1 添加
  1. 如果 key 同时出现在 m1 和 m2 中,那么使用 m2 中相应的 value 覆盖 m1 的
  1. 如果 key 出现在 m1 中,但是 m2 中没有,那么保留
  1. 如果类型不一致,那么可以自定义类型转换函数。通过类型转换函数也可以转换 slice,比如定义 string 到 int 的转换函数,那么它也被用于将 []string 转换为 []int
  1. 递归地处理 map 中嵌套的 map(不支持 slice 中嵌套的 map)

2. 实现

package pkg

import (
	"errors"
	"reflect"
)

// ErrMapTypeNeeded indicates the type of the provided value is not map
var ErrMapTypeNeeded = errors.New("map type needed")

// ErrTypeIncompatible represents type incompatible
var ErrTypeIncompatible = errors.New("type incompatible")

// Converter is used to convert values, for example converts string to int
type Converter struct {
	InType   reflect.Type
	FromType reflect.Type
	// Func is used to converts a value of FromType to a value of InType.
	// the type of fromValue is FromType;
	// and the type of the first return value must be InType
	Func func(fromValue reflect.Value) (reflect.Value, error)
}

// config represents MergeMap's configuration
type config struct {
	// if key is missing, add it
	addMissingKey bool
	converters    []*Converter
}

// newDefaultConfig news a default config
func newDefaultConfig() *config {
	return &config{
		addMissingKey: true,
	}
}

type Option func(*config)

func WithIgnoreMissingKey() Option {
	return func(c *config) {
		c.addMissingKey = false
	}
}

func WithConverters(converters ...*Converter) Option {
	return func(c *config) {
		c.converters = append(c.converters, converters...)
	}
}

// convertValue converts fromValue to a value of inType
func convertValue(inType reflect.Type, fromValue reflect.Value, converters ...*Converter) (reflect.Value, error) {
	fromType := fromValue.Type()
	if fromType.AssignableTo(inType) {
		return fromValue, nil
	}
	for _, converter := range converters {
		if converter.InType == inType && fromType == converter.FromType {
			convertedValue, err := converter.Func(fromValue)
			if err != nil {
				return reflect.Value{}, err
			}
			// check if type is correct
			if convertedValue.Type() != inType {
				return reflect.Value{}, ErrTypeIncompatible
			}
			return convertedValue, nil
		}
	}
	return reflect.Value{}, ErrTypeIncompatible
}

// MergeMap merges map from into map in. in will be modified in place
func MergeMap(in, from any, options ...Option) (any, error) {
	cfg := newDefaultConfig()
	for _, option := range options {
		option(cfg)
	}

	inValue := reflect.ValueOf(in)
	fromValue := reflect.ValueOf(from)
	if inValue.Kind() != reflect.Map {
		return in, ErrMapTypeNeeded
	}
	if fromValue.Kind() != reflect.Map {
		return in, ErrMapTypeNeeded
	}

	fromIter := fromValue.MapRange()
	for fromIter.Next() {
		fromEntityKey := reflect.ValueOf(fromIter.Key().Interface())
		k, err := convertValue(inValue.Type().Key(), fromEntityKey, cfg.converters...)
		if err != nil {
			return nil, err
		}
		fromEntityValue := reflect.ValueOf(fromIter.Value().Interface())
		inEntityValue := inValue.MapIndex(k)
		var addKV bool
		if inEntityValue.IsValid() {
			inEntityValue := reflect.ValueOf(inEntityValue.Interface())
			switch inEntityValue.Kind() {
			// key is present and value is a map
			case reflect.Map:
				if _, err := MergeMap(inEntityValue.Interface(), fromEntityValue.Interface(), options...); err != nil {
					return nil, err
				}
				continue
			default:
				addKV = true
			}
		}
		if cfg.addMissingKey {
			addKV = true
		}
		if addKV {
			var v reflect.Value
			var err error
			switch inValue.Type().Elem().Kind() {
			// value is slice
			case reflect.Slice:
				if fromEntityValue.Kind() != reflect.Slice {
					return nil, ErrTypeIncompatible
				}
				// if the basic type is not the same, try to convert every element
				if inValue.Type().Elem() != fromValue.Type().Elem() {
					v = reflect.MakeSlice(inValue.Type().Elem(), 0, fromEntityValue.Len())
					for i := 0; i < fromEntityValue.Len(); i++ {
						element, err := convertValue(inValue.Type().Elem().Elem(), fromEntityValue.Index(i), cfg.converters...)
						if err != nil {
							return nil, err
						}
						v = reflect.Append(v, element)
					}
				}
			default:
				v, err = convertValue(inValue.Type().Elem(), fromEntityValue, cfg.converters...)
			}
			if err != nil {
				return nil, err
			}
			inValue.SetMapIndex(k, v)
		}
	}
	return in, nil
}

2.1. 测试用例

package pkg

import (
	"fmt"
	"reflect"
	"strconv"
	"testing"
)

func TestMergeMap(t *testing.T) {
	testCases := []struct {
		in       map[string]any
		from     map[string]any
		options  []Option
		expected map[string]any
	}{
		{
			in: map[string]any{
				"nested": map[int32]string{
					1: "1",
				},
			},
			from: map[string]any{
				"nested": map[string]int32{
					"1": int32(100),
				},
			},
			options: []Option{
				WithConverters(
					// convert string to int32
					&Converter{
						InType:   reflect.TypeOf(int32(0)),
						FromType: reflect.TypeOf(""),
						Func: func(fromValue reflect.Value) (reflect.Value, error) {
							s, ok := fromValue.Interface().(string)
							if !ok {
								return reflect.ValueOf(int32(0)), ErrTypeIncompatible
							}
							i64, err := strconv.ParseInt(s, 10, 64)
							if err != nil {
								return reflect.ValueOf(int32(0)), ErrTypeIncompatible
							}
							return reflect.ValueOf(int32(i64)), nil
						},
					},
					// convert int32 to string
					&Converter{
						InType:   reflect.TypeOf(""),
						FromType: reflect.TypeOf(int32(0)),
						Func: func(fromValue reflect.Value) (reflect.Value, error) {
							i, ok := fromValue.Interface().(int32)
							if !ok {
								return reflect.ValueOf(""), ErrTypeIncompatible
							}
							return reflect.ValueOf(fmt.Sprintf("%d", i)), nil
						},
					},
				),
			},
			expected: map[string]any{
				"nested": map[int32]string{
					1: "100",
				},
			},
		},
		{
			in: map[string]any{
				// field a will be replaced
				"a": "a_in",
				// field b will be reserved
				"b": "b_in",
				// field c will be replaced
				"c": []int{1, 2, 3},
				// field d will be merged
				"d": map[string]any{
					"e": "e_in",
					"f": "f_in",
					"g": map[string]any{
						// field d.g.h will be replaced
						"h": "h_in",
					},
				},
			},
			from: map[string]any{
				"a": "a_from",
				"c": []int{4, 5, 6},
				"d": map[string]any{
					"e": "e_from",
					"g": map[string]any{
						"h": "h_from",
						// field d.g.j will be added
						"j": "j_from",
					},
				},
			},
			expected: map[string]any{
				"a": "a_from",
				"b": "b_in",
				"c": []int{4, 5, 6},
				"d": map[string]any{
					"e": "e_from",
					"f": "f_in",
					"g": map[string]any{
						"h": "h_from",
						"j": "j_from",
					},
				},
			},
		},
		{
			in: map[string]any{
				// field a will be reserved
				"a": "a_in",
				// field c will be replaced
				"c": []int{1, 2, 3},
			},
			from: map[string]any{
				"b": "b_from",
				"c": []int{4, 5, 6},
			},
			options: []Option{
				// field b will be ignored
				WithIgnoreMissingKey(),
			},
			expected: map[string]any{
				"a": "a_in",
				"c": []int{4, 5, 6},
			},
		},
	}
	for _, testCase := range testCases {
		if _, err := MergeMap(testCase.in, testCase.from, testCase.options...); err != nil {
			t.Fatalf("failed to merge map: %v\n", err)
		}
		if !reflect.DeepEqual(testCase.in, testCase.expected) {
			t.Fatalf("expected: %v, got: %v", testCase.expected, testCase.in)
		}
	}
}

2.2. 类型转换示例

package main

import (
	"fmt"
	"reflect"
	"strconv"
	"test/pkg"
)

func main() {
	m1 := map[string][]int{
		// the entity is reserved
		"one": {1, 2, 3},
		// the value will be replaced by the associated value in m2
		"three": {7, 8, 9},
	}
	m2 := map[any]any{
		// string slice will be converted to int slice
		"two": []string{"4", "5", "6"},
		// string slice will be converted to int slice, and replace the slice in m1
		"three": []string{"10", "11", "12"},
		// int will be converted to string
		4: []int{13, 14, 15},
		// the entity will be added to m1
		"five": []int{16, 17, 18},
	}
	_, err := pkg.MergeMap(
		m1, m2, pkg.WithConverters(
			// converts int to string
			&pkg.Converter{
				InType:   reflect.TypeOf(""),
				FromType: reflect.TypeOf(0),
				Func: func(fromValue reflect.Value) (reflect.Value, error) {
					i, ok := fromValue.Interface().(int)
					if !ok {
						return reflect.ValueOf(""), nil
					}
					return reflect.ValueOf(strconv.FormatInt(int64(i), 10)), nil
				},
			},
			// converts string to int
			&pkg.Converter{
				InType:   reflect.TypeOf(0),
				FromType: reflect.TypeOf(""),
				Func: func(fromValue reflect.Value) (reflect.Value, error) {
					s, ok := fromValue.Interface().(string)
					if !ok {
						return reflect.ValueOf(""), pkg.ErrTypeIncompatible
					}
					i64, err := strconv.ParseInt(s, 10, 64)
					if err != nil {
						return reflect.ValueOf(""), err
					}
					return reflect.ValueOf(int(i64)), nil
				},
			},
		),
	)
	if err != nil {
		panic(fmt.Sprintf("failed to merge map: %v", err))
	}
	fmt.Println(m1)
}