{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "-jAYlxeKxvAJ" }, "source": [ "# 从零开始运行 GraphCast (AutoDL 或者其他新的环境)\n", "-------------------------------------------------------------------\n", "**这是从 https://google-deepmind/graphcast 复现的项目。由 https://github.com/sfsun67 改写和调试。**\n", "\n", "**AutoDL 是国内的一家云计算平台,网址是https://www.autodl.com**\n", "\n", "你应该有类似的文件结构,这里的数据由 Google Cloud Bucket (https://console.cloud.google.com/storage/browser/dm_graphcast 提供:\n", "```\n", ".\n", "├── code\n", "│ ├── graphcast-main\n", "│ ├──graphcast\n", "│ ├──tree\n", "│ ├──wrapt\n", "│ ├──graphcast_demo.ipynb\n", "│ ├──README.md\n", "│ ├──setup.py\n", "│ ├──...\n", "├── data\n", "│ ├── dataset\n", "│ ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc\n", "│ ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-04.nc\n", "│ ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-12.nc\n", "│ ├──...\n", "│ ├── params\n", "│ ├──params-GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz\n", "│ ├──params-GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz\n", "│ ├──...\n", "│ ├── stats\n", "│ ├──stats-mean_by_level.nc\n", "│ ├──...\n", "└────── \n", "```\n", "\n", "PS: \n", "1. Python 要使用3.10版本。老版本会出现函数调用失效的问题。\n", "2. 你需要仔细核对包的版本,防止出现意外的错误。例如, xarray 只能使用 2023.7.0 版本,其他版本会出现错误。\n", "3. 你需要仔细核对所有包是否安装正确。未安装的包会导致意外错误。例如,tree 和 wrapt 是两个 GraphCast 所必需的包,但是并不在源文件中。例如,tree 和 wrapt 中的 .os 文件未导入,会引发循环调用。他们的原始文件可以在 Colaboratory(https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb) 的环境中找到。\n", "\n", "\n", "\n", "*代码在如下机器上测试*\n", "1. GPU: TITAN Xp 12GB; CPU: Xeon(R) E5-2680 v4; JAX / 0.3.10 / 3.8(ubuntu18.04) / 11.1\n", "2. GPU: V100-SXM2-32GB 32GB; CPU: Xeon(R) Platinum 8255C; JAX / 0.3.10 / 3.8(ubuntu18.04) / 11.1\n", "-------------------------------------------------------------------\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

版权所有 2023 年 DeepMind Technologies Limited。

\n", "

根据 Apache 许可证第 2.0 版(\"许可证\")获得许可;除非符合许可证的规定,否则您不得使用此文件。您可以在 http://www.apache.org/licenses/LICENSE-2.0 获取许可证的副本。

\n", "

除非适用法律要求或书面同意,根据许可证分发的软件是基于 \"按原样\" 分发的,没有任何明示或暗示的担保或条件。有关许可证下的具体语言,请参见许可证中的权限和限制。

\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 将 Python 版本更新到 3.10.\n", "\n", "GraphCast 需要 Python >= 3.10 。推荐 Python 3.10。\n", "\n", "在终端中,新建一个名为 GraphCast 的环境。\n", "\n", "参考代码如下:\n", "```\n", "\n", "# 更新 conda (可选)\n", "conda update -n base -c defaults conda\n", "\n", "# 在新环境 GraphCast 中安装 python=3.10 \n", "conda create -n GraphCast python=3.10 \n", "\n", "# 更新bashrc中的环境变量\n", "conda init bash && source /root/.bashrc\n", "\n", "# 激活新的环境\n", "conda activate GraphCast\n", "\n", "# 验证版本\n", "python\n", "```\n", "\n", "\n", "注意:验证版本之后,重启jupyter,使用新的内核。 " ] }, { "cell_type": "markdown", "metadata": { "id": "yMbbXFl4msJw" }, "source": [ "# 安装和初始化\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# 学术资源加速 https://www.autodl.com/docs/network_turbo/ .\n", "\n", "import subprocess\n", "import os\n", "\n", "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n", "output = result.stdout\n", "for line in output.splitlines():\n", " if '=' in line:\n", " var, value = line.split('=', 1)\n", " os.environ[var] = value" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Skipping shapely as it is not installed.\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mChannels:\n", " - https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " - https://mirrors.ustc.edu.cn/anaconda/pkgs/free\n", " - defaults\n", "Platform: linux-64\n", "Collecting package metadata (repodata.json): done\n", "Solving environment: done\n", "\n", "## Package Plan ##\n", "\n", " environment location: /root/miniconda3/envs/GraphCast\n", "\n", " added / updated specs:\n", " - shapely\n", "\n", "\n", "The following packages will be downloaded:\n", "\n", " package | build\n", " ---------------------------|-----------------\n", " blas-1.0 | mkl 6 KB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " geos-3.8.0 | he6710b0_0 961 KB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " intel-openmp-2023.1.0 | hdb19cb5_46306 17.2 MB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " mkl-2023.1.0 | h213fc3f_46344 171.5 MB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " mkl-service-2.4.0 | py310h5eee18b_1 54 KB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " mkl_fft-1.3.8 | py310h5eee18b_0 216 KB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " mkl_random-1.2.4 | py310hdb19cb5_0 312 KB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " numpy-1.26.4 | py310h5f9d8c6_0 11 KB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " numpy-base-1.26.4 | py310hb5e798b_0 7.2 MB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " shapely-2.0.1 | py310h006c72b_0 433 KB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " tbb-2021.8.0 | hdb19cb5_0 1.6 MB https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " ------------------------------------------------------------\n", " Total: 199.5 MB\n", "\n", "The following NEW packages will be INSTALLED:\n", "\n", " blas anaconda/pkgs/main/linux-64::blas-1.0-mkl \n", " geos anaconda/pkgs/main/linux-64::geos-3.8.0-he6710b0_0 \n", " intel-openmp anaconda/pkgs/main/linux-64::intel-openmp-2023.1.0-hdb19cb5_46306 \n", " mkl anaconda/pkgs/main/linux-64::mkl-2023.1.0-h213fc3f_46344 \n", " mkl-service anaconda/pkgs/main/linux-64::mkl-service-2.4.0-py310h5eee18b_1 \n", " mkl_fft anaconda/pkgs/main/linux-64::mkl_fft-1.3.8-py310h5eee18b_0 \n", " mkl_random anaconda/pkgs/main/linux-64::mkl_random-1.2.4-py310hdb19cb5_0 \n", " numpy anaconda/pkgs/main/linux-64::numpy-1.26.4-py310h5f9d8c6_0 \n", " numpy-base anaconda/pkgs/main/linux-64::numpy-base-1.26.4-py310hb5e798b_0 \n", " shapely anaconda/pkgs/main/linux-64::shapely-2.0.1-py310h006c72b_0 \n", " tbb anaconda/pkgs/main/linux-64::tbb-2021.8.0-hdb19cb5_0 \n", "\n", "\n", "\n", "Downloading and Extracting Packages:\n", "mkl-2023.1.0 | 171.5 MB | | 0% \n", "intel-openmp-2023.1. | 17.2 MB | | 0% \u001b[A\n", "\n", "numpy-base-1.26.4 | 7.2 MB | | 0% \u001b[A\u001b[A\n", "\n", "\n", "tbb-2021.8.0 | 1.6 MB | | 0% \u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "geos-3.8.0 | 961 KB | | 0% \u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "shapely-2.0.1 | 433 KB | | 0% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "mkl_random-1.2.4 | 312 KB | | 0% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl_fft-1.3.8 | 216 KB | | 0% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl-service-2.4.0 | 54 KB | | 0% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "numpy-1.26.4 | 11 KB | | 0% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "blas-1.0 | 6 KB | | 0% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "numpy-base-1.26.4 | 7.2 MB | | 0% \u001b[A\u001b[A\n", "mkl-2023.1.0 | 171.5 MB | | 0% \u001b[A\n", "\n", "\n", "tbb-2021.8.0 | 1.6 MB | 3 | 1% \u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "geos-3.8.0 | 961 KB | 6 | 2% \u001b[A\u001b[A\u001b[A\u001b[A\n", "intel-openmp-2023.1. | 17.2 MB | #######9 | 21% \u001b[A\n", "\n", "mkl-2023.1.0 | 171.5 MB | 7 | 2% \u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "mkl_random-1.2.4 | 312 KB | #8 | 5% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "shapely-2.0.1 | 433 KB | #3 | 4% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl_fft-1.3.8 | 216 KB | ##7 | 7% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "tbb-2021.8.0 | 1.6 MB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\n", "\n", "\n", "tbb-2021.8.0 | 1.6 MB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "mkl_random-1.2.4 | 312 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl-service-2.4.0 | 54 KB | ##########9 | 29% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "mkl-2023.1.0 | 171.5 MB | #2 | 3% \u001b[A\u001b[A\n", "intel-openmp-2023.1. | 17.2 MB | ############5 | 34% \u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "numpy-1.26.4 | 11 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "blas-1.0 | 6 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl_fft-1.3.8 | 216 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl-2023.1.0 | 171.5 MB | #7 | 5% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "intel-openmp-2023.1. | 17.2 MB | #################9 | 48% \u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl-service-2.4.0 | 54 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl-service-2.4.0 | 54 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "numpy-1.26.4 | 11 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "mkl-2023.1.0 | 171.5 MB | ##3 | 6% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "mkl-2023.1.0 | 171.5 MB | ### | 8% \u001b[A\n", "intel-openmp-2023.1. | 17.2 MB | ################################ | 87% \u001b[A\n", "\n", "\n", "\n", "geos-3.8.0 | 961 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "mkl-2023.1.0 | 171.5 MB | ###8 | 10% \u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "shapely-2.0.1 | 433 KB | ##################################### | 100% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "mkl-2023.1.0 | 171.5 MB | ############6 | 34% \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "mkl-2023.1.0 | 171.5 MB | ###################1 | 52% \u001b[A\n", "\n", " \u001b[A\u001b[A\n", " \u001b[A\n", "\n", " \u001b[A\u001b[A\n", "\n", "\n", " \u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", " \u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", " \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", " \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", " \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n", "Preparing transaction: done\n", "Verifying transaction: done\n", "Executing transaction: done\n", "Found existing installation: shapely 2.0.1\n", "Uninstalling shapely-2.0.1:\n", " Successfully uninstalled shapely-2.0.1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "# 这一步将使用 shapely 安装环境。安装 shapely 是为了自动处理 GraphCast 的依赖环境。删除 shapely 是因为改版本不被后续所需。\n", "\n", "!pip uninstall -y shapely\n", "!conda install -y shapely\n", "!pip uninstall -y shapely" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "id": "-W4K9skv9vh-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: http://mirrors.aliyun.com/pypi/simple\n", "Collecting https://github.com/deepmind/graphcast/archive/master.zip\n", " Downloading https://github.com/deepmind/graphcast/archive/master.zip\n", "\u001b[2K \u001b[32m-\u001b[0m \u001b[32m106.5 kB\u001b[0m \u001b[31m35.8 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hCollecting cartopy (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/e2/4e/ba1a6046d52888cd38e3ebdb4e2eb2e366907be7a44e8e9014f63e4786bf/Cartopy-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.8/11.8 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n", "\u001b[?25hCollecting chex (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/9a/82/257141baabfaf8b0187521ddb83e996f2a71cdd4f7796d9599ca3e3ea4a9/chex-0.1.85-py3-none-any.whl (95 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.1/95.1 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting colabtools (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/bb/29/9088b67e938f38885c1035b36624ed6176c73845152c5ddd603facfa3e24/colabtools-0.0.1-py3-none-any.whl (14 kB)\n", "Collecting dask (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/ff/d3/f1dcba697c7d7e8470ffa34b31ca1e663d4a2654ef806877f1017ecc5102/dask-2024.2.1-py3-none-any.whl (1.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting dm-haiku (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/1c/c2/4a32e22bad1c5c675ac53701b099ce39c286970326512d3e9b06f8866f7d/dm_haiku-0.0.12-py3-none-any.whl (371 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m371.7/371.7 kB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting dm-tree (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/cc/2b/a13e3a44f9121ecab0057af462baeb64dc50eb269de52648db8823bc12ae/dm_tree-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m152.8/152.8 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting jax (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/ad/29/37cc2d58775917e6da532ef59cd3a66133d4de73fce1c16852e8475e5411/jax-0.4.25-py3-none-any.whl (1.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m44.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting jraph (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/2a/e2/f799edeb39a154560b52134cdb3a3359e2de965c76886949966e46d5c42b/jraph-0.0.6.dev0-py3-none-any.whl (90 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.6/90.6 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting matplotlib (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/c1/f2/325897d6c498278b0f8b460d44b516f5db865ddb4ba9018e9fe58a3e4633/matplotlib-3.8.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m155.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from graphcast==0.1) (1.26.4)\n", "Collecting pandas (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/19/df/8d789d96a9e338cf28cb7978fa93ef5da53137624b7ef032f30748421c2b/pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.0/13.0 MB\u001b[0m \u001b[31m160.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting rtree (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/59/a5/176d27468a1b0bcd7fa9c011cadacfa364e9bca8fa649baab7fb3f15af70/Rtree-1.2.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (535 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m535.2/535.2 kB\u001b[0m \u001b[31m112.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting scipy (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/f5/aa/8e6071a5e4dca4ec68b5b22e4991ee74c59c5d372112b9c236ec1faff57d/scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.4/38.4 MB\u001b[0m \u001b[31m93.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting trimesh (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/41/ca/deada79afce17e97b93dbb4802971be3330d7ecfc728fd4ed89c43776bf6/trimesh-4.1.7-py3-none-any.whl (690 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m690.4/690.4 kB\u001b[0m \u001b[31m112.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: typing_extensions in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from graphcast==0.1) (4.10.0)\n", "Collecting xarray (from graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/47/f6/f2d4a0a2a4eb6fc427f1e482987e24104c9710c4244c76ea75b55243ada0/xarray-2024.2.0-py3-none-any.whl (1.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting shapely>=1.7 (from cartopy->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/3e/d2/2afa1e563d417401ac364017a4d4d2d13a9dacb7dcbb83cc13f48a1efe41/shapely-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.5/2.5 MB\u001b[0m \u001b[31m113.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: packaging>=20 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from cartopy->graphcast==0.1) (23.2)\n", "Collecting pyshp>=2.1 (from cartopy->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/98/2f/68116db5b36b895c0450e3072b8cb6c2fac0359279b182ea97014d3c8ac0/pyshp-2.3.1-py2.py3-none-any.whl (46 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.5/46.5 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting pyproj>=3.1.0 (from cartopy->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/f6/2b/b60cf73b0720abca313bfffef34e34f7f7dae23852b2853cf0368d49426b/pyproj-3.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.3/8.3 MB\u001b[0m \u001b[31m137.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting contourpy>=1.0.1 (from matplotlib->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/58/56/e2c43dcfa1f9c7db4d5e3d6f5134b24ed953f4e2133a4b12f0062148db58/contourpy-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (310 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.7/310.7 kB\u001b[0m \u001b[31m76.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting cycler>=0.10 (from matplotlib->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl (8.3 kB)\n", "Collecting fonttools>=4.22.0 (from matplotlib->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/a6/ba/5eac3e9c9bbc2dea3606e46de08bcef0908d74e7ccf89a71701b95a16747/fonttools-4.49.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m104.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting kiwisolver>=1.3.1 (from matplotlib->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/6f/40/4ab1fdb57fced80ce5903f04ae1aed7c1d5939dda4fd0c0aa526c12fe28a/kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m99.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting pillow>=8 (from matplotlib->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/cb/c3/98faa3e92cf866b9446c4842f1fe847e672b2f54e000cb984157b8095797/pillow-10.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (4.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m57.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting pyparsing>=2.3.1 (from matplotlib->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/9d/ea/6d76df31432a0e6fdf81681a895f009a4bb47b3c39036db3e1b528191d52/pyparsing-3.1.2-py3-none-any.whl (103 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m103.2/103.2 kB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: python-dateutil>=2.7 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from matplotlib->graphcast==0.1) (2.9.0)\n", "Collecting absl-py>=0.9.0 (from chex->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl (133 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.7/133.7 kB\u001b[0m \u001b[31m43.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting jaxlib>=0.1.37 (from chex->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/cf/68/0c82f9a43fbf69ccec013ac7c95e432b0147b2dfb8b5fc5f5a5ef83f2df3/jaxlib-0.4.25-cp310-cp310-manylinux2014_x86_64.whl (79.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.2/79.2 MB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0mm\n", "\u001b[?25hCollecting toolz>=0.9.0 (from chex->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/b7/8a/d82202c9f89eab30f9fc05380daae87d617e2ad11571ab23d7c13a29bb54/toolz-0.12.1-py3-none-any.whl (56 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.1/56.1 kB\u001b[0m \u001b[31m21.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting ml-dtypes>=0.2.0 (from jax->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/71/01/7dc0e2cdead686a758810d08fd4111602088fe3f0d88064a83cbfb635593/ml_dtypes-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting opt-einsum (from jax->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/bc/19/404708a7e54ad2798907210462fd950c3442ea51acc8790f3da48d2bee8b/opt_einsum-3.3.0-py3-none-any.whl (65 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.5/65.5 kB\u001b[0m \u001b[31m30.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting click>=8.1 (from dask->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl (97 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m97.9/97.9 kB\u001b[0m \u001b[31m38.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting cloudpickle>=1.5.0 (from dask->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/96/43/dae06432d0c4b1dc9e9149ad37b4ca8384cf6eb7700cd9215b177b914f0a/cloudpickle-3.0.0-py3-none-any.whl (20 kB)\n", "Collecting fsspec>=2021.09.0 (from dask->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/ad/30/2281c062222dc39328843bd1ddd30ff3005ef8e30b2fd09c4d2792766061/fsspec-2024.2.0-py3-none-any.whl (170 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m170.9/170.9 kB\u001b[0m \u001b[31m62.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting partd>=1.2.0 (from dask->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/11/8a/b7a58e208b144a7315208a0dd627e23f5f50b47fa89c2924bb2e9238ecfb/partd-1.4.1-py3-none-any.whl (18 kB)\n", "Collecting pyyaml>=5.3.1 (from dask->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/29/61/bf33c6c85c55bc45a29eee3195848ff2d518d84735eb0e2d8cb42e0d285e/PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (705 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m705.5/705.5 kB\u001b[0m \u001b[31m85.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: importlib-metadata>=4.13.0 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from dask->graphcast==0.1) (7.0.2)\n", "Collecting jmp>=0.0.2 (from dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/27/e5/cce82de2831e5aff9332d8d624bb57188f1b2af6ccf6979caf898a8a4348/jmp-0.0.4-py3-none-any.whl (18 kB)\n", "Collecting tabulate>=0.8.9 (from dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl (35 kB)\n", "Collecting flax>=0.7.1 (from dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/8d/4a/7e78abc8392ff21b0257deb79e842f80647b63b745447df94893732d60fd/flax-0.8.1-py3-none-any.whl (677 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m677.6/677.6 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting pytz>=2020.1 (from pandas->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/9c/3d/a121f284241f08268b21359bd425f7d4825cffc5ac5cd0e1b3d82ffd2b10/pytz-2024.1-py2.py3-none-any.whl (505 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m505.5/505.5 kB\u001b[0m \u001b[31m97.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting tzdata>=2022.7 (from pandas->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/65/58/f9c9e6be752e9fcb8b6a0ee9fb87e6e7a1f6bcab2cdc73f02bb7ba91ada0/tzdata-2024.1-py2.py3-none-any.whl (345 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m345.4/345.4 kB\u001b[0m \u001b[31m80.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting msgpack (from flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/d9/96/a1868dd8997d65732476dfc70fef44d046c1b4dbe36ec1481ab744d87775/msgpack-1.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (385 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m385.1/385.1 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting optax (from flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/d7/62/5072ef01f30959b726cf995e4c612138acfc3e7e149bf5496c5bb2aac680/optax-0.2.0-py3-none-any.whl (209 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.5/209.5 kB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting orbax-checkpoint (from flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/83/a2/0677f2ee06bdbf7b4e6be4ad931ffe58f2ea82d67bb2a277d9d7b3b1e352/orbax_checkpoint-0.5.3-py3-none-any.whl (143 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.0/143.0 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting tensorstore (from flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/3c/6f/5b09e7ff2e1d4cdbcda7a99b33556871fa788e0eb638e89604503b09d681/tensorstore-0.1.54-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.2/14.2 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n", "\u001b[?25hCollecting rich>=11.1 (from flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/87/67/a37f6214d0e9fe57f6ae54b2956d550ca8365857f42a1ce0392bb21d9410/rich-13.7.1-py3-none-any.whl (240 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m240.7/240.7 kB\u001b[0m \u001b[31m41.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: zipp>=0.5 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from importlib-metadata>=4.13.0->dask->graphcast==0.1) (3.17.0)\n", "Collecting locket (from partd>=1.2.0->dask->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/db/bc/83e112abc66cd466c6b83f99118035867cecd41802f8d044638aa78a106e/locket-1.0.0-py2.py3-none-any.whl (4.4 kB)\n", "Collecting certifi (from pyproj>=3.1.0->cartopy->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/ba/06/a07f096c664aeb9f01624f858c3add0a4e913d6c96257acb4fce61e7de14/certifi-2024.2.2-py3-none-any.whl (163 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.8/163.8 kB\u001b[0m \u001b[31m48.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: six>=1.5 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib->graphcast==0.1) (1.16.0)\n", "Collecting markdown-it-py>=2.2.0 (from rich>=11.1->flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl (87 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.5/87.5 kB\u001b[0m \u001b[31m23.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pygments<3.0.0,>=2.13.0 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from rich>=11.1->flax>=0.7.1->dm-haiku->graphcast==0.1) (2.17.2)\n", "Collecting etils[epath,epy] (from orbax-checkpoint->flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/37/10/dd5b124f037a636783e416a2fe839edd7ec63c0dce7ce4f3c1da029aeb80/etils-1.7.0-py3-none-any.whl (152 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m152.4/152.4 kB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: nest_asyncio in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from orbax-checkpoint->flax>=0.7.1->dm-haiku->graphcast==0.1) (1.6.0)\n", "Collecting protobuf (from orbax-checkpoint->flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/15/db/7f731524fe0e56c6b2eb57d05b55d3badd80ef7d1f1ed59db191b2fdd8ab/protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl (294 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m294.6/294.6 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting mdurl~=0.1 (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl (10.0 kB)\n", "Collecting importlib_resources (from etils[epath,epy]->orbax-checkpoint->flax>=0.7.1->dm-haiku->graphcast==0.1)\n", " Downloading http://mirrors.aliyun.com/pypi/packages/f3/3e/a01c6de9853a7d73672926e01966f979723386f0ba83d62e5db3cdf4c097/importlib_resources-6.1.3-py3-none-any.whl (34 kB)\n", "Building wheels for collected packages: graphcast\n", " Building wheel for graphcast (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for graphcast: filename=graphcast-0.1-py3-none-any.whl size=94233 sha256=816aa53de19b18ed34909733382b099801550998f5450196aa337bcc5afef8c0\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-uap1nz6o/wheels/35/9a/29/d754c3682bd39d8b5375879c27bea6b1417d2fe2ea71c2a46e\n", "Successfully built graphcast\n", "Installing collected packages: pytz, dm-tree, colabtools, tzdata, trimesh, toolz, tabulate, shapely, scipy, rtree, pyyaml, pyshp, pyparsing, protobuf, pillow, opt-einsum, msgpack, ml-dtypes, mdurl, locket, kiwisolver, jmp, importlib_resources, fsspec, fonttools, etils, cycler, contourpy, cloudpickle, click, certifi, absl-py, tensorstore, pyproj, partd, pandas, matplotlib, markdown-it-py, jaxlib, jax, xarray, rich, jraph, dask, chex, cartopy, orbax-checkpoint, optax, flax, dm-haiku, graphcast\n", "Successfully installed absl-py-2.1.0 cartopy-0.22.0 certifi-2024.2.2 chex-0.1.85 click-8.1.7 cloudpickle-3.0.0 colabtools-0.0.1 contourpy-1.2.0 cycler-0.12.1 dask-2024.2.1 dm-haiku-0.0.12 dm-tree-0.1.8 etils-1.7.0 flax-0.8.1 fonttools-4.49.0 fsspec-2024.2.0 graphcast-0.1 importlib_resources-6.1.3 jax-0.4.25 jaxlib-0.4.25 jmp-0.0.4 jraph-0.0.6.dev0 kiwisolver-1.4.5 locket-1.0.0 markdown-it-py-3.0.0 matplotlib-3.8.3 mdurl-0.1.2 ml-dtypes-0.3.2 msgpack-1.0.8 opt-einsum-3.3.0 optax-0.2.0 orbax-checkpoint-0.5.3 pandas-2.2.1 partd-1.4.1 pillow-10.2.0 protobuf-4.25.3 pyparsing-3.1.2 pyproj-3.6.1 pyshp-2.3.1 pytz-2024.1 pyyaml-6.0.1 rich-13.7.1 rtree-1.2.0 scipy-1.12.0 shapely-2.0.3 tabulate-0.9.0 tensorstore-0.1.54 toolz-0.12.1 trimesh-4.1.7 tzdata-2024.1 xarray-2024.2.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# @title Pip 安装 graphcast 和其他依赖项\n", "\n", "\n", "%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cellView": "form", "id": "MA5087Vb29z2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found existing installation: shapely 2.0.3\n", "Uninstalling shapely-2.0.3:\n", " Successfully uninstalled shapely-2.0.3\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: http://mirrors.aliyun.com/pypi/simple\n", "Collecting shapely\n", " Downloading http://mirrors.aliyun.com/pypi/packages/36/8f/03929218f8d7003c3eafa5ffad1fb3f185459d336fa9cc31d3e67f442f97/shapely-2.0.3.tar.gz (280 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m280.5/280.5 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Installing build dependencies ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", "\u001b[?25h Installing backend dependencies ... \u001b[?25ldone\n", "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: numpy<2,>=1.14 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from shapely) (1.26.4)\n", "Building wheels for collected packages: shapely\n", " Building wheel for shapely (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for shapely: filename=shapely-2.0.3-cp310-cp310-linux_x86_64.whl size=428218 sha256=290da7a19b27c6e0a4f19804686dcee97f0b24f27022919d1eb2b63176198884\n", " Stored in directory: /root/.cache/pip/wheels/9c/d8/0d/94954bf75398579fcda4f0e0e0333d7f8ea2ca4a7ce82bfb02\n", "Successfully built shapely\n", "Installing collected packages: shapely\n", "Successfully installed shapely-2.0.3\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "# @title cartopy 崩溃的解决方法\n", "\n", "!pip uninstall -y shapely\n", "!pip install shapely --no-binary shapely" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Channels:\n", " - conda-forge\n", " - https://mirrors.ustc.edu.cn/anaconda/pkgs/main\n", " - https://mirrors.ustc.edu.cn/anaconda/pkgs/free\n", " - defaults\n", "Platform: linux-64\n", "Collecting package metadata (repodata.json): done\n", "Solving environment: done\n", "\n", "## Package Plan ##\n", "\n", " environment location: /root/miniconda3/envs/GraphCast\n", "\n", " added / updated specs:\n", " - ipywidgets\n", "\n", "\n", "The following packages will be downloaded:\n", "\n", " package | build\n", " ---------------------------|-----------------\n", " ipywidgets-8.1.2 | pyhd8ed1ab_0 111 KB conda-forge\n", " jupyterlab_widgets-3.0.10 | pyhd8ed1ab_0 183 KB conda-forge\n", " widgetsnbextension-4.0.10 | pyhd8ed1ab_0 866 KB conda-forge\n", " ------------------------------------------------------------\n", " Total: 1.1 MB\n", "\n", "The following NEW packages will be INSTALLED:\n", "\n", " ipywidgets conda-forge/noarch::ipywidgets-8.1.2-pyhd8ed1ab_0 \n", " jupyterlab_widgets conda-forge/noarch::jupyterlab_widgets-3.0.10-pyhd8ed1ab_0 \n", " widgetsnbextension conda-forge/noarch::widgetsnbextension-4.0.10-pyhd8ed1ab_0 \n", "\n", "\n", "\n", "Downloading and Extracting Packages:\n", "widgetsnbextension-4 | 866 KB | | 0% \n", "jupyterlab_widgets-3 | 183 KB | | 0% \u001b[A\n", "\n", "ipywidgets-8.1.2 | 111 KB | | 0% \u001b[A\u001b[A\n", "jupyterlab_widgets-3 | 183 KB | ###2 | 9% \u001b[A\n", "\n", "ipywidgets-8.1.2 | 111 KB | #####3 | 14% \u001b[A\u001b[A\n", "jupyterlab_widgets-3 | 183 KB | #########7 | 26% \u001b[A\n", "\n", "ipywidgets-8.1.2 | 111 KB | #####################3 | 58% \u001b[A\u001b[A\n", "widgetsnbextension-4 | 866 KB | 6 | 2% \u001b[A\n", "\n", "ipywidgets-8.1.2 | 111 KB | ##########################6 | 72% \u001b[A\u001b[A\n", "jupyterlab_widgets-3 | 183 KB | #########################9 | 70% \u001b[A\n", "widgetsnbextension-4 | 866 KB | ###4 | 9% \u001b[A\n", "jupyterlab_widgets-3 | 183 KB | ###################################6 | 96% \u001b[A\n", "widgetsnbextension-4 | 866 KB | #############6 | 37% \u001b[A\n", "\n", "widgetsnbextension-4 | 866 KB | ################4 | 44% \u001b[A\u001b[A\n", "\n", " \u001b[A\u001b[A\n", " \u001b[A\n", "\n", " \u001b[A\u001b[A\n", "Preparing transaction: done\n", "Verifying transaction: done\n", "Executing transaction: done\n", "Found existing installation: xarray 2024.2.0\n", "Uninstalling xarray-2024.2.0:\n", " Successfully uninstalled xarray-2024.2.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: http://mirrors.aliyun.com/pypi/simple\n", "Collecting xarray==2023.7.0\n", " Downloading http://mirrors.aliyun.com/pypi/packages/cf/2f/e696512aa1e4e2ee1cf1e0bdbab6042f6c782058eb0a4367184ce4343f36/xarray-2023.7.0-py3-none-any.whl (1.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.21 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from xarray==2023.7.0) (1.26.4)\n", "Requirement already satisfied: pandas>=1.4 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from xarray==2023.7.0) (2.2.1)\n", "Requirement already satisfied: packaging>=21.3 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from xarray==2023.7.0) (23.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from pandas>=1.4->xarray==2023.7.0) (2.9.0)\n", "Requirement already satisfied: pytz>=2020.1 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from pandas>=1.4->xarray==2023.7.0) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from pandas>=1.4->xarray==2023.7.0) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /root/miniconda3/envs/GraphCast/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas>=1.4->xarray==2023.7.0) (1.16.0)\n", "Installing collected packages: xarray\n", "Successfully installed xarray-2023.7.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "# @title 安装其他依赖项,并解决 xarray 的版本问题。\n", "\n", "# 这里需要将xarray的版本从2023.12.0(2023年12月30日安装)降低到2023.7.0,否则会报错。\n", "\n", "!conda install -y -c conda-forge ipywidgets\n", "!pip uninstall -y xarray\n", "!pip install xarray==2023.7.0" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "cellView": "form", "id": "Z_j8ej4Pyg1L" }, "outputs": [], "source": [ "# @title 导入库\n", "\n", "\n", "import dataclasses\n", "import datetime\n", "import functools\n", "import math\n", "import re\n", "from typing import Optional\n", "\n", "import cartopy.crs as ccrs\n", "#from google.cloud import storage\n", "from graphcast import autoregressive\n", "from graphcast import casting\n", "from graphcast import checkpoint\n", "from graphcast import data_utils\n", "from graphcast import graphcast\n", "from graphcast import normalization\n", "from graphcast import rollout\n", "from graphcast import xarray_jax\n", "from graphcast import xarray_tree\n", "from IPython.display import HTML\n", "import ipywidgets as widgets\n", "import haiku as hk\n", "import jax\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "from matplotlib import animation\n", "import numpy as np\n", "import xarray\n", "\n", "\n", "\n", "\n", "def parse_file_parts(file_name):\n", " return dict(part.split(\"-\", 1) for part in file_name.split(\"_\"))\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "cellView": "form", "id": "5JUymx84dI2m" }, "outputs": [], "source": [ "# @title 载入绘图函数\n", "\n", "\n", "def select(\n", " data: xarray.Dataset,\n", " variable: str,\n", " level: Optional[int] = None,\n", " max_steps: Optional[int] = None\n", " ) -> xarray.Dataset:\n", " data = data[variable]\n", " if \"batch\" in data.dims:\n", " data = data.isel(batch=0)\n", " if max_steps is not None and \"time\" in data.sizes and max_steps < data.sizes[\"time\"]:\n", " data = data.isel(time=range(0, max_steps))\n", " if level is not None and \"level\" in data.coords:\n", " data = data.sel(level=level)\n", " return data\n", "\n", "def scale(\n", " data: xarray.Dataset,\n", " center: Optional[float] = None,\n", " robust: bool = False,\n", " ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", " vmin = np.nanpercentile(data, (2 if robust else 0))\n", " vmax = np.nanpercentile(data, (98 if robust else 100))\n", " if center is not None:\n", " diff = max(vmax - center, center - vmin)\n", " vmin = center - diff\n", " vmax = center + diff\n", " return (data, matplotlib.colors.Normalize(vmin, vmax),\n", " (\"RdBu_r\" if center is not None else \"viridis\"))\n", "\n", "def plot_data(\n", " data: dict[str, xarray.Dataset],\n", " fig_title: str,\n", " plot_size: float = 5,\n", " robust: bool = False,\n", " cols: int = 4\n", " ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", "\n", " first_data = next(iter(data.values()))[0]\n", " max_steps = first_data.sizes.get(\"time\", 1)\n", " assert all(max_steps == d.sizes.get(\"time\", 1) for d, _, _ in data.values())\n", "\n", " cols = min(cols, len(data))\n", " rows = math.ceil(len(data) / cols)\n", " figure = plt.figure(figsize=(plot_size * 2 * cols,\n", " plot_size * rows))\n", " figure.suptitle(fig_title, fontsize=16)\n", " figure.subplots_adjust(wspace=0, hspace=0)\n", " figure.tight_layout()\n", "\n", " images = []\n", " for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):\n", " ax = figure.add_subplot(rows, cols, i+1)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " ax.set_title(title)\n", " im = ax.imshow(\n", " plot_data.isel(time=0, missing_dims=\"ignore\"), norm=norm,\n", " origin=\"lower\", cmap=cmap)\n", " plt.colorbar(\n", " mappable=im,\n", " ax=ax,\n", " orientation=\"vertical\",\n", " pad=0.02,\n", " aspect=16,\n", " shrink=0.75,\n", " cmap=cmap,\n", " extend=(\"both\" if robust else \"neither\"))\n", " images.append(im)\n", "\n", " def update(frame):\n", " if \"time\" in first_data.dims:\n", " td = datetime.timedelta(microseconds=first_data[\"time\"][frame].item() / 1000)\n", " figure.suptitle(f\"{fig_title}, {td}\", fontsize=16)\n", " else:\n", " figure.suptitle(fig_title, fontsize=16)\n", " for im, (plot_data, norm, cmap) in zip(images, data.values()):\n", " im.set_data(plot_data.isel(time=frame, missing_dims=\"ignore\"))\n", "\n", " ani = animation.FuncAnimation(\n", " fig=figure, func=update, frames=max_steps, interval=250)\n", " plt.close(figure.number)\n", " return HTML(ani.to_jshtml())" ] }, { "cell_type": "markdown", "metadata": { "id": "WEtSV8HEkHtf" }, "source": [ "# 加载数据并初始化模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 载入模型参数\n", "\n", "选择两种获取模型参数的方式之一:\n", "- **random**:您将获得随机预测,但您可以更改模型架构,这可能会使其运行更快或适应您的设备。\n", "- **checkpoint**:您将获得明智的预测,但受限于模型训练时使用的架构,这可能不适合您的设备。特别是生成梯度会使用大量内存,因此您至少需要25GB的内存(TPUv4或A100)。\n", "\n", "检查点在一些方面有所不同:\n", "- 网格大小指定了地球的内部图形表示。较小的网格将运行更快,但输出将更差。网格大小不影响模型的参数数量。\n", "- 分辨率和压力级别的数量必须匹配数据。较低的分辨率和较少的级别会运行得更快。数据分辨率仅影响编码器/解码器。\n", "- 我们的所有模型都预测降水。然而,ERA5包含降水,而HRES不包含。我们标记为 \"ERA5\" 的模型将降水作为输入,并期望以ERA5数据作为输入,而标记为 \"ERA5-HRES\" 的模型不以降水作为输入,并专门训练以HRES-fc0作为输入(请参阅下面的数据部分)。\n", "\n", "我们提供三个预训练模型:\n", "1. `GraphCast`,用于GraphCast论文的高分辨率模型(0.25度分辨率,37个压力级别),在1979年至2017年间使用ERA5数据进行训练,\n", "\n", "2. `GraphCast_small`,GraphCast的较小低分辨率版本(1度分辨率,13个压力级别和较小的网格),在1979年至2015年间使用ERA5数据进行训练,适用于具有较低内存和计算约束的模型运行,\n", "\n", "3. `GraphCast_operational`,一个高分辨率模型(0.25度分辨率,13个压力级别),在1979年至2017年使用ERA5数据进行预训练,并在2016年至2021年间使用HRES数据进行微调。此模型可以从HRES数据初始化(不需要降水输入)。\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0a6232a8763d4f879d0bdae53802d75e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Tab(children=(VBox(children=(IntSlider(value=4, description='Mesh size:', max=6, min=4), IntSli…" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @title 选择模型\n", "# Rewrite by S.F. Sune, https://github.com/sfsun67.\n", "'''\n", " 我们有三种训练好的模型可供选择, 需要从https://console.cloud.google.com/storage/browser/dm_graphcast准备:\n", " GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz\n", " GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz\n", " GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz\n", "'''\n", "# 在此路径 /root/data/params 中查找结果,并列出 \"params/\"中所有文件的名称,去掉名称中的 \"params/\"perfix。\n", "\n", "import os\n", "import glob\n", "\n", "# 定义数据目录,请替换成你自己的目录。\n", "dir_path_params = \"/root/data/params\"\n", "\n", "# Use glob to get all file paths in the directory\n", "file_paths_params = glob.glob(os.path.join(dir_path_params, \"*\"))\n", "\n", "# Remove the directory path and the \".../params/\" prefix from each file name\n", "params_file_options = [os.path.basename(path) for path in file_paths_params]\n", "\n", "\n", "random_mesh_size = widgets.IntSlider(\n", " value=4, min=4, max=6, description=\"Mesh size:\")\n", "random_gnn_msg_steps = widgets.IntSlider(\n", " value=4, min=1, max=32, description=\"GNN message steps:\")\n", "random_latent_size = widgets.Dropdown(\n", " options=[int(2**i) for i in range(4, 10)], value=32,description=\"Latent size:\")\n", "random_levels = widgets.Dropdown(\n", " options=[13, 37], value=13, description=\"Pressure levels:\")\n", "\n", "\n", "params_file = widgets.Dropdown(\n", " options=params_file_options,\n", " description=\"Params file:\",\n", " layout={\"width\": \"max-content\"})\n", "\n", "source_tab = widgets.Tab([\n", " widgets.VBox([\n", " random_mesh_size,\n", " random_gnn_msg_steps,\n", " random_latent_size,\n", " random_levels,\n", " ]),\n", " params_file,\n", "])\n", "source_tab.set_title(0, \"随机参数权重(Random)\")\n", "source_tab.set_title(1, \"预训练权重(Checkpoint)\")\n", "widgets.VBox([\n", " source_tab,\n", " widgets.Label(value=\"运行下一个单元格以加载模型。重新运行该单元格将清除您的选择。\")\n", "])\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "cellView": "form", "id": "lYQgrPgPdI2n" }, "outputs": [ { "data": { "text/plain": [ "ModelConfig(resolution=0, mesh_size=4, latent_size=32, gnn_msg_steps=4, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @title 加载模型\n", "\n", "source = source_tab.get_title(source_tab.selected_index)\n", "\n", "if source == \"随机参数权重(Random)\":\n", " params = None # Filled in below\n", " state = {}\n", " model_config = graphcast.ModelConfig(\n", " resolution=0,\n", " mesh_size=random_mesh_size.value,\n", " latent_size=random_latent_size.value,\n", " gnn_msg_steps=random_gnn_msg_steps.value,\n", " hidden_layers=1,\n", " radius_query_fraction_edge_length=0.6)\n", " task_config = graphcast.TaskConfig(\n", " input_variables=graphcast.TASK.input_variables,\n", " target_variables=graphcast.TASK.target_variables,\n", " forcing_variables=graphcast.TASK.forcing_variables,\n", " pressure_levels=graphcast.PRESSURE_LEVELS[random_levels.value],\n", " input_duration=graphcast.TASK.input_duration,\n", " )\n", "else:\n", " assert source == \"预训练权重(Checkpoint)\"\n", " '''with gcs_bucket.blob(f\"params/{params_file.value}\").open(\"rb\") as f:\n", " ckpt = checkpoint.load(f, graphcast.CheckPoint)'''\n", " \n", " with open(f\"{dir_path_params}/{params_file.value}\", \"rb\") as f:\n", " ckpt = checkpoint.load(f, graphcast.CheckPoint)\n", " \n", " params = ckpt.params\n", " state = {}\n", "\n", " model_config = ckpt.model_config\n", " task_config = ckpt.task_config\n", " print(\"模型描述:\\n\", ckpt.description, \"\\n\")\n", " print(\"模型许可信息:\\n\", ckpt.license, \"\\n\")\n", "\n", "model_config" ] }, { "cell_type": "markdown", "metadata": { "id": "rQWk0RRuCjDN" }, "source": [ "## 载入示例数据\n", "\n", "有几个示例数据集可用,在几个坐标轴上各不相同:\n", "- **来源**:fake、era5、hres\n", "- **分辨率**:0.25度、1度、6度\n", "- **级别**:13, 37\n", "- **步数**:包含多少个时间步\n", "\n", "并非所有组合都可用。\n", "- 由于加载内存的要求,较高分辨率只适用于较少的步数。\n", "- HRES 只有 0.25 度,13 个压力等级。\n", "\n", "数据分辨率必须与加载的模型相匹配。\n", "\n", "对基础数据集进行了一些转换:\n", "- 我们累积了 6 个小时的降水量,而不是默认的 1 个小时。\n", "- 对于 HRES 数据,每个时间步对应 HRES 在前导时间 0 的预报,实际上提供了 HRES 的 \"初始化\"。有关详细描述,请参见 GraphCast 论文中的 HRES-fc0。请注意,HRES 无法提供 6 小时的累积降水量,因此我们的模型以 HRES 输入不依赖于降水。但由于我们的模型可以预测降水,因此在示例数据中包含了 ERA5 降水量,以作为地面真实情况的示例。\n", "- 我们在数据中加入了 ERA5 的 \"toa_incident_solar_radiation\"。我们的模型使用 -6h、0h 和 +6h 辐射作为每 1 步预测的强迫项。在运行中,如果没有现成的 +6h 辐射,可以使用诸如 `pysolar` 等软件包计算辐射。\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "cellView": "form", "id": "-DJzie5me2-H" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "12400763c92e4d319bdc8e8523bcc0c2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Dropdown(description='数据文件:', layout=Layout(width='max-content'), options=(), value=None), Labe…" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @title 获取和筛选可用示例数据的列表\n", "\n", "# Rewrite by S.F. Sune, https://github.com/sfsun67.\n", "# 在“/root/data/dataset”路径下查找结果,并列出“dataset/”中所有文件的名称列表,去掉“dataset/”前缀。\n", "\n", "# 定义数据目录,请替换成你自己的目录。\n", "dir_path_dataset = \"/root/data/dataset\"\n", "\n", "# Use glob to get all file paths in the directory\n", "file_paths_dataset = glob.glob(os.path.join(dir_path_dataset, \"*\"))\n", "\n", "# Remove the directory path and the \".../params/\" prefix from each file name\n", "dataset_file_options = [os.path.basename(path) for path in file_paths_dataset]\n", "#print(\"dataset_file_options: \", dataset_file_options)\n", "\n", "# Remove \"dataset-\" prefix from each file name\n", "dataset_file_options = [name.removeprefix(\"dataset-\") for name in dataset_file_options]\n", "\n", "\n", "def data_valid_for_model(\n", " file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):\n", " file_parts = parse_file_parts(file_name.removesuffix(\".nc\"))\n", " #print(\"file_parts: \", file_parts)\n", " return (\n", " model_config.resolution in (0, float(file_parts[\"res\"])) and\n", " len(task_config.pressure_levels) == int(file_parts[\"levels\"]) and\n", " (\n", " (\"total_precipitation_6hr\" in task_config.input_variables and\n", " file_parts[\"source\"] in (\"era5\", \"fake\")) or\n", " (\"total_precipitation_6hr\" not in task_config.input_variables and\n", " file_parts[\"source\"] in (\"hres\", \"fake\"))\n", " )\n", " )\n", "\n", "\n", "dataset_file = widgets.Dropdown(\n", " options=[\n", " (\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(option.removesuffix(\".nc\")).items()]), option)\n", " for option in dataset_file_options\n", " if data_valid_for_model(option, model_config, task_config)\n", " ],\n", " description=\"数据文件:\",\n", " layout={\"width\": \"max-content\"})\n", "widgets.VBox([\n", " dataset_file,\n", " widgets.Label(value=\"运行下一个单元格以加载数据集。重新运行此单元格将清除您的选择,并重新筛选与您的模型匹配的数据集。\")\n", "])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "cellView": "form", "id": "Yz-ekISoJxeZ" }, "outputs": [ { "ename": "AttributeError", "evalue": "'NoneType' object has no attribute 'removesuffix'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[11], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# @title 加载气象数据\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43mdata_valid_for_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtask_config\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid dataset file, rerun the cell above and choose a valid dataset file.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124;03m'''with gcs_bucket.blob(f\"dataset/{dataset_file.value}\").open(\"rb\") as f:\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;124;03m example_batch = xarray.load_dataset(f).compute()'''\u001b[39;00m\n", "Cell \u001b[0;32mIn[10], line 22\u001b[0m, in \u001b[0;36mdata_valid_for_model\u001b[0;34m(file_name, model_config, task_config)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdata_valid_for_model\u001b[39m(\n\u001b[1;32m 21\u001b[0m file_name: \u001b[38;5;28mstr\u001b[39m, model_config: graphcast\u001b[38;5;241m.\u001b[39mModelConfig, task_config: graphcast\u001b[38;5;241m.\u001b[39mTaskConfig):\n\u001b[0;32m---> 22\u001b[0m file_parts \u001b[38;5;241m=\u001b[39m parse_file_parts(\u001b[43mfile_name\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mremovesuffix\u001b[49m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.nc\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m#print(\"file_parts: \", file_parts)\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 25\u001b[0m model_config\u001b[38;5;241m.\u001b[39mresolution \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mfloat\u001b[39m(file_parts[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mres\u001b[39m\u001b[38;5;124m\"\u001b[39m])) \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28mlen\u001b[39m(task_config\u001b[38;5;241m.\u001b[39mpressure_levels) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mint\u001b[39m(file_parts[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlevels\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 32\u001b[0m )\n\u001b[1;32m 33\u001b[0m )\n", "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'removesuffix'" ] } ], "source": [ "# @title 加载气象数据\n", "\n", "\n", "if not data_valid_for_model(dataset_file.value, model_config, task_config):\n", " raise ValueError(\n", " \"Invalid dataset file, rerun the cell above and choose a valid dataset file.\")\n", "\n", "'''with gcs_bucket.blob(f\"dataset/{dataset_file.value}\").open(\"rb\") as f:\n", " example_batch = xarray.load_dataset(f).compute()'''\n", "\n", "with open(f\"{dir_path_dataset}/dataset-{dataset_file.value}\", \"rb\") as f:\n", " example_batch = xarray.load_dataset(f).compute()\n", "\n", "assert example_batch.dims[\"time\"] >= 3 # 2 for input, >=1 for targets\n", "\n", "print(\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(dataset_file.value.removesuffix(\".nc\")).items()]))\n", "\n", "example_batch" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "lXjFvdE6qStr" }, "outputs": [], "source": [ "# @title 选择绘图数据\n", "\n", "plot_example_variable = widgets.Dropdown(\n", " options=example_batch.data_vars.keys(),\n", " value=\"2m_temperature\",\n", " description=\"变量\")\n", "plot_example_level = widgets.Dropdown(\n", " options=example_batch.coords[\"level\"].values,\n", " value=500,\n", " description=\"级别\")\n", "plot_example_robust = widgets.Checkbox(value=True, description=\"鲁棒性\")\n", "plot_example_max_steps = widgets.IntSlider(\n", " min=1, max=example_batch.dims[\"time\"], value=example_batch.dims[\"time\"],\n", " description=\"最大步\")\n", "\n", "widgets.VBox([\n", " plot_example_variable,\n", " plot_example_level,\n", " plot_example_robust,\n", " plot_example_max_steps,\n", " widgets.Label(value=\"运行下一个单元格以绘制数据。重新运行此单元格将清除您的选择。\")\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "kIK-EgMdkHtk" }, "outputs": [], "source": [ "# @title 绘制示例数据\n", "\n", "\n", "plot_size = 7\n", "\n", "data = {\n", " \" \": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),\n", " robust=plot_example_robust.value),\n", "}\n", "fig_title = plot_example_variable.value\n", "if \"等级\" in example_batch[plot_example_variable.value].coords:\n", " fig_title += f\" at {plot_example_level.value} hPa\"\n", "\n", "plot_data(data, fig_title, plot_size, plot_example_robust.value)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tPVy1GHokHtk" }, "outputs": [], "source": [ "# @title 选择要提取的训练和评估数据\n", "\n", "train_steps = widgets.IntSlider(\n", " value=1, min=1, max=example_batch.sizes[\"time\"]-2, description=\"训练步数\")\n", "eval_steps = widgets.IntSlider(\n", " value=example_batch.sizes[\"time\"]-2, min=1, max=example_batch.sizes[\"time\"]-2, description=\"评估步数\")\n", "\n", "widgets.VBox([\n", " train_steps,\n", " eval_steps,\n", " widgets.Label(value=\"运行下一个单元格以提取数据。重新运行此单元格将清除您的选择。\")\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Ogp4vTBvsgSt" }, "outputs": [], "source": [ "# @title 提取训练和评估数据\n", "\n", "train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", " example_batch, target_lead_times=slice(\"6h\", f\"{train_steps.value*6}h\"),\n", " **dataclasses.asdict(task_config))\n", "\n", "eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(\n", " example_batch, target_lead_times=slice(\"6h\", f\"{eval_steps.value*6}h\"),\n", " **dataclasses.asdict(task_config))\n", "\n", "print(\"所有示例: \", example_batch.dims.mapping)\n", "print(\"训练输入: \", train_inputs.dims.mapping)\n", "print(\"训练目标: \", train_targets.dims.mapping)\n", "print(\"训练强迫:\", train_forcings.dims.mapping)\n", "print(\"评估输入: \", eval_inputs.dims.mapping)\n", "print(\"评估目标: \", eval_targets.dims.mapping)\n", "print(\"评估强迫项: \", eval_forcings.dims.mapping)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Q--ZRhpTdI2o" }, "outputs": [], "source": [ "# @title 加载规范化数据\n", "# Rewrite by S.F. Sune, https://github.com/sfsun67.\n", "dir_path_stats = \"/root/data/stats\"\n", "\n", "with open(f\"{dir_path_stats}/stats-diffs_stddev_by_level.nc\", \"rb\") as f:\n", " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", "with open(f\"{dir_path_stats}/stats-mean_by_level.nc\", \"rb\") as f:\n", " mean_by_level = xarray.load_dataset(f).compute()\n", "with open(f\"{dir_path_stats}/stats-stddev_by_level.nc\", \"rb\") as f:\n", " stddev_by_level = xarray.load_dataset(f).compute()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ke2zQyuT_sMA" }, "outputs": [], "source": [ "# @title 构建 jitted 函数,并可能初始化随机权重\n", "# 构建模型并初始化权重\n", "\n", "# 模型组网\n", "def construct_wrapped_graphcast(\n", " model_config: graphcast.ModelConfig,\n", " task_config: graphcast.TaskConfig):\n", " \"\"\"Constructs and wraps the GraphCast Predictor.\"\"\"\n", " # Deeper one-step predictor.\n", " predictor = graphcast.GraphCast(model_config, task_config)\n", "\n", " # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to\n", " # from/to float32 to/from BFloat16.\n", " predictor = casting.Bfloat16Cast(predictor)\n", "\n", " # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from\n", " # BFloat16 happens after applying normalization to the inputs/targets.\n", " predictor = normalization.InputsAndResiduals(\n", " predictor,\n", " diffs_stddev_by_level=diffs_stddev_by_level,\n", " mean_by_level=mean_by_level,\n", " stddev_by_level=stddev_by_level)\n", "\n", " # Wraps everything so the one-step model can produce trajectories.\n", " predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)\n", " return predictor\n", "\n", "# 前向运算\n", "@hk.transform_with_state\n", "def run_forward(model_config, task_config, inputs, targets_template, forcings):\n", " predictor = construct_wrapped_graphcast(model_config, task_config)\n", " return predictor(inputs, targets_template=targets_template, forcings=forcings)\n", "\n", "# 计算损失函数\n", "@hk.transform_with_state # used to convert a pure function into a stateful function\n", "def loss_fn(model_config, task_config, inputs, targets, forcings):\n", " predictor = construct_wrapped_graphcast(model_config, task_config) # constructs and wraps a GraphCast Predictor, which is a model used for making predictions in a graph-based machine learning task.\n", " loss, diagnostics = predictor.loss(inputs, targets, forcings)\n", " return xarray_tree.map_structure(\n", " lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),\n", " (loss, diagnostics))\n", "\n", "# 计算梯度\n", "def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):\n", " def _aux(params, state, i, t, f):\n", " (loss, diagnostics), next_state = loss_fn.apply(\n", " params, state, jax.random.PRNGKey(0), model_config, task_config,\n", " i, t, f)\n", " return loss, (diagnostics, next_state)\n", " (loss, (diagnostics, next_state)), grads = jax.value_and_grad(\n", " _aux, has_aux=True)(params, state, inputs, targets, forcings)\n", " return loss, diagnostics, next_state, grads\n", "\n", "# Jax doesn't seem to like passing configs as args through the jit. Passing it\n", "# in via partial (instead of capture by closure) forces jax to invalidate the\n", "# jit cache if you change configs.\n", "def with_configs(fn):\n", " return functools.partial(\n", " fn, model_config=model_config, task_config=task_config)\n", "\n", "# Always pass params and state, so the usage below are simpler\n", "def with_params(fn):\n", " return functools.partial(fn, params=params, state=state)\n", "\n", "# Our models aren't stateful, so the state is always empty, so just return the\n", "# predictions. This is requiredy by our rollout code, and generally simpler.\n", "def drop_state(fn):\n", " return lambda **kw: fn(**kw)[0]\n", "\n", "init_jitted = jax.jit(with_configs(run_forward.init))\n", "\n", "if params is None:\n", " params, state = init_jitted(\n", " rng=jax.random.PRNGKey(0),\n", " inputs=train_inputs,\n", " targets_template=train_targets,\n", " forcings=train_forcings)\n", "\n", "loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))\n", "grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))\n", "run_forward_jitted = drop_state(with_params(jax.jit(with_configs(\n", " run_forward.apply))))" ] }, { "cell_type": "markdown", "metadata": { "id": "VBNutliiCyqA" }, "source": [ "# 运行模型\n", "\n", "请注意,第一次运行下面的单元格可能需要一段时间(可能几分钟),因为这包括代码编译的时间。第二次运行时速度会明显加快。\n", "\n", "这将使用 python 循环迭代预测步骤,其中 1 步的预测是固定的。这比下面的训练步骤对内存的要求要低,应该可以使用小型 GraphCast 模型对 1 度分辨率数据进行 4 步预测。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "7obeY9i9oTtD" }, "outputs": [], "source": [ "# @标题 递归计算(在 python 中的循环)\n", "\n", "assert model_config.resolution in (0, 360. / eval_inputs.sizes[\"lon\"]), (\n", " \"Model resolution doesn't match the data resolution. You likely want to \"\n", " \"re-filter the dataset list, and download the correct data.\")\n", "\n", "print(\"Inputs: \", eval_inputs.dims.mapping)\n", "print(\"Targets: \", eval_targets.dims.mapping)\n", "print(\"Forcings:\", eval_forcings.dims.mapping)\n", "\n", "predictions = rollout.chunked_prediction(\n", " run_forward_jitted,\n", " rng=jax.random.PRNGKey(0),\n", " inputs=eval_inputs,\n", " targets_template=eval_targets * np.nan,\n", " forcings=eval_forcings)\n", "predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ft298eZskHtn" }, "outputs": [], "source": [ "# @title 选择要绘制的预测结果\n", "\n", "plot_pred_variable = widgets.Dropdown(\n", " options=predictions.data_vars.keys(),\n", " value=\"2m_temperature\",\n", " description=\"变量\")\n", "plot_pred_level = widgets.Dropdown(\n", " options=predictions.coords[\"level\"].values,\n", " value=500,\n", " description=\"级别\")\n", "plot_pred_robust = widgets.Checkbox(value=True, description=\"鲁棒性\")\n", "plot_pred_max_steps = widgets.IntSlider(\n", " min=1,\n", " max=predictions.dims[\"time\"],\n", " value=predictions.dims[\"time\"],\n", " description=\"最大步\")\n", "\n", "widgets.VBox([\n", " plot_pred_variable,\n", " plot_pred_level,\n", " plot_pred_robust,\n", " plot_pred_max_steps,\n", " widgets.Label(value=\"运行下一个单元格,绘制预测结果。重新运行该单元格将清除您的选择。\")\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "_tTdx6fmmj1I" }, "outputs": [], "source": [ "# @title 使用预测数据绘图\n", "\n", "\n", "plot_size = 5\n", "plot_max_steps = min(predictions.dims[\"time\"], plot_pred_max_steps.value)\n", "\n", "data = {\n", " \"Targets\": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", " \"Predictions\": scale(select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", " \"Diff\": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -\n", " select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),\n", " robust=plot_pred_robust.value, center=0),\n", "}\n", "fig_title = plot_pred_variable.value\n", "if \"level\" in predictions[plot_pred_variable.value].coords:\n", " fig_title += f\" at {plot_pred_level.value} hPa\"\n", "\n", "plot_data(data, fig_title, plot_size, plot_pred_robust.value)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Pa78b64bLYe1" }, "source": [ "# 训练模型\n", "\n", "以下操作需要大量内存,而且根据所使用的加速器,只能在低分辨率数据上拟合很小的 \"随机 \"模型。它使用上面选择的训练步数。\n", "\n", "第一次执行单元需要更多时间,因为其中包括函数的 jit 时间。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Nv-u3dAP7IRZ" }, "outputs": [], "source": [ "# @title 损失计算(多步骤递归(自回归)损失)\n", "loss, diagnostics = loss_fn_jitted(\n", " rng=jax.random.PRNGKey(0),\n", " inputs=train_inputs,\n", " targets=train_targets,\n", " forcings=train_forcings)\n", "\n", "print(\"Loss:\", float(loss))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "mBNFq1IGZNLz" }, "outputs": [], "source": [ "# @title 梯度计算(通过时间进行反推)\n", "loss, diagnostics, next_state, grads = grads_fn_jitted(\n", " inputs=train_inputs,\n", " targets=train_targets,\n", " forcings=train_forcings)\n", "mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", "print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "J4FJFKWD8Loz" }, "outputs": [], "source": [ "# @title 递归(自回归)推出(在 JAX 中保持循环)\n", "print(\"Inputs: \", train_inputs.dims.mapping)\n", "print(\"Targets: \", train_targets.dims.mapping)\n", "print(\"Forcings:\", train_forcings.dims.mapping)\n", "\n", "predictions = run_forward_jitted(\n", " rng=jax.random.PRNGKey(0),\n", " inputs=train_inputs,\n", " targets_template=train_targets * np.nan,\n", " forcings=train_forcings)\n", "predictions" ] } ], "metadata": { "colab": { "name": "GraphCast", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 0 }