tree_benchmark.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Benchmarks for utilities working with arbitrarily nested structures."""
  16. import collections
  17. import timeit
  18. import tree
  19. TIME_UNITS = [
  20. (1, "s"),
  21. (10**-3, "ms"),
  22. (10**-6, "us"),
  23. (10**-9, "ns"),
  24. ]
  25. def format_time(time):
  26. for d, unit in TIME_UNITS:
  27. if time > d:
  28. return "{:.2f}{}".format(time / d, unit)
  29. def run_benchmark(benchmark_fn, num_iters):
  30. times = timeit.repeat(benchmark_fn, repeat=2, number=num_iters)
  31. return times[-1] / num_iters # Discard the first half for "warmup".
  32. def map_to_list(func, *args):
  33. return list(map(func, *args))
  34. def benchmark_map(map_fn, structure):
  35. def benchmark_fn():
  36. return map_fn(lambda v: v, structure)
  37. return benchmark_fn
  38. BENCHMARKS = collections.OrderedDict([
  39. ("tree_map_1", benchmark_map(tree.map_structure, [0])),
  40. ("tree_map_8", benchmark_map(tree.map_structure, [0] * 8)),
  41. ("tree_map_64", benchmark_map(tree.map_structure, [0] * 64)),
  42. ("builtin_map_1", benchmark_map(map_to_list, [0])),
  43. ("builtin_map_8", benchmark_map(map_to_list, [0] * 8)),
  44. ("builtin_map_64", benchmark_map(map_to_list, [0] * 64)),
  45. ])
  46. def main():
  47. for name, benchmark_fn in BENCHMARKS.items():
  48. print(name, format_time(run_benchmark(benchmark_fn, num_iters=1000)))
  49. if __name__ == "__main__":
  50. main()