map.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. // Copyright 2014 Dario Castañé. All rights reserved.
  2. // Copyright 2009 The Go Authors. All rights reserved.
  3. // Use of this source code is governed by a BSD-style
  4. // license that can be found in the LICENSE file.
  5. // Based on src/pkg/reflect/deepequal.go from official
  6. // golang's stdlib.
  7. package mergo
  8. import (
  9. "fmt"
  10. "reflect"
  11. "unicode"
  12. "unicode/utf8"
  13. )
  14. func changeInitialCase(s string, mapper func(rune) rune) string {
  15. if s == "" {
  16. return s
  17. }
  18. r, n := utf8.DecodeRuneInString(s)
  19. return string(mapper(r)) + s[n:]
  20. }
  21. func isExported(field reflect.StructField) bool {
  22. r, _ := utf8.DecodeRuneInString(field.Name)
  23. return r >= 'A' && r <= 'Z'
  24. }
  25. // Traverses recursively both values, assigning src's fields values to dst.
  26. // The map argument tracks comparisons that have already been seen, which allows
  27. // short circuiting on recursive types.
  28. func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, overwrite bool) (err error) {
  29. if dst.CanAddr() {
  30. addr := dst.UnsafeAddr()
  31. h := 17 * addr
  32. seen := visited[h]
  33. typ := dst.Type()
  34. for p := seen; p != nil; p = p.next {
  35. if p.ptr == addr && p.typ == typ {
  36. return nil
  37. }
  38. }
  39. // Remember, remember...
  40. visited[h] = &visit{addr, typ, seen}
  41. }
  42. zeroValue := reflect.Value{}
  43. switch dst.Kind() {
  44. case reflect.Map:
  45. dstMap := dst.Interface().(map[string]interface{})
  46. for i, n := 0, src.NumField(); i < n; i++ {
  47. srcType := src.Type()
  48. field := srcType.Field(i)
  49. if !isExported(field) {
  50. continue
  51. }
  52. fieldName := field.Name
  53. fieldName = changeInitialCase(fieldName, unicode.ToLower)
  54. if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v)) || overwrite) {
  55. dstMap[fieldName] = src.Field(i).Interface()
  56. }
  57. }
  58. case reflect.Struct:
  59. srcMap := src.Interface().(map[string]interface{})
  60. for key := range srcMap {
  61. srcValue := srcMap[key]
  62. fieldName := changeInitialCase(key, unicode.ToUpper)
  63. dstElement := dst.FieldByName(fieldName)
  64. if dstElement == zeroValue {
  65. // We discard it because the field doesn't exist.
  66. continue
  67. }
  68. srcElement := reflect.ValueOf(srcValue)
  69. dstKind := dstElement.Kind()
  70. srcKind := srcElement.Kind()
  71. if srcKind == reflect.Ptr && dstKind != reflect.Ptr {
  72. srcElement = srcElement.Elem()
  73. srcKind = reflect.TypeOf(srcElement.Interface()).Kind()
  74. } else if dstKind == reflect.Ptr {
  75. // Can this work? I guess it can't.
  76. if srcKind != reflect.Ptr && srcElement.CanAddr() {
  77. srcPtr := srcElement.Addr()
  78. srcElement = reflect.ValueOf(srcPtr)
  79. srcKind = reflect.Ptr
  80. }
  81. }
  82. if !srcElement.IsValid() {
  83. continue
  84. }
  85. if srcKind == dstKind {
  86. if err = deepMerge(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
  87. return
  88. }
  89. } else {
  90. if srcKind == reflect.Map {
  91. if err = deepMap(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
  92. return
  93. }
  94. } else {
  95. return fmt.Errorf("type mismatch on %s field: found %v, expected %v", fieldName, srcKind, dstKind)
  96. }
  97. }
  98. }
  99. }
  100. return
  101. }
  102. // Map sets fields' values in dst from src.
  103. // src can be a map with string keys or a struct. dst must be the opposite:
  104. // if src is a map, dst must be a valid pointer to struct. If src is a struct,
  105. // dst must be map[string]interface{}.
  106. // It won't merge unexported (private) fields and will do recursively
  107. // any exported field.
  108. // If dst is a map, keys will be src fields' names in lower camel case.
  109. // Missing key in src that doesn't match a field in dst will be skipped. This
  110. // doesn't apply if dst is a map.
  111. // This is separated method from Merge because it is cleaner and it keeps sane
  112. // semantics: merging equal types, mapping different (restricted) types.
  113. func Map(dst, src interface{}) error {
  114. return _map(dst, src, false)
  115. }
  116. func MapWithOverwrite(dst, src interface{}) error {
  117. return _map(dst, src, true)
  118. }
  119. func _map(dst, src interface{}, overwrite bool) error {
  120. var (
  121. vDst, vSrc reflect.Value
  122. err error
  123. )
  124. if vDst, vSrc, err = resolveValues(dst, src); err != nil {
  125. return err
  126. }
  127. // To be friction-less, we redirect equal-type arguments
  128. // to deepMerge. Only because arguments can be anything.
  129. if vSrc.Kind() == vDst.Kind() {
  130. return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
  131. }
  132. switch vSrc.Kind() {
  133. case reflect.Struct:
  134. if vDst.Kind() != reflect.Map {
  135. return ErrExpectedMapAsDestination
  136. }
  137. case reflect.Map:
  138. if vDst.Kind() != reflect.Struct {
  139. return ErrExpectedStructAsDestination
  140. }
  141. default:
  142. return ErrNotSupported
  143. }
  144. return deepMap(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
  145. }