{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# @title Imports\n", "\n", "import dataclasses\n", "import datetime\n", "import functools\n", "import math\n", "import re\n", "from typing import Optional\n", "\n", "import cartopy.crs as ccrs\n", "#from google.cloud import storage\n", "from graphcast import autoregressive\n", "from graphcast import casting\n", "from graphcast import checkpoint\n", "from graphcast import data_utils\n", "from graphcast import graphcast\n", "from graphcast import normalization\n", "from graphcast import rollout\n", "from graphcast import xarray_jax\n", "from graphcast import xarray_tree\n", "from IPython.display import HTML\n", "import ipywidgets as widgets\n", "import haiku as hk\n", "import jax\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "from matplotlib import animation\n", "import numpy as np\n", "import xarray\n", "\n", "\n", "\n", "\n", "def parse_file_parts(file_name):\n", " return dict(part.split(\"-\", 1) for part in file_name.split(\"_\"))\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "id": "KGaJ6V9MdI2n" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "08eac6c9a6514a2e9d154627837d8f7d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Tab(children=(VBox(children=(IntSlider(value=4, description='Mesh size:', max=6, min=4), IntSli…" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @title Choose the model\n", "# Rewrite by S.F. Sune, https://github.com/sfsun67.\n", "'''\n", " We have three options. Acquiring from https://console.cloud.google.com/storage/browser/dm_graphcast:\n", " GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz\n", " GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz\n", " GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz\n", "'''\n", "# find the result in this path /root/data/params, and list of names of all files in the \"params/\", with the \"params/\" perfix removed from the names.\n", "import os\n", "import glob\n", "\n", "# Define the directory path\n", "dir_path_params = \"/root/data/params\"\n", "\n", "# Use glob to get all file paths in the directory\n", "file_paths_params = glob.glob(os.path.join(dir_path_params, \"*\"))\n", "\n", "# Remove the directory path and the \".../params/\" prefix from each file name\n", "params_file_options = [os.path.basename(path) for path in file_paths_params]\n", "\n", "\n", "random_mesh_size = widgets.IntSlider(\n", " value=4, min=4, max=6, description=\"Mesh size:\")\n", "random_gnn_msg_steps = widgets.IntSlider(\n", " value=4, min=1, max=32, description=\"GNN message steps:\")\n", "random_latent_size = widgets.Dropdown(\n", " options=[int(2**i) for i in range(4, 10)], value=32,description=\"Latent size:\")\n", "random_levels = widgets.Dropdown(\n", " options=[13, 37], value=13, description=\"Pressure levels:\")\n", "\n", "\n", "params_file = widgets.Dropdown(\n", " options=params_file_options,\n", " description=\"Params file:\",\n", " layout={\"width\": \"max-content\"})\n", "\n", "source_tab = widgets.Tab([\n", " widgets.VBox([\n", " random_mesh_size,\n", " random_gnn_msg_steps,\n", " random_latent_size,\n", " random_levels,\n", " ]),\n", " params_file,\n", "])\n", "source_tab.set_title(0, \"Random\")\n", "source_tab.set_title(1, \"Checkpoint\")\n", "widgets.VBox([\n", " source_tab,\n", " widgets.Label(value=\"Run the next cell to load the model. Rerunning this cell clears your selection.\")\n", "])\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "id": "lYQgrPgPdI2n" }, "outputs": [ { "data": { "text/plain": [ "ModelConfig(resolution=0, mesh_size=4, latent_size=32, gnn_msg_steps=4, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# @title Load the model\n", "\n", "source = source_tab.get_title(source_tab.selected_index)\n", "\n", "if source == \"Random\":\n", " params = None # Filled in below\n", " state = {}\n", " model_config = graphcast.ModelConfig(\n", " resolution=0,\n", " mesh_size=random_mesh_size.value,\n", " latent_size=random_latent_size.value,\n", " gnn_msg_steps=random_gnn_msg_steps.value,\n", " hidden_layers=1,\n", " radius_query_fraction_edge_length=0.6)\n", " task_config = graphcast.TaskConfig(\n", " input_variables=graphcast.TASK.input_variables,\n", " target_variables=graphcast.TASK.target_variables,\n", " forcing_variables=graphcast.TASK.forcing_variables,\n", " pressure_levels=graphcast.PRESSURE_LEVELS[random_levels.value],\n", " input_duration=graphcast.TASK.input_duration,\n", " )\n", "else:\n", " assert source == \"Checkpoint\"\n", " '''with gcs_bucket.blob(f\"params/{params_file.value}\").open(\"rb\") as f:\n", " ckpt = checkpoint.load(f, graphcast.CheckPoint)'''\n", " \n", " with open(f\"{dir_path_params}/{params_file.value}\", \"rb\") as f:\n", " ckpt = checkpoint.load(f, graphcast.CheckPoint)\n", " \n", " params = ckpt.params\n", " state = {}\n", "\n", " model_config = ckpt.model_config\n", " task_config = ckpt.task_config\n", " print(\"Model description:\\n\", ckpt.description, \"\\n\")\n", " print(\"Model license:\\n\", ckpt.license, \"\\n\")\n", "\n", "model_config" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "id": "-DJzie5me2-H" }, "outputs": [ { "ename": "NameError", "evalue": "name 'glob' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[2], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m dir_path_dataset \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/root/data/dataset\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Use glob to get all file paths in the directory\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m file_paths_dataset \u001b[38;5;241m=\u001b[39m \u001b[43mglob\u001b[49m\u001b[38;5;241m.\u001b[39mglob(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(dir_path_dataset, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m*\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# Remove the directory path and the \".../params/\" prefix from each file name\u001b[39;00m\n\u001b[1;32m 12\u001b[0m dataset_file_options \u001b[38;5;241m=\u001b[39m [os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mbasename(path) \u001b[38;5;28;01mfor\u001b[39;00m path \u001b[38;5;129;01min\u001b[39;00m file_paths_dataset]\n", "\u001b[0;31mNameError\u001b[0m: name 'glob' is not defined" ] } ], "source": [ "# @title Get and filter the list of available example datasets\n", "# Rewrite by S.F. Sune, https://github.com/sfsun67.\n", "# find the result in this path /root/data/dataset, and list of names of all files in the \"dataset/\", with the \"dataset/\" perfix removed from the names.\n", "\n", "# Define the directory path\n", "dir_path_dataset = \"/root/data/dataset\"\n", "\n", "# Use glob to get all file paths in the directory\n", "file_paths_dataset = glob.glob(os.path.join(dir_path_dataset, \"*\"))\n", "\n", "# Remove the directory path and the \".../params/\" prefix from each file name\n", "dataset_file_options = [os.path.basename(path) for path in file_paths_dataset]\n", "#print(\"dataset_file_options: \", dataset_file_options)\n", "\n", "# Remove \"dataset-\" prefix from each file name\n", "dataset_file_options = [name.removeprefix(\"dataset-\") for name in dataset_file_options]\n", "\n", "\n", "def data_valid_for_model(\n", " file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):\n", " file_parts = parse_file_parts(file_name.removesuffix(\".nc\"))\n", " #print(\"file_parts: \", file_parts)\n", " return (\n", " model_config.resolution in (0, float(file_parts[\"res\"])) and\n", " len(task_config.pressure_levels) == int(file_parts[\"levels\"]) and\n", " (\n", " (\"total_precipitation_6hr\" in task_config.input_variables and\n", " file_parts[\"source\"] in (\"era5\", \"fake\")) or\n", " (\"total_precipitation_6hr\" not in task_config.input_variables and\n", " file_parts[\"source\"] in (\"hres\", \"fake\"))\n", " )\n", " )\n", "\n", "\n", "dataset_file = widgets.Dropdown(\n", " options=[\n", " (\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(option.removesuffix(\".nc\")).items()]), option)\n", " for option in dataset_file_options\n", " if data_valid_for_model(option, model_config, task_config)\n", " ],\n", " description=\"Dataset file:\",\n", " layout={\"width\": \"max-content\"})\n", "widgets.VBox([\n", " dataset_file,\n", " widgets.Label(value=\"Run the next cell to load the dataset. Rerunning this cell clears your selection and refilters the datasets that match your model.\")\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Yz-ekISoJxeZ" }, "outputs": [], "source": [ "# @title Load weather data\n", "\n", "if not data_valid_for_model(dataset_file.value, model_config, task_config):\n", " raise ValueError(\n", " \"Invalid dataset file, rerun the cell above and choose a valid dataset file.\")\n", "\n", "'''with gcs_bucket.blob(f\"dataset/{dataset_file.value}\").open(\"rb\") as f:\n", " example_batch = xarray.load_dataset(f).compute()'''\n", "\n", "with open(f\"{dir_path_dataset}/dataset-{dataset_file.value}\", \"rb\") as f:\n", " example_batch = xarray.load_dataset(f).compute()\n", "\n", "assert example_batch.dims[\"time\"] >= 3 # 2 for input, >=1 for targets\n", "\n", "print(\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(dataset_file.value.removesuffix(\".nc\")).items()]))\n", "\n", "example_batch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# load example_batch from .csv file\n", "import pandas as pd\n", "import xarray as xr\n", "\n", "'''\n", "NOTE: Forward slashes '/' are not allowed in variable and dimension names (got 'Nd143/Nd144'). Forward slashes are used as hierarchy-separators for HDF5-based files ('netcdf4'/'h5netcdf').\n", "'''\n", "\n", "# Load the data into a DataFrame\n", "df = pd.read_csv(\"/root/data/Sedi_test.csv\", encoding='latin1')\n", "\n", "# Convert the DataFrame to an xarray Dataset\n", "sedi_ds = xr.Dataset.from_dataframe(df)\n", "\n", "# 当 interpreted age 为 nan 时,删去该行\n", "sedi_ds = sedi_ds.dropna(dim='index', subset=['interpreted age'])\n", "\n", "# 按照 interpreted age 升序排序,并改变其他变量的顺序\n", "sedi_ds = sedi_ds.sortby('interpreted age', ascending=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Rewrite the lon and lat according the resulation of dataset.\n", "\n", "# define \n", "resolution_Rewrite_lon_lat = 1 #the resolution of longitiude and latitude\n", "\n", "def Rewrite_lon_lat(data, resolution):\n", " '''\n", " 根据 xarray 数据集中的分辨率 重写 lon 和 lat \n", " Rewrite the lon and lat according the resulation of dataset.\n", " data: the original data\n", " resolution: the resolution of the data\n", " '''\n", " condition_number = int(1/resolution)\n", " data[\"site latitude\"].data = np.round(data[\"site latitude\"].data * condition_number) / condition_number\n", " data[\"site longitude\"].data = np.round(data[\"site longitude\"].data * condition_number) / condition_number\n", "\n", " return data\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
              "Dimensions:                 (index: 3192)\n",
              "Coordinates:\n",
              "  * index                   (index) int64 171 172 175 176 ... 3543 3544 3545\n",
              "Data variables: (12/103)\n",
              "    site latitude           (index) float64 -37.0 -37.0 -36.0 ... -20.0 -20.0\n",
              "    site longitude          (index) float64 -73.0 -73.0 -73.0 ... 119.0 119.0\n",
              "    interpreted age         (index) float64 0.0 0.0 0.0 ... 3.4e+03 3.4e+03\n",
              "    Ag (ppm)                (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Al (wt%)                (index) float64 10.09 9.96 7.28 ... 0.75 1.1 0.84\n",
              "    As (ppm)                (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    ...                      ...\n",
              "    C:N (atomic)            (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta13C-org (permil)   (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta15N (permil)       (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta98Mo (permil)      (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta34S-py (permil)    (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta238U (permil)      (index) float64 nan nan nan nan ... nan nan nan nan
" ], "text/plain": [ "\n", "Dimensions: (index: 3192)\n", "Coordinates:\n", " * index (index) int64 171 172 175 176 ... 3543 3544 3545\n", "Data variables: (12/103)\n", " site latitude (index) float64 -37.0 -37.0 -36.0 ... -20.0 -20.0\n", " site longitude (index) float64 -73.0 -73.0 -73.0 ... 119.0 119.0\n", " interpreted age (index) float64 0.0 0.0 0.0 ... 3.4e+03 3.4e+03\n", " Ag (ppm) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Al (wt%) (index) float64 10.09 9.96 7.28 ... 0.75 1.1 0.84\n", " As (ppm) (index) float64 nan nan nan nan ... nan nan nan nan\n", " ... ...\n", " C:N (atomic) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta13C-org (permil) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta15N (permil) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta98Mo (permil) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta34S-py (permil) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta238U (permil) (index) float64 nan nan nan nan ... nan nan nan nan" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Rewrite the lon and lat according the dataset of xarray.\n", "# 问题:重写了经纬度的分辨率之后,如何处理新出来的经纬度的重复值?\n", "combined = Rewrite_lon_lat(sedi_ds, resolution_Rewrite_lon_lat)\n", "combined\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Processing Sedi datasets, replacing duplicates with averages ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 6/6 [00:06<00:00, 1.09s/it]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
              "Dimensions:                 (index: 445)\n",
              "Dimensions without coordinates: index\n",
              "Data variables: (12/103)\n",
              "    site latitude           (index) float64 34.0 -36.0 -37.0 ... -21.0 -20.0\n",
              "    site longitude          (index) float64 -120.0 -74.0 -73.0 ... 118.0 119.0\n",
              "    interpreted age         (index) float64 0.0 0.0 0.0 ... 3.4e+03 3.4e+03\n",
              "    Ag (ppm)                (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Al (wt%)                (index) float64 nan 7.908 10.03 ... 12.09 0.7287\n",
              "    As (ppm)                (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    ...                      ...\n",
              "    C:N (atomic)            (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta13C-org (permil)   (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta15N (permil)       (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta98Mo (permil)      (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta34S-py (permil)    (index) float64 nan nan nan nan ... nan nan nan nan\n",
              "    Delta238U (permil)      (index) float64 nan nan nan nan ... nan nan nan nan
" ], "text/plain": [ "\n", "Dimensions: (index: 445)\n", "Dimensions without coordinates: index\n", "Data variables: (12/103)\n", " site latitude (index) float64 34.0 -36.0 -37.0 ... -21.0 -20.0\n", " site longitude (index) float64 -120.0 -74.0 -73.0 ... 118.0 119.0\n", " interpreted age (index) float64 0.0 0.0 0.0 ... 3.4e+03 3.4e+03\n", " Ag (ppm) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Al (wt%) (index) float64 nan 7.908 10.03 ... 12.09 0.7287\n", " As (ppm) (index) float64 nan nan nan nan ... nan nan nan nan\n", " ... ...\n", " C:N (atomic) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta13C-org (permil) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta15N (permil) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta98Mo (permil) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta34S-py (permil) (index) float64 nan nan nan nan ... nan nan nan nan\n", " Delta238U (permil) (index) float64 nan nan nan nan ... nan nan nan nan" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均\n", "\n", "import multiprocessing as mp\n", "from tqdm import tqdm\n", "\n", "# Function to process a part of the dataset\n", "sedimentary_list = []\n", "def groupby_and_average(sedi_ds):\n", " '''\n", " # 使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均\n", " '''\n", " for site_longitude_value, site_longitude in sedi_ds.groupby(\"site longitude\"):\n", " for site_latitude_value, site_latitude in site_longitude.groupby(\"site latitude\"):\n", " for interpreted_age_value, sedi in site_latitude.groupby(\"interpreted age\"):\n", " #sedimentary_dict = sedi.apply(np.mean).to_dict() \n", " sedimentary_list.append(sedi.apply(np.mean))\n", " \n", " # Add an identifying dimension to each xr.Dataset of sedimentary_list \n", " for i, sedi_ds in enumerate(sedimentary_list):\n", " sedi_ds = sedi_ds.expand_dims({'sample': [i]})\n", "\n", " # Concatenate the datasets\n", " combined = xr.concat(sedimentary_list, dim='index')\n", "\n", "\n", " return combined, site_longitude_value, site_latitude_value, interpreted_age_value\n", "\n", "\n", "# Divide the dataset into parts\n", "part_number = 6\n", "dim = 'index' # replace with your actual dimension\n", "dim_size = sedi_ds.dims[dim]\n", "indices = np.linspace(0, dim_size, part_number+1).astype(int)\n", "parts = [sedi_ds.isel({dim: slice(indices[i], indices[i + 1])}) for i in range(part_number)]\n", "\n", "# Create a multiprocessing Pool\n", "pool = mp.Pool(mp.cpu_count())\n", "\n", "# Process each part of the dataset in parallel with a progress bar\n", "print('Processing Sedi datasets, replacing duplicates with averages ...')\n", "results = []\n", "with tqdm(total=len(parts)) as pbar:\n", " for result in pool.imap_unordered(groupby_and_average, parts):\n", " results.append(result)\n", " pbar.update(1)\n", "\n", "# Close the pool\n", "pool.close()\n", "\n", "# To combine multiple xarray.Dataset objects\n", "result_list = [result[0] for result in results]\n", "combined = xr.concat(result_list, dim='index')\n", "# 按照 interpreted age 升序排序,并改变其他变量的顺序\n", "combined = combined.sortby('interpreted age', ascending=True)\n", "combined\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Create the new xr.Dataset\n", "# When copy, notice that deep copy and shallow copy.\n", "\n", "# define \n", "resolution = resolution_Rewrite_lon_lat #the resolution of longitiude and latitude\n", "batch = 0\n", "datetime_temp = np.random.rand(1, len(list(dict.fromkeys(combined['interpreted age'].data)))) # 这里要根据非重复 age 的长度来定义 xarray 的长度\n", "datetime_temp[0, :] = list(dict.fromkeys(combined['interpreted age'].data))\n", "\n", "# Create the dimensions\n", "dims = {\n", " \"lon\": int(360/resolution),\n", " \"lat\": int(181/resolution),\n", " \"level\": 13,\n", " \"time\": len(list(dict.fromkeys(combined['interpreted age'].data))),\n", "}\n", "\n", "# Create the coordinates\n", "coords_creat = {\n", " \"lon\": np.linspace(0, 359, int(dims[\"lon\"] - (1/resolution - 1))),\n", " \"lat\": np.linspace(-90, 90, int(dims[\"lat\"] - (1/resolution - 1))),\n", " \"level\": np.arange(50, 1000, 75),\n", " \"time\": datetime_temp[0, :],\n", " \"datetime\": ([\"batch\", \"time\"], datetime_temp),\n", "}\n", "\n", "\n", "# Create the new dataset\n", "Sedi_dataset = xr.Dataset(coords = coords_creat)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sedi_dataset= \n", "Dimensions: (lon: 360, lat: 181, level: 13, time: 377, batch: 1)\n", "Coordinates:\n", " * lon (lon) float64 0.0 1.0 2.0 3.0 4.0 ... 356.0 357.0 358.0 359.0\n", " * lat (lat) float64 -90.0 -89.0 -88.0 -87.0 ... 87.0 88.0 89.0 90.0\n", " * level (level) int64 50 125 200 275 350 425 500 575 650 725 800 875 950\n", " * time (time) float64 0.0 3.74 13.96 103.0 ... 3e+03 3.25e+03 3.4e+03\n", " datetime (batch, time) float64 0.0 3.74 13.96 ... 3e+03 3.25e+03 3.4e+03\n", "Dimensions without coordinates: batch\n", "Data variables:\n", " *empty*\n" ] } ], "source": [ "print(\"Sedi_dataset=\", Sedi_dataset)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "load sedi data: 19%|█▉ | 20/103 [00:04<00:20, 4.06it/s]\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[18], line 31\u001b[0m\n\u001b[1;32m 27\u001b[0m data \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat16) \u001b[38;5;66;03m# Convert the data type to np.float32 有效,这段代码能少一半内存\u001b[39;00m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# [非常重要]如何测试这段代码????????????????????????????????????????????????\u001b[39;00m\n\u001b[0;32m---> 31\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(\u001b[43mcombined\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mindex\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m)): \n\u001b[1;32m 32\u001b[0m \u001b[38;5;66;03m# 当 age 重复的时候,使用 i-j 来保持时间不变。\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m combined[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minterpreted age\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mdata[i\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m==\u001b[39m combined[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minterpreted age\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mdata[i]:\n\u001b[1;32m 34\u001b[0m j \u001b[38;5;241m=\u001b[39m j \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;66;03m# 如果 age 重复,j 就加 1,i-j 保持时间不变\u001b[39;00m\n", "File \u001b[0;32m~/miniconda3/envs/GraphCast/lib/python3.10/site-packages/xarray/core/dataset.py:1473\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1469\u001b[0m \u001b[38;5;129m@overload\u001b[39m\n\u001b[1;32m 1470\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m: T_Dataset, key: Iterable[Hashable]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T_Dataset:\n\u001b[1;32m 1471\u001b[0m \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\n\u001b[0;32m-> 1473\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\n\u001b[1;32m 1474\u001b[0m \u001b[38;5;28mself\u001b[39m: T_Dataset, key: Mapping[Any, Any] \u001b[38;5;241m|\u001b[39m Hashable \u001b[38;5;241m|\u001b[39m Iterable[Hashable]\n\u001b[1;32m 1475\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T_Dataset \u001b[38;5;241m|\u001b[39m DataArray:\n\u001b[1;32m 1476\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Access variables or coordinates of this dataset as a\u001b[39;00m\n\u001b[1;32m 1477\u001b[0m \u001b[38;5;124;03m :py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset.\u001b[39;00m\n\u001b[1;32m 1478\u001b[0m \n\u001b[1;32m 1479\u001b[0m \u001b[38;5;124;03m Indexing with a list of names will return a new ``Dataset`` object.\u001b[39;00m\n\u001b[1;32m 1480\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 1481\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m utils\u001b[38;5;241m.\u001b[39mis_dict_like(key):\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# load sedi data into the Sedi_dataset\n", "\n", "import numpy as np\n", "import xarray as xr\n", "import unittest\n", "import numpy as np\n", "import xarray as xr\n", "j=0\n", "dims = Sedi_dataset.dims # Get the dimensions from Sedi_dataset\n", "\n", "\n", "# remove duplicate values from \n", "combined_age_remo_dupli = list(dict.fromkeys(combined['interpreted age'].data))\n", "# \n", "combined_batch = Sedi_dataset[\"batch\"].data\n", "Sedi_dataset[\"batch\"]\n", "\n", "# Add the variables from the combined dataset to the new dataset\n", "for var in tqdm(combined.data_vars, desc=\"load sedi data\"):\n", " # Skip the variables that are Coordinates.\n", " if var == \"site latitude\" or var == \"site longitude\" or var == \"interpreted age\":\n", " continue\n", "\n", " # def / 是否可以使用广播?\n", " # create a nan array with the shape of (1,664,181,360) by numpy\n", " data = np.nan * np.zeros((1, len(combined_age_remo_dupli), dims[\"lat\"], dims[\"lon\"])) # (banch, time, lat, lon)\n", " data = data.astype(np.float16) # Convert the data type to np.float32 有效,这段代码能少一半内存\n", "\n", "\n", " # [非常重要]如何测试这段代码????????????????????????????????????????????????\n", " for i in range(len(combined[\"index\"])): \n", " # 当 age 重复的时候,使用 i-j 来保持时间不变。\n", " if combined['interpreted age'].data[i-1] == combined['interpreted age'].data[i]:\n", " j = j + 1 # 如果 age 重复,j 就加 1,i-j 保持时间不变\n", " # i 指示 age,j 用来固定重复的 age,下面的代码将经纬度上的数据赋值给指定 age 。\n", " data[batch, i-j, int(combined[\"site latitude\"].values[i]), \n", " int(combined[\"site longitude\"].values[i])] = combined[var].values[i]\n", " else:\n", " # 如果 age 不重复,j 不变,i-j 在之前的基础上继续变化\n", " # i 指示 age,j 用来固定重复的 age,下面的代码将经纬度上的数据赋值给指定 age 。\n", " data[batch, i-j, int(combined[\"site latitude\"].values[i]), \n", " int(combined[\"site longitude\"].values[i])] = combined[var].values[i]\n", " j = 0 # 重置 j 的值\n", "\n", " # Create a new DataArray with the same data but new dimensions\n", " new_dataarray = xr.DataArray(\n", " data,\n", " dims=[\"batch\", \"time\", \"lat\", \"lon\"],\n", " coords={\"batch\": Sedi_dataset[\"batch\"], \"time\": combined_age_remo_dupli, \"lat\": Sedi_dataset[\"lat\"], \"lon\": Sedi_dataset[\"lon\"]}\n", " )\n", " # Add the new DataArray to the new dataset\n", " Sedi_dataset[var] = new_dataarray\n", " del data, new_dataarray\n", "\n", "Sedi_dataset" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing variables: 0%| | 0/10 [00:00 54\u001b[0m \u001b[43mpool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprocess_variable\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchunk\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# To combine multiple xarray.Dataset objects\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \n\u001b[1;32m 58\u001b[0m \n\u001b[1;32m 59\u001b[0m \n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# Close the pool\u001b[39;00m\n\u001b[1;32m 61\u001b[0m pool\u001b[38;5;241m.\u001b[39mclose()\n", "File \u001b[0;32m~/miniconda3/envs/GraphCast/lib/python3.10/multiprocessing/pool.py:367\u001b[0m, in \u001b[0;36mPool.map\u001b[0;34m(self, func, iterable, chunksize)\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmap\u001b[39m(\u001b[38;5;28mself\u001b[39m, func, iterable, chunksize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 363\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m'''\u001b[39;00m\n\u001b[1;32m 364\u001b[0m \u001b[38;5;124;03m Apply `func` to each element in `iterable`, collecting the results\u001b[39;00m\n\u001b[1;32m 365\u001b[0m \u001b[38;5;124;03m in a list that is returned.\u001b[39;00m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;124;03m '''\u001b[39;00m\n\u001b[0;32m--> 367\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_map_async\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapstar\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchunksize\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/GraphCast/lib/python3.10/multiprocessing/pool.py:774\u001b[0m, in \u001b[0;36mApplyResult.get\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 772\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value\n\u001b[1;32m 773\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 774\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_value\n", "\u001b[0;31mNameError\u001b[0m: name 'combined_age_remo_dupli' is not defined" ] } ], "source": [ "# 未验证正确的并行化代码\n", "\n", "import numpy as np\n", "import xarray as xr\n", "from tqdm import tqdm\n", "import multiprocessing\n", "\n", "global combined, combined_age_remo_dupli, dims, Sedi_dataset\n", "\n", "# Define the function to process each variable\n", "def process_variable(var):\n", " \n", " # Skip the variables that are Coordinates.\n", " if var in [\"site latitude\", \"site longitude\", \"interpreted age\"]:\n", " return\n", " \n", " # create a nan array with the shape of (1,664,181,360) by numpy\n", " data = np.nan * np.zeros((1, len(combined_age_remo_dupli), dims[\"lat\"], dims[\"lon\"])) # (batch, time, lat, lon)\n", " data = data.astype(np.float16) # Convert the data type to np.float32\n", " \n", " # Initialize j\n", " j = 0\n", " \n", " # Iterate over the combined index\n", " for i in range(len(combined[\"index\"])):\n", " # When age is repeated, use i-j to keep time unchanged\n", " if i > 0 and combined['interpreted age'].data[i-1] == combined['interpreted age'].data[i]:\n", " j += 1\n", " \n", " # Assign values to the data array\n", " data[0, i-j, int(combined[\"site latitude\"].values[i]), int(combined[\"site longitude\"].values[i])] = combined[var].values[i]\n", " \n", " # Create a new DataArray with the same data but new dimensions\n", " new_dataarray = xr.DataArray(\n", " data,\n", " dims=[\"batch\", \"time\", \"lat\", \"lon\"],\n", " coords={\"batch\": Sedi_dataset[\"batch\"], \"time\": combined_age_remo_dupli, \"lat\": Sedi_dataset[\"lat\"], \"lon\": Sedi_dataset[\"lon\"]}\n", " )\n", " \n", " # Add the new DataArray to the new dataset\n", " Sedi_dataset[var] = new_dataarray\n", "\n", "# Define the number of processes to use\n", "num_processes = 10\n", "\n", "# Split the variables into chunks for multiprocessing\n", "variable_chunks = np.array_split(list(combined.data_vars), num_processes)\n", "\n", "# Create a pool of processes\n", "pool = multiprocessing.Pool(processes=num_processes)\n", "\n", "# Iterate over variable chunks and process them in parallel\n", "for chunk in tqdm(variable_chunks, desc=\"Processing variables\"):\n", " pool.map(process_variable, chunk)\n", "\n", "# To combine multiple xarray.Dataset objects\n", " \n", "\n", "\n", "# Close the pool\n", "pool.close()\n", "pool.join()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
              "Dimensions:   (lon: 360, lat: 181, level: 13, time: 377, batch: 1)\n",
              "Coordinates:\n",
              "  * lon       (lon) float64 0.0 1.0 2.0 3.0 4.0 ... 356.0 357.0 358.0 359.0\n",
              "  * lat       (lat) float64 -90.0 -89.0 -88.0 -87.0 ... 87.0 88.0 89.0 90.0\n",
              "  * level     (level) int64 50 125 200 275 350 425 500 575 650 725 800 875 950\n",
              "  * time      (time) float64 0.0 3.74 13.96 103.0 ... 3e+03 3.25e+03 3.4e+03\n",
              "    datetime  (batch, time) float64 0.0 3.74 13.96 ... 3e+03 3.25e+03 3.4e+03\n",
              "  * batch     (batch) int64 0\n",
              "Data variables: (12/17)\n",
              "    Ag (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Al (wt%)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    As (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Au (ppb)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Ba (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Be (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    ...        ...\n",
              "    Cr (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Cs (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Cu (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Dy (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Er (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n",
              "    Eu (ppm)  (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan
" ], "text/plain": [ "\n", "Dimensions: (lon: 360, lat: 181, level: 13, time: 377, batch: 1)\n", "Coordinates:\n", " * lon (lon) float64 0.0 1.0 2.0 3.0 4.0 ... 356.0 357.0 358.0 359.0\n", " * lat (lat) float64 -90.0 -89.0 -88.0 -87.0 ... 87.0 88.0 89.0 90.0\n", " * level (level) int64 50 125 200 275 350 425 500 575 650 725 800 875 950\n", " * time (time) float64 0.0 3.74 13.96 103.0 ... 3e+03 3.25e+03 3.4e+03\n", " datetime (batch, time) float64 0.0 3.74 13.96 ... 3e+03 3.25e+03 3.4e+03\n", " * batch (batch) int64 0\n", "Data variables: (12/17)\n", " Ag (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Al (wt%) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " As (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Au (ppb) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Ba (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Be (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " ... ...\n", " Cr (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Cs (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Cu (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Dy (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Er (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan\n", " Eu (ppm) (batch, time, lat, lon) float16 nan nan nan nan ... nan nan nan" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Sedi_dataset" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([10.02, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan, nan,\n", " nan, nan, nan, nan, nan, nan, nan, nan],\n", " dtype=float16)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Sedi_dataset[\"Al (wt%)\"].values[0,:,-37,-73]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "example_batch = Sedi_dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# The age instead of time in example_batch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "tPVy1GHokHtk" }, "outputs": [], "source": [ "# @title Choose training and eval data to extract\n", "train_steps = widgets.IntSlider(\n", " value=1, min=1, max=example_batch.sizes[\"time\"]-2, description=\"Train steps\")\n", "eval_steps = widgets.IntSlider(\n", " value=example_batch.sizes[\"time\"]-2, min=1, max=example_batch.sizes[\"time\"]-2, description=\"Eval steps\")\n", "\n", "widgets.VBox([\n", " train_steps,\n", " eval_steps,\n", " widgets.Label(value=\"Run the next cell to extract the data. Rerunning this cell clears your selection.\")\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Ogp4vTBvsgSt" }, "outputs": [], "source": [ "# @title Extract training and eval data\n", "# banch 的存在是为了更高效的利用数据集。因此如何制作 banch 是下一个要解决的问题。\n", "\n", "train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", " example_batch, target_lead_times=slice(\"6h\", f\"{train_steps.value*6}h\"),\n", " **dataclasses.asdict(task_config))\n", "\n", "eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(\n", " example_batch, target_lead_times=slice(\"6h\", f\"{eval_steps.value*6}h\"),\n", " **dataclasses.asdict(task_config))\n", "\n", "print(\"All Examples: \", example_batch.dims.mapping)\n", "print(\"Train Inputs: \", train_inputs.dims.mapping)\n", "print(\"Train Targets: \", train_targets.dims.mapping)\n", "print(\"Train Forcings:\", train_forcings.dims.mapping)\n", "print(\"Eval Inputs: \", eval_inputs.dims.mapping)\n", "print(\"Eval Targets: \", eval_targets.dims.mapping)\n", "print(\"Eval Forcings: \", eval_forcings.dims.mapping)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Q--ZRhpTdI2o" }, "outputs": [], "source": [ "# @title Load normalization data\n", "# Rewrite by S.F. Sune, https://github.com/sfsun67.\n", "dir_path_stats = \"/root/data/stats\"\n", "\n", "with open(f\"{dir_path_stats}/stats-diffs_stddev_by_level.nc\", \"rb\") as f:\n", " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", "with open(f\"{dir_path_stats}/stats-mean_by_level.nc\", \"rb\") as f:\n", " mean_by_level = xarray.load_dataset(f).compute()\n", "with open(f\"{dir_path_stats}/stats-stddev_by_level.nc\", \"rb\") as f:\n", " stddev_by_level = xarray.load_dataset(f).compute()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ke2zQyuT_sMA" }, "outputs": [], "source": [ "# @title Build jitted functions, and possibly initialize random weights\n", "# Construct the model and initialize the weights.\n", "# 构建模型并初始化权重\n", "\n", "# 模型组网\n", "# Construct the model\n", "def construct_wrapped_graphcast(\n", " model_config: graphcast.ModelConfig,\n", " task_config: graphcast.TaskConfig):\n", " \"\"\"Constructs and wraps the GraphCast Predictor.\"\"\"\n", " # Deeper one-step predictor.\n", " predictor = graphcast.GraphCast(model_config, task_config)\n", "\n", " # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to\n", " # from/to float32 to/from BFloat16.\n", " predictor = casting.Bfloat16Cast(predictor)\n", "\n", " # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from\n", " # BFloat16 happens after applying normalization to the inputs/targets.\n", " predictor = normalization.InputsAndResiduals(\n", " predictor,\n", " diffs_stddev_by_level=diffs_stddev_by_level,\n", " mean_by_level=mean_by_level,\n", " stddev_by_level=stddev_by_level)\n", "\n", " # Wraps everything so the one-step model can produce trajectories.\n", " predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)\n", " return predictor\n", "\n", "# 前向运算\n", "# forward\n", "@hk.transform_with_state\n", "def run_forward(model_config, task_config, inputs, targets_template, forcings):\n", " predictor = construct_wrapped_graphcast(model_config, task_config)\n", " return predictor(inputs, targets_template=targets_template, forcings=forcings)\n", "\n", "# 计算损失函数\n", "# loss function\n", "@hk.transform_with_state # used to convert a pure function into a stateful function\n", "def loss_fn(model_config, task_config, inputs, targets, forcings):\n", " predictor = construct_wrapped_graphcast(model_config, task_config) # constructs and wraps a GraphCast Predictor, which is a model used for making predictions in a graph-based machine learning task.\n", " loss, diagnostics = predictor.loss(inputs, targets, forcings)\n", " return xarray_tree.map_structure(\n", " lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),\n", " (loss, diagnostics))\n", "\n", "# 计算梯度\n", "# gradient\n", "def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):\n", " def _aux(params, state, i, t, f):\n", " (loss, diagnostics), next_state = loss_fn.apply(\n", " params, state, jax.random.PRNGKey(0), model_config, task_config,\n", " i, t, f)\n", " return loss, (diagnostics, next_state)\n", " (loss, (diagnostics, next_state)), grads = jax.value_and_grad(\n", " _aux, has_aux=True)(params, state, inputs, targets, forcings)\n", " return loss, diagnostics, next_state, grads\n", "\n", "# Jax doesn't seem to like passing configs as args through the jit. Passing it\n", "# in via partial (instead of capture by closure) forces jax to invalidate the\n", "# jit cache if you change configs.\n", "def with_configs(fn):\n", " return functools.partial(\n", " fn, model_config=model_config, task_config=task_config)\n", "\n", "# Always pass params and state, so the usage below are simpler\n", "def with_params(fn):\n", " return functools.partial(fn, params=params, state=state)\n", "\n", "# Our models aren't stateful, so the state is always empty, so just return the\n", "# predictions. This is requiredy by our rollout code, and generally simpler.\n", "def drop_state(fn):\n", " return lambda **kw: fn(**kw)[0]\n", "\n", "init_jitted = jax.jit(with_configs(run_forward.init))\n", "\n", "if params is None:\n", " params, state = init_jitted(\n", " rng=jax.random.PRNGKey(0),\n", " inputs=train_inputs,\n", " targets_template=train_targets,\n", " forcings=train_forcings)\n", "\n", "loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))\n", "grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))\n", "run_forward_jitted = drop_state(with_params(jax.jit(with_configs(\n", " run_forward.apply))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 测试:1. 对数据集进行迭代 2. 权重保存与加载" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 对数据集进行迭代 模型为原始的。数据为步长40的数据集, train step = 2\n", "for i in range(39):\n", " example_batch_slice = example_batch.isel(time=slice(i, 4+i))\n", " train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", " example_batch_slice, target_lead_times=slice(\"6h\", f\"{train_steps.value*6}h\"),\n", " **dataclasses.asdict(task_config))\n", " # @title Gradient computation (backprop through time)\n", " loss, diagnostics, next_state, grads = grads_fn_jitted(\n", " inputs=train_inputs,\n", " targets=train_targets,\n", " forcings=train_forcings)\n", " mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", " print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 对数据集进行迭代 0-20\n", "for i in range(20):\n", " example_batch_slice = example_batch.isel(time=slice(i, 4+i))\n", " train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", " example_batch_slice, target_lead_times=slice(\"6h\", f\"{train_steps.value*6}h\"),\n", " **dataclasses.asdict(task_config))\n", " # @title Gradient computation (backprop through time)\n", " loss, diagnostics, next_state, grads = grads_fn_jitted(\n", " inputs=train_inputs,\n", " targets=train_targets,\n", " forcings=train_forcings)\n", " mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", " print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ckpt 的已有数据\n", "params_new = grads # 由 grads_fn_jitted 计算得到\n", "model_config = model_config # 由模型加载得到 # @title Load the model\n", "task_config = task_config # 由模型加载得到 # @title Load the model \n", "description='\\nGraphCast model ...(输入你的陈述)\\n'\n", "license='\\nThe model weights are licensed 输入数据集的license\\n'\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title Save the model by S.F. Sune \n", "import io\n", "\n", "ckpt = save_model.Checkpoint(\n", " params = params_new,\n", " model_config = model_config,\n", " task_config = task_config,\n", " description = description,\n", " license = license\n", " )\n", "\n", "buffer = io.BytesIO() # 创建一个内存文件对象 creat a memory file object\n", "checkpoint.dump(buffer, ckpt)\n", "buffer.seek(0) # 移动文件指针到文件的开头,便于读取 move the file pointer to the beginning of the file, to facilitate reading\n", "\n", "# 保存buffer为.npy到本地 save buffer of .npy to local\n", "with open(\"/root/data/params/params-GraphCast_test.npy\", \"wb\") as f:\n", " f.write(buffer.read())\n", "f.close() # 关闭文件 close file" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 加载训练20次的模型\n", "\n", "with open(\"/root/data/params/params-GraphCast_test.npy\", \"rb\") as f:\n", " ckpt = checkpoint.load(f, graphcast.CheckPoint)\n", "\n", "params = ckpt.params\n", "state = {}\n", "\n", "model_config = ckpt.model_config\n", "task_config = ckpt.task_config\n", "print(\"Model description:\\n\", ckpt.description, \"\\n\")\n", "print(\"Model license:\\n\", ckpt.license, \"\\n\")\n", "params" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#继续迭代\n", "for i in range(20,39):\n", " example_batch_slice = example_batch.isel(time=slice(i, 4+i))\n", " train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", " example_batch_slice, target_lead_times=slice(\"6h\", f\"{train_steps.value*6}h\"),\n", " **dataclasses.asdict(task_config))\n", " # @title Gradient computation (backprop through time)\n", " loss, diagnostics, next_state, grads = grads_fn_jitted(\n", " inputs=train_inputs,\n", " targets=train_targets,\n", " forcings=train_forcings)\n", " mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", " print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "VBNutliiCyqA" }, "source": [ "# Run the model\n", "\n", "Note that the cell below may take a while (possibly minutes) to run the first time you execute them, because this will include the time it takes for the code to compile. The second time running will be significantly faster.\n", "\n", "This use the python loop to iterate over prediction steps, where the 1-step prediction is jitted. This has lower memory requirements than the training steps below, and should enable making prediction with the small GraphCast model on 1 deg resolution data for 4 steps." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "7obeY9i9oTtD" }, "outputs": [], "source": [ "# @title Autoregressive rollout (loop in python)\n", "\n", "assert model_config.resolution in (0, 360. / eval_inputs.sizes[\"lon\"]), (\n", " \"Model resolution doesn't match the data resolution. You likely want to \"\n", " \"re-filter the dataset list, and download the correct data.\")\n", "\n", "print(\"Inputs: \", eval_inputs.dims.mapping)\n", "print(\"Targets: \", eval_targets.dims.mapping)\n", "print(\"Forcings:\", eval_forcings.dims.mapping)\n", "\n", "predictions = rollout.chunked_prediction(\n", " run_forward_jitted,\n", " rng=jax.random.PRNGKey(0),\n", " inputs=eval_inputs,\n", " targets_template=eval_targets * np.nan,\n", " forcings=eval_forcings)\n", "predictions" ] }, { "cell_type": "markdown", "metadata": { "id": "Pa78b64bLYe1" }, "source": [ "# Train the model\n", "\n", "The following operations require a large amount of memory and, depending on the accelerator being used, will only fit the very small \"random\" model on low resolution data. It uses the number of training steps selected above.\n", "\n", "The first time executing the cell takes more time, as it include the time to jit the function." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Nv-u3dAP7IRZ" }, "outputs": [], "source": [ "# @title Loss computation (autoregressive loss over multiple steps)\n", "loss, diagnostics = loss_fn_jitted(\n", " rng=jax.random.PRNGKey(0),\n", " inputs=train_inputs,\n", " targets=train_targets,\n", " forcings=train_forcings)\n", "\n", "print(\"Loss:\", float(loss))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "mBNFq1IGZNLz" }, "outputs": [], "source": [ "# @title Gradient computation (backprop through time)\n", "loss, diagnostics, next_state, grads = grads_fn_jitted(\n", " inputs=train_inputs,\n", " targets=train_targets,\n", " forcings=train_forcings)\n", "mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", "print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "J4FJFKWD8Loz" }, "outputs": [], "source": [ "# @title Autoregressive rollout (keep the loop in JAX)\n", "print(\"Inputs: \", train_inputs.dims.mapping)\n", "print(\"Targets: \", train_targets.dims.mapping)\n", "print(\"Forcings:\", train_forcings.dims.mapping)\n", "\n", "predictions = run_forward_jitted(\n", " rng=jax.random.PRNGKey(0),\n", " inputs=train_inputs,\n", " targets_template=train_targets * np.nan,\n", " forcings=train_forcings)\n", "predictions" ] } ], "metadata": { "colab": { "name": "GraphCast", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 0 }