123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- # @title Imports
- import numpy as np
- import xarray as xr
- import pandas as pd
- import multiprocessing as mp
- from tqdm import tqdm
- ## read /root/data/Sedi_dataset.nc to Sedi_dataset
- # Sedi_dataset = xr.open_dataset("/root/data/Sedi_dataset.nc")
- # load example_batch from .csv file
- print("load example_batch from .csv file")
- # Load the data into a DataFrame
- df = pd.read_csv("/root/data/SEDI.csv", encoding='latin1')
- # Convert the DataFrame to an xarray Dataset
- sedi_ds = xr.Dataset.from_dataframe(df)
- # 当 interpreted age 为 nan 时,删去该行
- sedi_ds = sedi_ds.dropna(dim='index', subset=['interpreted age'])
- # 按照 interpreted age 升序排序,并改变其他变量的顺序
- sedi_ds = sedi_ds.sortby('interpreted age', ascending=True)
- # Rewrite the lon and lat according the resulation of dataset.
- print("Rewrite the lon and lat according the resulation of dataset.")
- # define
- resolution_Rewrite_lon_lat = 1 #the resolution of longitiude and latitude
- def Rewrite_lon_lat(data, resolution):
- '''
- 根据 xarray 数据集中的分辨率 重写 lon 和 lat
- Rewrite the lon and lat according the resulation of dataset.
- data: the original data
- resolution: the resolution of the data
- '''
- condition_number = int(1/resolution)
- data["site latitude"].data = np.round(data["site latitude"].data * condition_number) / condition_number
- data["site longitude"].data = np.round(data["site longitude"].data * condition_number) / condition_number
- return data
- # Rewrite the lon and lat according the dataset of xarray.
- # 问题:重写了经纬度的分辨率之后,如何处理新出来的经纬度的重复值?
- combined = Rewrite_lon_lat(sedi_ds, resolution_Rewrite_lon_lat)
- # 使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均
- print("使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均")
- # Function to process a part of the dataset
- sedimentary_list = []
- def groupby_and_average(sedi_ds):
- '''
- # 使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均
- '''
- for site_longitude_value, site_longitude in sedi_ds.groupby("site longitude"):
- for site_latitude_value, site_latitude in site_longitude.groupby("site latitude"):
- for interpreted_age_value, sedi in site_latitude.groupby("interpreted age"):
- #sedimentary_dict = sedi.apply(np.mean).to_dict()
- sedimentary_list.append(sedi.apply(np.mean))
-
- # Add an identifying dimension to each xr.Dataset of sedimentary_list
- for i, sedi_ds in enumerate(sedimentary_list):
- sedi_ds = sedi_ds.expand_dims({'sample': [i]})
- # Concatenate the datasets
- combined = xr.concat(sedimentary_list, dim='index')
- return combined, site_longitude_value, site_latitude_value, interpreted_age_value
- # Divide the dataset into parts
- part_number = 9
- dim = 'index' # replace with your actual dimension
- dim_size = sedi_ds.dims[dim]
- indices = np.linspace(0, dim_size, part_number+1).astype(int)
- parts = [sedi_ds.isel({dim: slice(indices[i], indices[i + 1])}) for i in range(part_number)]
- # Create a multiprocessing Pool
- pool = mp.Pool(mp.cpu_count())
- # Process each part of the dataset in parallel with a progress bar
- print('Processing Sedi datasets, replacing duplicates with averages ...')
- results = []
- with tqdm(total=len(parts)) as pbar:
- for result in pool.imap_unordered(groupby_and_average, parts):
- results.append(result)
- pbar.update(1)
- # Close the pool
- pool.close()
- # To combine multiple xarray.Dataset objects
- result_list = [result[0] for result in results]
- combined = xr.concat(result_list, dim='index')
- # 按照 interpreted age 升序排序,并改变其他变量的顺序
- combined = combined.sortby('interpreted age', ascending=True)
- # Create the new xr.Dataset
- # When copy, notice that deep copy and shallow copy.
- # define
- resolution = resolution_Rewrite_lon_lat #the resolution of longitiude and latitude
- batch = 0
- datetime_temp = np.random.rand(1, len(list(dict.fromkeys(combined['interpreted age'].data)))) # 这里要根据非重复 age 的长度来定义 xarray 的长度
- datetime_temp[0, :] = list(dict.fromkeys(combined['interpreted age'].data))
- # Create the dimensions
- dims = {
- "lon": int(360/resolution),
- "lat": int(181/resolution),
- "level": 13,
- "time": len(list(dict.fromkeys(combined['interpreted age'].data))),
- }
- # Create the coordinates
- coords_creat = {
- "lon": np.linspace(0, 359, int(dims["lon"] - (1/resolution - 1))),
- "lat": np.linspace(-90, 90, int(dims["lat"] - (1/resolution - 1))),
- "level": np.arange(50, 1000, 75),
- "time": datetime_temp[0, :],
- "datetime": (["batch", "time"], datetime_temp),
- }
- # Create the new dataset
- Sedi_dataset = xr.Dataset(coords = coords_creat)
- print("Create the new dataset done.")
- # load sedi data into the Sedi_dataset
- j=0
- dims = Sedi_dataset.dims # Get the dimensions from Sedi_dataset
- # remove duplicate values from <combined['interpreted age'].data>
- combined_age_remo_dupli = list(dict.fromkeys(combined['interpreted age'].data))
- #
- combined_batch = Sedi_dataset["batch"].data
- Sedi_dataset["batch"]
- # Add the variables from the combined dataset to the new dataset
- for var in tqdm(combined.data_vars, desc="load sedi data"):
- # Skip the variables that are Coordinates.
- if var == "site latitude" or var == "site longitude" or var == "interpreted age":
- continue
- # def / 是否可以使用广播?
- # create a nan array with the shape of (1,664,181,360) by numpy
- data = np.nan * np.zeros((1, len(combined_age_remo_dupli), dims["lat"], dims["lon"])) # (banch, time, lat, lon)
- data = data.astype(np.float16) # Convert the data type to np.float32 有效,这段代码能少一半内存
- # [非常重要]如何测试这段代码????????????????????????????????????????????????
- for i in range(len(combined["index"])):
- # 当 age 重复的时候,使用 i-j 来保持时间不变。
- if combined['interpreted age'].data[i-1] == combined['interpreted age'].data[i]:
- j = j + 1 # 如果 age 重复,j 就加 1,i-j 保持时间不变
- # i 指示 age,j 用来固定重复的 age,下面的代码将经纬度上的数据赋值给指定 age 。
- data[batch, i-j, int(combined["site latitude"].values[i]),
- int(combined["site longitude"].values[i])] = combined[var].values[i]
- else:
- # 如果 age 不重复,j 不变,i-j 在之前的基础上继续变化
- # i 指示 age,j 用来固定重复的 age,下面的代码将经纬度上的数据赋值给指定 age 。
- data[batch, i-j, int(combined["site latitude"].values[i]),
- int(combined["site longitude"].values[i])] = combined[var].values[i]
- j = 0 # 重置 j 的值
- # Create a new DataArray with the same data but new dimensions
- new_dataarray = xr.DataArray(
- data,
- dims=["batch", "time", "lat", "lon"],
- coords={"batch": Sedi_dataset["batch"], "time": combined_age_remo_dupli, "lat": Sedi_dataset["lat"], "lon": Sedi_dataset["lon"]}
- )
- # Add the new DataArray to the new dataset
- Sedi_dataset[var] = new_dataarray
- del data, new_dataarray
- Sedi_dataset.astype(np.float32) # TypeError: Illegal primitive data type, must be one of dict_keys(['S1', 'i1', 'u1', 'i2', 'u2', 'i4', 'u4', 'i8', 'u8', 'f4', 'f8']), got float16 (variable 'Ag (ppm)', group '/')
- # save the Sedi_dataset
- path = "/root/autodl-fs/data/SEDI.nc"
- Sedi_dataset.to_netcdf(path)
- print(f"Save the Sedi_dataset done. path = {path}")
|