tree_test.py 47 KB


  1. # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for utilities working with arbitrarily nested structures."""
  16. import collections
  17. import doctest
  18. import types
  19. from typing import Any, Iterator, Mapping
  20. import unittest
  21. from absl.testing import parameterized
  22. import attr
  23. import numpy as np
  24. import tree
  25. import wrapt
  26. STRUCTURE1 = (((1, 2), 3), 4, (5, 6))
  27. STRUCTURE2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  28. STRUCTURE_DIFFERENT_NUM_ELEMENTS = ("spam", "eggs")
  29. STRUCTURE_DIFFERENT_NESTING = (((1, 2), 3), 4, 5, (6,))
  30. class DoctestTest(parameterized.TestCase):
  31. def testDoctest(self):
  32. extraglobs = {
  33. "collections": collections,
  34. "tree": tree,
  35. }
  36. num_failed, num_attempted = doctest.testmod(
  37. tree, extraglobs=extraglobs, optionflags=doctest.ELLIPSIS)
  38. self.assertGreater(num_attempted, 0, "No doctests found.")
  39. self.assertEqual(num_failed, 0, "{} doctests failed".format(num_failed))
  40. class NestTest(parameterized.TestCase):
  41. def assertAllEquals(self, a, b):
  42. self.assertTrue((np.asarray(a) == b).all())
  43. def testAttrsFlattenAndUnflatten(self):
  44. class BadAttr(object):
  45. """Class that has a non-iterable __attrs_attrs__."""
  46. __attrs_attrs__ = None
  47. @attr.s
  48. class SampleAttr(object):
  49. field1 = attr.ib()
  50. field2 = attr.ib()
  51. field_values = [1, 2]
  52. sample_attr = SampleAttr(*field_values)
  53. self.assertFalse(tree._is_attrs(field_values))
  54. self.assertTrue(tree._is_attrs(sample_attr))
  55. flat = tree.flatten(sample_attr)
  56. self.assertEqual(field_values, flat)
  57. restructured_from_flat = tree.unflatten_as(sample_attr, flat)
  58. self.assertIsInstance(restructured_from_flat, SampleAttr)
  59. self.assertEqual(restructured_from_flat, sample_attr)
  60. # Check that flatten fails if attributes are not iterable
  61. with self.assertRaisesRegex(TypeError, "object is not iterable"):
  62. flat = tree.flatten(BadAttr())
  63. @parameterized.parameters([
  64. (1, 2, 3),
  65. ({"B": 10, "A": 20}, [1, 2], 3),
  66. ((1, 2), [3, 4], 5),
  67. (collections.namedtuple("Point", ["x", "y"])(1, 2), 3, 4),
  68. wrapt.ObjectProxy(
  69. (collections.namedtuple("Point", ["x", "y"])(1, 2), 3, 4))
  70. ])
  71. def testAttrsMapStructure(self, *field_values):
  72. @attr.s
  73. class SampleAttr(object):
  74. field3 = attr.ib()
  75. field1 = attr.ib()
  76. field2 = attr.ib()
  77. structure = SampleAttr(*field_values)
  78. new_structure = tree.map_structure(lambda x: x, structure)
  79. self.assertEqual(structure, new_structure)
  80. def testFlattenAndUnflatten(self):
  81. structure = ((3, 4), 5, (6, 7, (9, 10), 8))
  82. flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
  83. self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
  84. self.assertEqual(
  85. tree.unflatten_as(structure, flat),
  86. (("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
  87. point = collections.namedtuple("Point", ["x", "y"])
  88. structure = (point(x=4, y=2), ((point(x=1, y=0),),))
  89. flat = [4, 2, 1, 0]
  90. self.assertEqual(tree.flatten(structure), flat)
  91. restructured_from_flat = tree.unflatten_as(structure, flat)
  92. self.assertEqual(restructured_from_flat, structure)
  93. self.assertEqual(restructured_from_flat[0].x, 4)
  94. self.assertEqual(restructured_from_flat[0].y, 2)
  95. self.assertEqual(restructured_from_flat[1][0][0].x, 1)
  96. self.assertEqual(restructured_from_flat[1][0][0].y, 0)
  97. self.assertEqual([5], tree.flatten(5))
  98. self.assertEqual([np.array([5])], tree.flatten(np.array([5])))
  99. self.assertEqual("a", tree.unflatten_as(5, ["a"]))
  100. self.assertEqual(
  101. np.array([5]), tree.unflatten_as("scalar", [np.array([5])]))
  102. with self.assertRaisesRegex(ValueError, "Structure is a scalar"):
  103. tree.unflatten_as("scalar", [4, 5])
  104. with self.assertRaisesRegex(TypeError, "flat_sequence"):
  105. tree.unflatten_as([4, 5], "bad_sequence")
  106. with self.assertRaises(ValueError):
  107. tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"])
  108. def testFlattenDictOrder(self):
  109. ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
  110. plain = {"d": 3, "b": 1, "a": 0, "c": 2}
  111. ordered_flat = tree.flatten(ordered)
  112. plain_flat = tree.flatten(plain)
  113. self.assertEqual([0, 1, 2, 3], ordered_flat)
  114. self.assertEqual([0, 1, 2, 3], plain_flat)
  115. def testUnflattenDictOrder(self):
  116. ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
  117. plain = {"d": 0, "b": 0, "a": 0, "c": 0}
  118. seq = [0, 1, 2, 3]
  119. ordered_reconstruction = tree.unflatten_as(ordered, seq)
  120. plain_reconstruction = tree.unflatten_as(plain, seq)
  121. self.assertEqual(
  122. collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
  123. ordered_reconstruction)
  124. self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
  125. def testFlattenAndUnflatten_withDicts(self):
  126. # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
  127. named_tuple = collections.namedtuple("A", ("b", "c"))
  128. mess = [
  129. "z",
  130. named_tuple(3, 4),
  131. {
  132. "c": [
  133. 1,
  134. collections.OrderedDict([
  135. ("b", 3),
  136. ("a", 2),
  137. ]),
  138. ],
  139. "b": 5
  140. },
  141. 17
  142. ]
  143. flattened = tree.flatten(mess)
  144. self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
  145. structure_of_mess = [
  146. 14,
  147. named_tuple("a", True),
  148. {
  149. "c": [
  150. 0,
  151. collections.OrderedDict([
  152. ("b", 9),
  153. ("a", 8),
  154. ]),
  155. ],
  156. "b": 3
  157. },
  158. "hi everybody",
  159. ]
  160. self.assertEqual(mess, tree.unflatten_as(structure_of_mess, flattened))
  161. # Check also that the OrderedDict was created, with the correct key order.
  162. unflattened_ordered_dict = tree.unflatten_as(
  163. structure_of_mess, flattened)[2]["c"][1]
  164. self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
  165. self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
  166. def testFlatten_numpyIsNotFlattened(self):
  167. structure = np.array([1, 2, 3])
  168. flattened = tree.flatten(structure)
  169. self.assertLen(flattened, 1)
  170. def testFlatten_stringIsNotFlattened(self):
  171. structure = "lots of letters"
  172. flattened = tree.flatten(structure)
  173. self.assertLen(flattened, 1)
  174. self.assertEqual(structure, tree.unflatten_as("goodbye", flattened))
  175. def testFlatten_bytearrayIsNotFlattened(self):
  176. structure = bytearray("bytes in an array", "ascii")
  177. flattened = tree.flatten(structure)
  178. self.assertLen(flattened, 1)
  179. self.assertEqual(flattened, [structure])
  180. self.assertEqual(structure,
  181. tree.unflatten_as(bytearray("hello", "ascii"), flattened))
  182. def testUnflattenSequenceAs_notIterableError(self):
  183. with self.assertRaisesRegex(TypeError, "flat_sequence must be a sequence"):
  184. tree.unflatten_as("hi", "bye")
  185. def testUnflattenSequenceAs_wrongLengthsError(self):
  186. with self.assertRaisesRegex(
  187. ValueError,
  188. "Structure had 2 elements, but flat_sequence had 3 elements."):
  189. tree.unflatten_as(["hello", "world"], ["and", "goodbye", "again"])
  190. def testUnflattenSequenceAs_defaultdict(self):
  191. structure = collections.defaultdict(
  192. list, [("a", [None]), ("b", [None, None])])
  193. sequence = [1, 2, 3]
  194. expected = collections.defaultdict(
  195. list, [("a", [1]), ("b", [2, 3])])
  196. self.assertEqual(expected, tree.unflatten_as(structure, sequence))
  197. def testIsSequence(self):
  198. self.assertFalse(tree.is_nested("1234"))
  199. self.assertFalse(tree.is_nested(b"1234"))
  200. self.assertFalse(tree.is_nested(u"1234"))
  201. self.assertFalse(tree.is_nested(bytearray("1234", "ascii")))
  202. self.assertTrue(tree.is_nested([1, 3, [4, 5]]))
  203. self.assertTrue(tree.is_nested(((7, 8), (5, 6))))
  204. self.assertTrue(tree.is_nested([]))
  205. self.assertTrue(tree.is_nested({"a": 1, "b": 2}))
  206. self.assertFalse(tree.is_nested(set([1, 2])))
  207. ones = np.ones([2, 3])
  208. self.assertFalse(tree.is_nested(ones))
  209. self.assertFalse(tree.is_nested(np.tanh(ones)))
  210. self.assertFalse(tree.is_nested(np.ones((4, 5))))
  211. # pylint does not correctly recognize these as class names and
  212. # suggests to use variable style under_score naming.
  213. # pylint: disable=invalid-name
  214. Named0ab = collections.namedtuple("named_0", ("a", "b"))
  215. Named1ab = collections.namedtuple("named_1", ("a", "b"))
  216. SameNameab = collections.namedtuple("same_name", ("a", "b"))
  217. SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
  218. SameNamexy = collections.namedtuple("same_name", ("x", "y"))
  219. SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
  220. SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
  221. NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
  222. # pylint: enable=invalid-name
  223. class SameNamedType1(SameNameab):
  224. pass
  225. # pylint: disable=g-error-prone-assert-raises
  226. def testAssertSameStructure(self):
  227. tree.assert_same_structure(STRUCTURE1, STRUCTURE2)
  228. tree.assert_same_structure("abc", 1.0)
  229. tree.assert_same_structure(b"abc", 1.0)
  230. tree.assert_same_structure(u"abc", 1.0)
  231. tree.assert_same_structure(bytearray("abc", "ascii"), 1.0)
  232. tree.assert_same_structure("abc", np.array([0, 1]))
  233. def testAssertSameStructure_differentNumElements(self):
  234. with self.assertRaisesRegex(
  235. ValueError,
  236. ("The two structures don't have the same nested structure\\.\n\n"
  237. "First structure:.*?\n\n"
  238. "Second structure:.*\n\n"
  239. "More specifically: Substructure "
  240. r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
  241. 'substructure "type=str str=spam" is not\n'
  242. "Entire first structure:\n"
  243. r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
  244. "Entire second structure:\n"
  245. r"\(\., \.\)")):
  246. tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS)
  247. def testAssertSameStructure_listVsNdArray(self):
  248. with self.assertRaisesRegex(
  249. ValueError,
  250. ("The two structures don't have the same nested structure\\.\n\n"
  251. "First structure:.*?\n\n"
  252. "Second structure:.*\n\n"
  253. r'More specifically: Substructure "type=list str=\[0, 1\]" '
  254. r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
  255. "is not")):
  256. tree.assert_same_structure([0, 1], np.array([0, 1]))
  257. def testAssertSameStructure_intVsList(self):
  258. with self.assertRaisesRegex(
  259. ValueError,
  260. ("The two structures don't have the same nested structure\\.\n\n"
  261. "First structure:.*?\n\n"
  262. "Second structure:.*\n\n"
  263. r'More specifically: Substructure "type=list str=\[0, 1\]" '
  264. 'is a sequence, while substructure "type=int str=0" '
  265. "is not")):
  266. tree.assert_same_structure(0, [0, 1])
  267. def testAssertSameStructure_tupleVsList(self):
  268. self.assertRaises(
  269. TypeError, tree.assert_same_structure, (0, 1), [0, 1])
  270. def testAssertSameStructure_differentNesting(self):
  271. with self.assertRaisesRegex(
  272. ValueError,
  273. ("don't have the same nested structure\\.\n\n"
  274. "First structure: .*?\n\nSecond structure: ")):
  275. tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING)
  276. def testAssertSameStructure_tupleVsNamedTuple(self):
  277. self.assertRaises(TypeError, tree.assert_same_structure, (0, 1),
  278. NestTest.Named0ab("a", "b"))
  279. def testAssertSameStructure_sameNamedTupleDifferentContents(self):
  280. tree.assert_same_structure(NestTest.Named0ab(3, 4),
  281. NestTest.Named0ab("a", "b"))
  282. def testAssertSameStructure_differentNamedTuples(self):
  283. self.assertRaises(TypeError, tree.assert_same_structure,
  284. NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
  285. def testAssertSameStructure_sameNamedTupleDifferentStructuredContents(self):
  286. with self.assertRaisesRegex(
  287. ValueError,
  288. ("don't have the same nested structure\\.\n\n"
  289. "First structure: .*?\n\nSecond structure: ")):
  290. tree.assert_same_structure(NestTest.Named0ab(3, 4),
  291. NestTest.Named0ab([3], 4))
  292. def testAssertSameStructure_differentlyNestedLists(self):
  293. with self.assertRaisesRegex(
  294. ValueError,
  295. ("don't have the same nested structure\\.\n\n"
  296. "First structure: .*?\n\nSecond structure: ")):
  297. tree.assert_same_structure([[3], 4], [3, [4]])
  298. def testAssertSameStructure_listStructureWithAndWithoutTypes(self):
  299. structure1_list = [[[1, 2], 3], 4, [5, 6]]
  300. with self.assertRaisesRegex(TypeError, "don't have the same sequence type"):
  301. tree.assert_same_structure(STRUCTURE1, structure1_list)
  302. tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False)
  303. tree.assert_same_structure(STRUCTURE1, structure1_list, check_types=False)
  304. def testAssertSameStructure_dictionaryDifferentKeys(self):
  305. with self.assertRaisesRegex(ValueError, "don't have the same set of keys"):
  306. tree.assert_same_structure({"a": 1}, {"b": 1})
  307. def testAssertSameStructure_sameNameNamedTuples(self):
  308. tree.assert_same_structure(NestTest.SameNameab(0, 1),
  309. NestTest.SameNameab2(2, 3))
  310. def testAssertSameStructure_sameNameNamedTuplesNested(self):
  311. # This assertion is expected to pass: two namedtuples with the same
  312. # name and field names are considered to be identical.
  313. tree.assert_same_structure(
  314. NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
  315. NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
  316. def testAssertSameStructure_sameNameNamedTuplesDifferentStructure(self):
  317. expected_message = "The two structures don't have the same.*"
  318. with self.assertRaisesRegex(ValueError, expected_message):
  319. tree.assert_same_structure(
  320. NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
  321. NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
  322. def testAssertSameStructure_differentNameNamedStructures(self):
  323. self.assertRaises(TypeError, tree.assert_same_structure,
  324. NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
  325. def testAssertSameStructure_sameNameDifferentFieldNames(self):
  326. self.assertRaises(TypeError, tree.assert_same_structure,
  327. NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
  328. def testAssertSameStructure_classWrappingNamedTuple(self):
  329. self.assertRaises(TypeError, tree.assert_same_structure,
  330. NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
  331. # pylint: enable=g-error-prone-assert-raises
  332. def testMapStructure(self):
  333. structure2 = (((7, 8), 9), 10, (11, 12))
  334. structure1_plus1 = tree.map_structure(lambda x: x + 1, STRUCTURE1)
  335. tree.assert_same_structure(STRUCTURE1, structure1_plus1)
  336. self.assertAllEquals(
  337. [2, 3, 4, 5, 6, 7],
  338. tree.flatten(structure1_plus1))
  339. structure1_plus_structure2 = tree.map_structure(
  340. lambda x, y: x + y, STRUCTURE1, structure2)
  341. self.assertEqual(
  342. (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
  343. structure1_plus_structure2)
  344. self.assertEqual(3, tree.map_structure(lambda x: x - 1, 4))
  345. self.assertEqual(7, tree.map_structure(lambda x, y: x + y, 3, 4))
  346. # Empty structures
  347. self.assertEqual((), tree.map_structure(lambda x: x + 1, ()))
  348. self.assertEqual([], tree.map_structure(lambda x: x + 1, []))
  349. self.assertEqual({}, tree.map_structure(lambda x: x + 1, {}))
  350. empty_nt = collections.namedtuple("empty_nt", "")
  351. self.assertEqual(empty_nt(), tree.map_structure(lambda x: x + 1,
  352. empty_nt()))
  353. # This is checking actual equality of types, empty list != empty tuple
  354. self.assertNotEqual((), tree.map_structure(lambda x: x + 1, []))
  355. with self.assertRaisesRegex(TypeError, "callable"):
  356. tree.map_structure("bad", structure1_plus1)
  357. with self.assertRaisesRegex(ValueError, "at least one structure"):
  358. tree.map_structure(lambda x: x)
  359. with self.assertRaisesRegex(ValueError, "same number of elements"):
  360. tree.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
  361. with self.assertRaisesRegex(ValueError, "same nested structure"):
  362. tree.map_structure(lambda x, y: None, 3, (3,))
  363. with self.assertRaisesRegex(TypeError, "same sequence type"):
  364. tree.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
  365. with self.assertRaisesRegex(ValueError, "same nested structure"):
  366. tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
  367. structure1_list = [[[1, 2], 3], 4, [5, 6]]
  368. with self.assertRaisesRegex(TypeError, "same sequence type"):
  369. tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list)
  370. tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list,
  371. check_types=False)
  372. with self.assertRaisesRegex(ValueError, "same nested structure"):
  373. tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
  374. check_types=False)
  375. with self.assertRaisesRegex(ValueError, "Only valid keyword argument.*foo"):
  376. tree.map_structure(lambda x: None, STRUCTURE1, foo="a")
  377. with self.assertRaisesRegex(ValueError, "Only valid keyword argument.*foo"):
  378. tree.map_structure(lambda x: None, STRUCTURE1, check_types=False, foo="a")
  379. def testMapStructureWithStrings(self):
  380. ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  381. inp_a = ab_tuple(a="foo", b=("bar", "baz"))
  382. inp_b = ab_tuple(a=2, b=(1, 3))
  383. out = tree.map_structure(lambda string, repeats: string * repeats,
  384. inp_a,
  385. inp_b)
  386. self.assertEqual("foofoo", out.a)
  387. self.assertEqual("bar", out.b[0])
  388. self.assertEqual("bazbazbaz", out.b[1])
  389. nt = ab_tuple(a=("something", "something_else"),
  390. b="yet another thing")
  391. rev_nt = tree.map_structure(lambda x: x[::-1], nt)
  392. # Check the output is the correct structure, and all strings are reversed.
  393. tree.assert_same_structure(nt, rev_nt)
  394. self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
  395. self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
  396. self.assertEqual(nt.b[::-1], rev_nt.b)
  397. def testAssertShallowStructure(self):
  398. inp_ab = ["a", "b"]
  399. inp_abc = ["a", "b", "c"]
  400. with self.assertRaisesRegex(
  401. ValueError,
  402. tree._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
  403. input_length=len(inp_ab),
  404. shallow_length=len(inp_abc))):
  405. tree._assert_shallow_structure(inp_abc, inp_ab)
  406. inp_ab1 = [(1, 1), (2, 2)]
  407. inp_ab2 = [[1, 1], [2, 2]]
  408. with self.assertRaisesWithLiteralMatch(
  409. TypeError,
  410. tree._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
  411. shallow_type=type(inp_ab2[0]),
  412. input_type=type(inp_ab1[0]))):
  413. tree._assert_shallow_structure(shallow_tree=inp_ab2, input_tree=inp_ab1)
  414. tree._assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
  415. inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
  416. inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
  417. with self.assertRaisesWithLiteralMatch(
  418. ValueError,
  419. tree._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])):
  420. tree._assert_shallow_structure(inp_ab2, inp_ab1)
  421. inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
  422. inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
  423. tree._assert_shallow_structure(inp_ab, inp_ba)
  424. # regression test for b/130633904
  425. tree._assert_shallow_structure({0: "foo"}, ["foo"], check_types=False)
  426. def testFlattenUpTo(self):
  427. # Shallow tree ends at scalar.
  428. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
  429. shallow_tree = [[True, True], [False, True]]
  430. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  431. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  432. self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
  433. self.assertEqual(flattened_shallow_tree, [True, True, False, True])
  434. # Shallow tree ends at string.
  435. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
  436. shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
  437. input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
  438. input_tree)
  439. input_tree_flattened = tree.flatten(input_tree)
  440. self.assertEqual(input_tree_flattened_as_shallow_tree,
  441. [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
  442. self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
  443. # Make sure dicts are correctly flattened, yielding values, not keys.
  444. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
  445. shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
  446. input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
  447. input_tree)
  448. self.assertEqual(input_tree_flattened_as_shallow_tree,
  449. [1, {"c": 2}, 3, (4, 5)])
  450. # Namedtuples.
  451. ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  452. input_tree = ab_tuple(a=[0, 1], b=2)
  453. shallow_tree = ab_tuple(a=0, b=1)
  454. input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
  455. input_tree)
  456. self.assertEqual(input_tree_flattened_as_shallow_tree,
  457. [[0, 1], 2])
  458. # Attrs.
  459. @attr.s
  460. class ABAttr(object):
  461. a = attr.ib()
  462. b = attr.ib()
  463. input_tree = ABAttr(a=[0, 1], b=2)
  464. shallow_tree = ABAttr(a=0, b=1)
  465. input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
  466. input_tree)
  467. self.assertEqual(input_tree_flattened_as_shallow_tree,
  468. [[0, 1], 2])
  469. # Nested dicts, OrderedDicts and namedtuples.
  470. input_tree = collections.OrderedDict(
  471. [("a", ab_tuple(a=[0, {"b": 1}], b=2)),
  472. ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
  473. shallow_tree = input_tree
  474. input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
  475. input_tree)
  476. self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
  477. shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
  478. input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
  479. input_tree)
  480. self.assertEqual(input_tree_flattened_as_shallow_tree,
  481. [ab_tuple(a=[0, {"b": 1}], b=2),
  482. 3,
  483. collections.OrderedDict([("f", 4)])])
  484. shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
  485. input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
  486. input_tree)
  487. self.assertEqual(input_tree_flattened_as_shallow_tree,
  488. [ab_tuple(a=[0, {"b": 1}], b=2),
  489. {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
  490. ## Shallow non-list edge-case.
  491. # Using iterable elements.
  492. input_tree = ["input_tree"]
  493. shallow_tree = "shallow_tree"
  494. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  495. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  496. self.assertEqual(flattened_input_tree, [input_tree])
  497. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  498. input_tree = ["input_tree_0", "input_tree_1"]
  499. shallow_tree = "shallow_tree"
  500. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  501. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  502. self.assertEqual(flattened_input_tree, [input_tree])
  503. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  504. # Using non-iterable elements.
  505. input_tree = [0]
  506. shallow_tree = 9
  507. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  508. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  509. self.assertEqual(flattened_input_tree, [input_tree])
  510. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  511. input_tree = [0, 1]
  512. shallow_tree = 9
  513. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  514. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  515. self.assertEqual(flattened_input_tree, [input_tree])
  516. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  517. ## Both non-list edge-case.
  518. # Using iterable elements.
  519. input_tree = "input_tree"
  520. shallow_tree = "shallow_tree"
  521. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  522. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  523. self.assertEqual(flattened_input_tree, [input_tree])
  524. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  525. # Using non-iterable elements.
  526. input_tree = 0
  527. shallow_tree = 0
  528. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  529. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  530. self.assertEqual(flattened_input_tree, [input_tree])
  531. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  532. ## Input non-list edge-case.
  533. # Using iterable elements.
  534. input_tree = "input_tree"
  535. shallow_tree = ["shallow_tree"]
  536. with self.assertRaisesWithLiteralMatch(
  537. TypeError,
  538. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
  539. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  540. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  541. self.assertEqual(flattened_shallow_tree, shallow_tree)
  542. input_tree = "input_tree"
  543. shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
  544. with self.assertRaisesWithLiteralMatch(
  545. TypeError,
  546. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
  547. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  548. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  549. self.assertEqual(flattened_shallow_tree, shallow_tree)
  550. # Using non-iterable elements.
  551. input_tree = 0
  552. shallow_tree = [9]
  553. with self.assertRaisesWithLiteralMatch(
  554. TypeError,
  555. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
  556. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  557. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  558. self.assertEqual(flattened_shallow_tree, shallow_tree)
  559. input_tree = 0
  560. shallow_tree = [9, 8]
  561. with self.assertRaisesWithLiteralMatch(
  562. TypeError,
  563. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
  564. flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
  565. flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
  566. self.assertEqual(flattened_shallow_tree, shallow_tree)
  567. def testByteStringsNotTreatedAsIterable(self):
  568. structure = [u"unicode string", b"byte string"]
  569. flattened_structure = tree.flatten_up_to(structure, structure)
  570. self.assertEqual(structure, flattened_structure)
  571. def testFlattenWithPathUpTo(self):
  572. def get_paths_and_values(shallow_tree, input_tree):
  573. path_value_pairs = tree.flatten_with_path_up_to(shallow_tree, input_tree)
  574. paths = [p for p, _ in path_value_pairs]
  575. values = [v for _, v in path_value_pairs]
  576. return paths, values
  577. # Shallow tree ends at scalar.
  578. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
  579. shallow_tree = [[True, True], [False, True]]
  580. (flattened_input_tree_paths,
  581. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  582. (flattened_shallow_tree_paths,
  583. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  584. self.assertEqual(flattened_input_tree_paths,
  585. [(0, 0), (0, 1), (1, 0), (1, 1)])
  586. self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
  587. self.assertEqual(flattened_shallow_tree_paths,
  588. [(0, 0), (0, 1), (1, 0), (1, 1)])
  589. self.assertEqual(flattened_shallow_tree, [True, True, False, True])
  590. # Shallow tree ends at string.
  591. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
  592. shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
  593. (input_tree_flattened_as_shallow_tree_paths,
  594. input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
  595. input_tree)
  596. input_tree_flattened_paths = [
  597. p for p, _ in tree.flatten_with_path(input_tree)
  598. ]
  599. input_tree_flattened = tree.flatten(input_tree)
  600. self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
  601. [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)])
  602. self.assertEqual(input_tree_flattened_as_shallow_tree,
  603. [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
  604. self.assertEqual(input_tree_flattened_paths,
  605. [(0, 0, 0), (0, 0, 1),
  606. (0, 1, 0, 0), (0, 1, 0, 1),
  607. (0, 1, 1, 0, 0), (0, 1, 1, 0, 1),
  608. (0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)])
  609. self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
  610. # Make sure dicts are correctly flattened, yielding values, not keys.
  611. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
  612. shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
  613. (input_tree_flattened_as_shallow_tree_paths,
  614. input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
  615. input_tree)
  616. self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
  617. [("a",), ("b",), ("d", 0), ("d", 1)])
  618. self.assertEqual(input_tree_flattened_as_shallow_tree,
  619. [1, {"c": 2}, 3, (4, 5)])
  620. # Namedtuples.
  621. ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  622. input_tree = ab_tuple(a=[0, 1], b=2)
  623. shallow_tree = ab_tuple(a=0, b=1)
  624. (input_tree_flattened_as_shallow_tree_paths,
  625. input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
  626. input_tree)
  627. self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
  628. [("a",), ("b",)])
  629. self.assertEqual(input_tree_flattened_as_shallow_tree,
  630. [[0, 1], 2])
  631. # Nested dicts, OrderedDicts and namedtuples.
  632. input_tree = collections.OrderedDict(
  633. [("a", ab_tuple(a=[0, {"b": 1}], b=2)),
  634. ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
  635. shallow_tree = input_tree
  636. (input_tree_flattened_as_shallow_tree_paths,
  637. input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
  638. input_tree)
  639. self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
  640. [("a", "a", 0),
  641. ("a", "a", 1, "b"),
  642. ("a", "b"),
  643. ("c", "d"),
  644. ("c", "e", "f")])
  645. self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
  646. shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
  647. (input_tree_flattened_as_shallow_tree_paths,
  648. input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
  649. input_tree)
  650. self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
  651. [("a",),
  652. ("c", "d"),
  653. ("c", "e")])
  654. self.assertEqual(input_tree_flattened_as_shallow_tree,
  655. [ab_tuple(a=[0, {"b": 1}], b=2),
  656. 3,
  657. collections.OrderedDict([("f", 4)])])
  658. shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
  659. (input_tree_flattened_as_shallow_tree_paths,
  660. input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
  661. input_tree)
  662. self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
  663. [("a",), ("c",)])
  664. self.assertEqual(input_tree_flattened_as_shallow_tree,
  665. [ab_tuple(a=[0, {"b": 1}], b=2),
  666. {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
  667. ## Shallow non-list edge-case.
  668. # Using iterable elements.
  669. input_tree = ["input_tree"]
  670. shallow_tree = "shallow_tree"
  671. (flattened_input_tree_paths,
  672. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  673. (flattened_shallow_tree_paths,
  674. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  675. self.assertEqual(flattened_input_tree_paths, [()])
  676. self.assertEqual(flattened_input_tree, [input_tree])
  677. self.assertEqual(flattened_shallow_tree_paths, [()])
  678. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  679. input_tree = ["input_tree_0", "input_tree_1"]
  680. shallow_tree = "shallow_tree"
  681. (flattened_input_tree_paths,
  682. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  683. (flattened_shallow_tree_paths,
  684. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  685. self.assertEqual(flattened_input_tree_paths, [()])
  686. self.assertEqual(flattened_input_tree, [input_tree])
  687. self.assertEqual(flattened_shallow_tree_paths, [()])
  688. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  689. # Test case where len(shallow_tree) < len(input_tree)
  690. input_tree = {"a": "A", "b": "B", "c": "C"}
  691. shallow_tree = {"a": 1, "c": 2}
  692. # Using non-iterable elements.
  693. input_tree = [0]
  694. shallow_tree = 9
  695. (flattened_input_tree_paths,
  696. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  697. (flattened_shallow_tree_paths,
  698. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  699. self.assertEqual(flattened_input_tree_paths, [()])
  700. self.assertEqual(flattened_input_tree, [input_tree])
  701. self.assertEqual(flattened_shallow_tree_paths, [()])
  702. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  703. input_tree = [0, 1]
  704. shallow_tree = 9
  705. (flattened_input_tree_paths,
  706. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  707. (flattened_shallow_tree_paths,
  708. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  709. self.assertEqual(flattened_input_tree_paths, [()])
  710. self.assertEqual(flattened_input_tree, [input_tree])
  711. self.assertEqual(flattened_shallow_tree_paths, [()])
  712. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  713. ## Both non-list edge-case.
  714. # Using iterable elements.
  715. input_tree = "input_tree"
  716. shallow_tree = "shallow_tree"
  717. (flattened_input_tree_paths,
  718. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  719. (flattened_shallow_tree_paths,
  720. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  721. self.assertEqual(flattened_input_tree_paths, [()])
  722. self.assertEqual(flattened_input_tree, [input_tree])
  723. self.assertEqual(flattened_shallow_tree_paths, [()])
  724. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  725. # Using non-iterable elements.
  726. input_tree = 0
  727. shallow_tree = 0
  728. (flattened_input_tree_paths,
  729. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  730. (flattened_shallow_tree_paths,
  731. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  732. self.assertEqual(flattened_input_tree_paths, [()])
  733. self.assertEqual(flattened_input_tree, [input_tree])
  734. self.assertEqual(flattened_shallow_tree_paths, [()])
  735. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  736. ## Input non-list edge-case.
  737. # Using iterable elements.
  738. input_tree = "input_tree"
  739. shallow_tree = ["shallow_tree"]
  740. with self.assertRaisesWithLiteralMatch(
  741. TypeError,
  742. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
  743. path=[], input_type=type(input_tree))):
  744. (flattened_input_tree_paths,
  745. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  746. (flattened_shallow_tree_paths,
  747. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  748. self.assertEqual(flattened_shallow_tree_paths, [(0,)])
  749. self.assertEqual(flattened_shallow_tree, shallow_tree)
  750. input_tree = "input_tree"
  751. shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
  752. with self.assertRaisesWithLiteralMatch(
  753. TypeError,
  754. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
  755. path=[], input_type=type(input_tree))):
  756. (flattened_input_tree_paths,
  757. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  758. (flattened_shallow_tree_paths,
  759. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  760. self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
  761. self.assertEqual(flattened_shallow_tree, shallow_tree)
  762. # Using non-iterable elements.
  763. input_tree = 0
  764. shallow_tree = [9]
  765. with self.assertRaisesWithLiteralMatch(
  766. TypeError,
  767. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
  768. path=[], input_type=type(input_tree))):
  769. (flattened_input_tree_paths,
  770. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  771. (flattened_shallow_tree_paths,
  772. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  773. self.assertEqual(flattened_shallow_tree_paths, [(0,)])
  774. self.assertEqual(flattened_shallow_tree, shallow_tree)
  775. input_tree = 0
  776. shallow_tree = [9, 8]
  777. with self.assertRaisesWithLiteralMatch(
  778. TypeError,
  779. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
  780. path=[], input_type=type(input_tree))):
  781. (flattened_input_tree_paths,
  782. flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
  783. (flattened_shallow_tree_paths,
  784. flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
  785. self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
  786. self.assertEqual(flattened_shallow_tree, shallow_tree)
  787. # Test that error messages include paths.
  788. input_tree = {"a": {"b": {0, 1}}}
  789. structure = {"a": {"b": [0, 1]}}
  790. with self.assertRaisesWithLiteralMatch(
  791. TypeError,
  792. tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
  793. path=["a", "b"], input_type=type(input_tree["a"]["b"]))):
  794. (flattened_input_tree_paths,
  795. flattened_input_tree) = get_paths_and_values(structure, input_tree)
  796. (flattened_tree_paths,
  797. flattened_tree) = get_paths_and_values(structure, structure)
  798. self.assertEqual(flattened_tree_paths, [("a", "b", 0,), ("a", "b", 1,)])
  799. self.assertEqual(flattened_tree, structure["a"]["b"])
  800. def testMapStructureUpTo(self):
  801. # Named tuples.
  802. ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  803. op_tuple = collections.namedtuple("op_tuple", "add, mul")
  804. inp_val = ab_tuple(a=2, b=3)
  805. inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
  806. out = tree.map_structure_up_to(
  807. inp_val,
  808. lambda val, ops: (val + ops.add) * ops.mul,
  809. inp_val,
  810. inp_ops,
  811. check_types=False)
  812. self.assertEqual(out.a, 6)
  813. self.assertEqual(out.b, 15)
  814. # Lists.
  815. data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
  816. name_list = ["evens", ["odds", "primes"]]
  817. out = tree.map_structure_up_to(
  818. name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
  819. name_list, data_list)
  820. self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
  821. # We cannot define namedtuples within @parameterized argument lists.
  822. # pylint: disable=invalid-name
  823. Foo = collections.namedtuple("Foo", ["a", "b"])
  824. Bar = collections.namedtuple("Bar", ["c", "d"])
  825. # pylint: enable=invalid-name
  826. @parameterized.parameters([
  827. dict(inputs=[], expected=[]),
  828. dict(inputs=[23, "42"], expected=[((0,), 23), ((1,), "42")]),
  829. dict(inputs=[[[[108]]]], expected=[((0, 0, 0, 0), 108)]),
  830. dict(inputs=Foo(a=3, b=Bar(c=23, d=42)),
  831. expected=[(("a",), 3), (("b", "c"), 23), (("b", "d"), 42)]),
  832. dict(inputs=Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="thing")),
  833. expected=[(("a", "c"), 23), (("a", "d"), 42), (("b", "c"), 0),
  834. (("b", "d"), "thing")]),
  835. dict(inputs=Bar(c=42, d=43),
  836. expected=[(("c",), 42), (("d",), 43)]),
  837. dict(inputs=Bar(c=[42], d=43),
  838. expected=[(("c", 0), 42), (("d",), 43)]),
  839. dict(inputs=wrapt.ObjectProxy(Bar(c=[42], d=43)),
  840. expected=[(("c", 0), 42), (("d",), 43)]),
  841. ])
  842. def testFlattenWithPath(self, inputs, expected):
  843. self.assertEqual(tree.flatten_with_path(inputs), expected)
  844. @parameterized.named_parameters([
  845. dict(testcase_name="Tuples", s1=(1, 2), s2=(3, 4),
  846. check_types=True, expected=(((0,), 4), ((1,), 6))),
  847. dict(testcase_name="Dicts", s1={"a": 1, "b": 2}, s2={"b": 4, "a": 3},
  848. check_types=True, expected={"a": (("a",), 4), "b": (("b",), 6)}),
  849. dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4],
  850. check_types=False, expected=(((0,), 4), ((1,), 6))),
  851. dict(testcase_name="Nested",
  852. s1={"a": [2, 3], "b": [1, 2, 3]},
  853. s2={"b": [5, 6, 7], "a": [8, 9]},
  854. check_types=True,
  855. expected={"a": [(("a", 0), 10), (("a", 1), 12)],
  856. "b": [(("b", 0), 6), (("b", 1), 8), (("b", 2), 10)]}),
  857. ])
  858. def testMapWithPathCompatibleStructures(self, s1, s2, check_types, expected):
  859. def path_and_sum(path, *values):
  860. return path, sum(values)
  861. result = tree.map_structure_with_path(
  862. path_and_sum, s1, s2, check_types=check_types)
  863. self.assertEqual(expected, result)
  864. @parameterized.named_parameters([
  865. dict(testcase_name="Tuples", s1=(1, 2, 3), s2=(4, 5),
  866. error_type=ValueError),
  867. dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2},
  868. error_type=ValueError),
  869. dict(testcase_name="Nested",
  870. s1={"a": [2, 3, 4], "b": [1, 3]},
  871. s2={"b": [5, 6], "a": [8, 9]},
  872. error_type=ValueError)
  873. ])
  874. def testMapWithPathIncompatibleStructures(self, s1, s2, error_type):
  875. with self.assertRaises(error_type):
  876. tree.map_structure_with_path(lambda path, *s: 0, s1, s2)
  877. def testMappingProxyType(self):
  878. structure = types.MappingProxyType({"a": 1, "b": (2, 3)})
  879. expected = types.MappingProxyType({"a": 4, "b": (5, 6)})
  880. self.assertEqual(tree.flatten(structure), [1, 2, 3])
  881. self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected)
  882. self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected)
  883. def testTraverseListsToTuples(self):
  884. structure = [(1, 2), [3], {"a": [4]}]
  885. self.assertEqual(
  886. ((1, 2), (3,), {"a": (4,)}),
  887. tree.traverse(
  888. lambda x: tuple(x) if isinstance(x, list) else x,
  889. structure,
  890. top_down=False))
  891. def testTraverseEarlyTermination(self):
  892. structure = [(1, [2]), [3, (4, 5, 6)]]
  893. visited = []
  894. def visit(x):
  895. visited.append(x)
  896. return "X" if isinstance(x, tuple) and len(x) > 2 else None
  897. output = tree.traverse(visit, structure)
  898. self.assertEqual([(1, [2]), [3, "X"]], output)
  899. self.assertEqual(
  900. [[(1, [2]), [3, (4, 5, 6)]],
  901. (1, [2]), 1, [2], 2, [3, (4, 5, 6)], 3, (4, 5, 6)],
  902. visited)
  903. def testMapStructureAcrossSubtreesDict(self):
  904. shallow = {"a": 1, "b": {"c": 2}}
  905. deep1 = {"a": 2, "b": {"c": 3, "d": 2}, "e": 4}
  906. deep2 = {"a": 3, "b": {"c": 2, "d": 3}, "e": 1}
  907. summed = tree.map_structure_up_to(
  908. shallow, lambda *args: sum(args), deep1, deep2)
  909. expected = {"a": 5, "b": {"c": 5}}
  910. self.assertEqual(summed, expected)
  911. concatenated = tree.map_structure_up_to(
  912. shallow, lambda *args: args, deep1, deep2)
  913. expected = {"a": (2, 3), "b": {"c": (3, 2)}}
  914. self.assertEqual(concatenated, expected)
  915. def testMapStructureAcrossSubtreesNoneValues(self):
  916. shallow = [1, [None]]
  917. deep1 = [1, [2, 3]]
  918. deep2 = [2, [3, 4]]
  919. summed = tree.map_structure_up_to(
  920. shallow, lambda *args: sum(args), deep1, deep2)
  921. expected = [3, [5]]
  922. self.assertEqual(summed, expected)
  923. def testMapStructureAcrossSubtreesList(self):
  924. shallow = [1, [1]]
  925. deep1 = [1, [2, 3]]
  926. deep2 = [2, [3, 4]]
  927. summed = tree.map_structure_up_to(
  928. shallow, lambda *args: sum(args), deep1, deep2)
  929. expected = [3, [5]]
  930. self.assertEqual(summed, expected)
  931. def testMapStructureAcrossSubtreesTuple(self):
  932. shallow = (1, (1,))
  933. deep1 = (1, (2, 3))
  934. deep2 = (2, (3, 4))
  935. summed = tree.map_structure_up_to(
  936. shallow, lambda *args: sum(args), deep1, deep2)
  937. expected = (3, (5,))
  938. self.assertEqual(summed, expected)
  939. def testMapStructureAcrossSubtreesNamedTuple(self):
  940. Foo = collections.namedtuple("Foo", ["x", "y"])
  941. Bar = collections.namedtuple("Bar", ["x"])
  942. shallow = Bar(1)
  943. deep1 = Foo(1, (1, 0))
  944. deep2 = Foo(2, (2, 0))
  945. summed = tree.map_structure_up_to(
  946. shallow, lambda *args: sum(args), deep1, deep2)
  947. expected = Bar(3)
  948. self.assertEqual(summed, expected)
  949. def testMapStructureAcrossSubtreesListTuple(self):
  950. # Tuples and lists can be used interchangeably between shallow structure
  951. # and input structures. Output takes on type of the shallow structure
  952. shallow = [1, (1,)]
  953. deep1 = [1, [2, 3]]
  954. deep2 = [2, [3, 4]]
  955. summed = tree.map_structure_up_to(shallow, lambda *args: sum(args), deep1,
  956. deep2)
  957. expected = [3, (5,)]
  958. self.assertEqual(summed, expected)
  959. shallow = [1, [1]]
  960. deep1 = [1, (2, 3)]
  961. deep2 = [2, (3, 4)]
  962. summed = tree.map_structure_up_to(shallow, lambda *args: sum(args), deep1,
  963. deep2)
  964. expected = [3, [5]]
  965. self.assertEqual(summed, expected)
  966. def testNoneNodeIncluded(self):
  967. structure = ((1, None))
  968. self.assertEqual(tree.flatten(structure), [1, None])
  969. def testCustomClassMapWithPath(self):
  970. class ExampleClass(Mapping[Any, Any]):
  971. """Small example custom class."""
  972. def __init__(self, *args, **kwargs):
  973. self._mapping = dict(*args, **kwargs)
  974. def __getitem__(self, k: Any) -> Any:
  975. return self._mapping[k]
  976. def __len__(self) -> int:
  977. return len(self._mapping)
  978. def __iter__(self) -> Iterator[Any]:
  979. return iter(self._mapping)
  980. def mapper(path, value):
  981. full_path = "/".join(path)
  982. return f"{full_path}_{value}"
  983. test_input = ExampleClass({"first": 1, "nested": {"second": 2, "third": 3}})
  984. output = tree.map_structure_with_path(mapper, test_input)
  985. expected = ExampleClass({
  986. "first": "first_1",
  987. "nested": {
  988. "second": "nested/second_2",
  989. "third": "nested/third_3"
  990. }
  991. })
  992. self.assertEqual(output, expected)
  993. if __name__ == "__main__":
  994. unittest.main()