all_graphcast_loadData.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # @title Imports
  2. import numpy as np
  3. import xarray as xr
  4. import pandas as pd
  5. import multiprocessing as mp
  6. from tqdm import tqdm
  7. ## read /root/data/Sedi_dataset.nc to Sedi_dataset
  8. # Sedi_dataset = xr.open_dataset("/root/data/Sedi_dataset.nc")
  9. # load example_batch from .csv file
  10. print("load example_batch from .csv file")
  11. # Load the data into a DataFrame
  12. df = pd.read_csv("/root/data/SEDI.csv", encoding='latin1')
  13. # Convert the DataFrame to an xarray Dataset
  14. sedi_ds = xr.Dataset.from_dataframe(df)
  15. # 当 interpreted age 为 nan 时,删去该行
  16. sedi_ds = sedi_ds.dropna(dim='index', subset=['interpreted age'])
  17. # 按照 interpreted age 升序排序,并改变其他变量的顺序
  18. sedi_ds = sedi_ds.sortby('interpreted age', ascending=True)
  19. # Rewrite the lon and lat according the resulation of dataset.
  20. print("Rewrite the lon and lat according the resulation of dataset.")
  21. # define
  22. resolution_Rewrite_lon_lat = 1 #the resolution of longitiude and latitude
  23. def Rewrite_lon_lat(data, resolution):
  24. '''
  25. 根据 xarray 数据集中的分辨率 重写 lon 和 lat
  26. Rewrite the lon and lat according the resulation of dataset.
  27. data: the original data
  28. resolution: the resolution of the data
  29. '''
  30. condition_number = int(1/resolution)
  31. data["site latitude"].data = np.round(data["site latitude"].data * condition_number) / condition_number
  32. data["site longitude"].data = np.round(data["site longitude"].data * condition_number) / condition_number
  33. return data
  34. # Rewrite the lon and lat according the dataset of xarray.
  35. # 问题:重写了经纬度的分辨率之后,如何处理新出来的经纬度的重复值?
  36. combined = Rewrite_lon_lat(sedi_ds, resolution_Rewrite_lon_lat)
  37. # 使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均
  38. print("使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均")
  39. # Function to process a part of the dataset
  40. sedimentary_list = []
  41. def groupby_and_average(sedi_ds):
  42. '''
  43. # 使用 groupby 方法根据 lon、lat 和 time 三个变量对数据集进行分组, 并对分组后的数据集求平均
  44. '''
  45. for site_longitude_value, site_longitude in sedi_ds.groupby("site longitude"):
  46. for site_latitude_value, site_latitude in site_longitude.groupby("site latitude"):
  47. for interpreted_age_value, sedi in site_latitude.groupby("interpreted age"):
  48. #sedimentary_dict = sedi.apply(np.mean).to_dict()
  49. sedimentary_list.append(sedi.apply(np.mean))
  50. # Add an identifying dimension to each xr.Dataset of sedimentary_list
  51. for i, sedi_ds in enumerate(sedimentary_list):
  52. sedi_ds = sedi_ds.expand_dims({'sample': [i]})
  53. # Concatenate the datasets
  54. combined = xr.concat(sedimentary_list, dim='index')
  55. return combined, site_longitude_value, site_latitude_value, interpreted_age_value
  56. # Divide the dataset into parts
  57. part_number = 9
  58. dim = 'index' # replace with your actual dimension
  59. dim_size = sedi_ds.dims[dim]
  60. indices = np.linspace(0, dim_size, part_number+1).astype(int)
  61. parts = [sedi_ds.isel({dim: slice(indices[i], indices[i + 1])}) for i in range(part_number)]
  62. # Create a multiprocessing Pool
  63. pool = mp.Pool(mp.cpu_count())
  64. # Process each part of the dataset in parallel with a progress bar
  65. print('Processing Sedi datasets, replacing duplicates with averages ...')
  66. results = []
  67. with tqdm(total=len(parts)) as pbar:
  68. for result in pool.imap_unordered(groupby_and_average, parts):
  69. results.append(result)
  70. pbar.update(1)
  71. # Close the pool
  72. pool.close()
  73. # To combine multiple xarray.Dataset objects
  74. result_list = [result[0] for result in results]
  75. combined = xr.concat(result_list, dim='index')
  76. # 按照 interpreted age 升序排序,并改变其他变量的顺序
  77. combined = combined.sortby('interpreted age', ascending=True)
  78. # Create the new xr.Dataset
  79. # When copy, notice that deep copy and shallow copy.
  80. # define
  81. resolution = resolution_Rewrite_lon_lat #the resolution of longitiude and latitude
  82. batch = 0
  83. datetime_temp = np.random.rand(1, len(list(dict.fromkeys(combined['interpreted age'].data)))) # 这里要根据非重复 age 的长度来定义 xarray 的长度
  84. datetime_temp[0, :] = list(dict.fromkeys(combined['interpreted age'].data))
  85. # Create the dimensions
  86. dims = {
  87. "lon": int(360/resolution),
  88. "lat": int(181/resolution),
  89. "level": 13,
  90. "time": len(list(dict.fromkeys(combined['interpreted age'].data))),
  91. }
  92. # Create the coordinates
  93. coords_creat = {
  94. "lon": np.linspace(0, 359, int(dims["lon"] - (1/resolution - 1))),
  95. "lat": np.linspace(-90, 90, int(dims["lat"] - (1/resolution - 1))),
  96. "level": np.arange(50, 1000, 75),
  97. "time": datetime_temp[0, :],
  98. "datetime": (["batch", "time"], datetime_temp),
  99. }
  100. # Create the new dataset
  101. Sedi_dataset = xr.Dataset(coords = coords_creat)
  102. print("Create the new dataset done.")
  103. # load sedi data into the Sedi_dataset
  104. j=0
  105. dims = Sedi_dataset.dims # Get the dimensions from Sedi_dataset
  106. # remove duplicate values from <combined['interpreted age'].data>
  107. combined_age_remo_dupli = list(dict.fromkeys(combined['interpreted age'].data))
  108. #
  109. combined_batch = Sedi_dataset["batch"].data
  110. Sedi_dataset["batch"]
  111. # Add the variables from the combined dataset to the new dataset
  112. for var in tqdm(combined.data_vars, desc="load sedi data"):
  113. # Skip the variables that are Coordinates.
  114. if var == "site latitude" or var == "site longitude" or var == "interpreted age":
  115. continue
  116. # def / 是否可以使用广播?
  117. # create a nan array with the shape of (1,664,181,360) by numpy
  118. data = np.nan * np.zeros((1, len(combined_age_remo_dupli), dims["lat"], dims["lon"])) # (banch, time, lat, lon)
  119. data = data.astype(np.float16) # Convert the data type to np.float32 有效,这段代码能少一半内存
  120. # [非常重要]如何测试这段代码????????????????????????????????????????????????
  121. for i in range(len(combined["index"])):
  122. # 当 age 重复的时候,使用 i-j 来保持时间不变。
  123. if combined['interpreted age'].data[i-1] == combined['interpreted age'].data[i]:
  124. j = j + 1 # 如果 age 重复,j 就加 1,i-j 保持时间不变
  125. # i 指示 age,j 用来固定重复的 age,下面的代码将经纬度上的数据赋值给指定 age 。
  126. data[batch, i-j, int(combined["site latitude"].values[i]),
  127. int(combined["site longitude"].values[i])] = combined[var].values[i]
  128. else:
  129. # 如果 age 不重复,j 不变,i-j 在之前的基础上继续变化
  130. # i 指示 age,j 用来固定重复的 age,下面的代码将经纬度上的数据赋值给指定 age 。
  131. data[batch, i-j, int(combined["site latitude"].values[i]),
  132. int(combined["site longitude"].values[i])] = combined[var].values[i]
  133. j = 0 # 重置 j 的值
  134. # Create a new DataArray with the same data but new dimensions
  135. new_dataarray = xr.DataArray(
  136. data,
  137. dims=["batch", "time", "lat", "lon"],
  138. coords={"batch": Sedi_dataset["batch"], "time": combined_age_remo_dupli, "lat": Sedi_dataset["lat"], "lon": Sedi_dataset["lon"]}
  139. )
  140. # Add the new DataArray to the new dataset
  141. Sedi_dataset[var] = new_dataarray
  142. del data, new_dataarray
  143. Sedi_dataset.astype(np.float32) # TypeError: Illegal primitive data type, must be one of dict_keys(['S1', 'i1', 'u1', 'i2', 'u2', 'i4', 'u4', 'i8', 'u8', 'f4', 'f8']), got float16 (variable 'Ag (ppm)', group '/')
  144. # save the Sedi_dataset
  145. path = "/root/autodl-fs/data/SEDI.nc"
  146. Sedi_dataset.to_netcdf(path)
  147. print(f"Save the Sedi_dataset done. path = {path}")