save_model.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright 2024 S.F Sune Limited.
  2. # https://github.com/sfsun67/GraphCast-from-Ground-Zero
  3. """
  4. 创建一个数据类,用于存储模型参数和配置信息。
  5. Create a data class to store model parameters and configuration information.
  6. """
  7. import dataclasses
  8. from typing import Any, Optional, Union
  9. @dataclasses.dataclass
  10. class SubConfig:
  11. a: int
  12. b: str
  13. @dataclasses.dataclass
  14. class Config:
  15. bt: bool
  16. bf: bool
  17. i: int
  18. f: float
  19. o1: Optional[int]
  20. o2: Optional[int]
  21. o3: Union[int, None]
  22. o4: Union[int, None]
  23. o5: int | None
  24. o6: int | None
  25. li: list[int]
  26. ls: list[str]
  27. ldc: list[SubConfig]
  28. tf: tuple[float, ...]
  29. ts: tuple[str, ...]
  30. t: tuple[str, int, SubConfig]
  31. tdc: tuple[SubConfig, ...]
  32. dsi: dict[str, int]
  33. dss: dict[str, str]
  34. dis: dict[int, str]
  35. dsdis: dict[str, dict[int, str]]
  36. dc: SubConfig
  37. dco: Optional[SubConfig]
  38. ddc: dict[str, SubConfig]
  39. @dataclasses.dataclass
  40. class ModelConfig:
  41. resolution: float
  42. mesh_size: int
  43. latent_size: int
  44. gnn_msg_steps: int
  45. hidden_layers: int
  46. radius_query_fraction_edge_length: float
  47. mesh2grid_edge_normalization_factor: float
  48. @dataclasses.dataclass
  49. class TaskConfig:
  50. input_variables: tuple[str, ...]
  51. target_variables: tuple[str, ...]
  52. forcing_variables: tuple[str, ...]
  53. pressure_levels: tuple[int, ...]
  54. input_duration: str
  55. @dataclasses.dataclass
  56. class Checkpoint:
  57. params: dict[str, Any]
  58. model_config: ModelConfig
  59. task_config: TaskConfig
  60. description: str
  61. license: str