pseudocode.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. '''
  2. Pseudocode of Pangu-Weather
  3. '''
  4. # The pseudocode can be implemented using deep learning libraries, e.g., Pytorch and Tensorflow or other high-level APIs
  5. # Basic operations used in our model, namely Linear, Conv3d, Conv2d, ConvTranspose3d and ConvTranspose2d
  6. # Linear: Linear transformation, available in all deep learning libraries
  7. # Conv3d and Con2d: Convolution with 2 or 3 dimensions, available in all deep learning libraries
  8. # ConvTranspose3d, ConvTranspose2d: transposed convolution with 2 or 3 dimensions, see Pytorch API or Tensorflow API
  9. from Your_AI_Library import Linear, Conv3d, Conv2d, ConvTranspose3d, ConvTranspose2d
  10. # Functions in the networks, namely GeLU, DropOut, DropPath, LayerNorm, and SoftMax
  11. # GeLU: the GeLU activation function, see Pytorch API or Tensorflow API
  12. # DropOut: the dropout function, available in all deep learning libraries
  13. # DropPath: the DropPath function, see the implementation of vision-transformer, see timm pakage of Pytorch
  14. # A possible implementation of DropPath: from timm.models.layers import DropPath
  15. # LayerNorm: the layer normalization function, see Pytorch API or Tensorflow API
  16. # Softmax: softmax function, see Pytorch API or Tensorflow API
  17. from Your_AI_Library import GeLU, DropOut, DropPath, LayerNorm, SoftMax
  18. # Common functions for roll, pad, and crop, depends on the data structure of your software environment
  19. from Your_AI_Library import roll3D, pad3D, pad2D, Crop3D, Crop2D
  20. # Common functions for reshaping and changing the order of dimensions
  21. # reshape: change the shape of the data with the order unchanged, see Pytorch API or Tensorflow API
  22. # TransposeDimensions: change the order of the dimensions, see Pytorch API or Tensorflow API
  23. from Your_AI_Library import reshape, TransposeDimensions
  24. # Common functions for creating new tensors
  25. # ConstructTensor: create a new tensor with an arbitrary shape
  26. # TruncatedNormalInit: Initialize the tensor with Truncate Normalization distribution
  27. # RangeTensor: create a new tensor like range(a, b)
  28. from Your_AI_Library import ConstructTensor, TruncatedNormalInit, RangeTensor
  29. # Common operations for the data, you may design it or simply use deep learning APIs default operations
  30. # LinearSpace: a tensor version of numpy.linspace
  31. # MeshGrid: a tensor version of numpy.meshgrid
  32. # Stack: a tensor version of numpy.stack
  33. # Flatten: a tensor version of numpy.ndarray.flatten
  34. # TensorSum: a tensor version of numpy.sum
  35. # TensorAbs: a tensor version of numpy.abs
  36. # Concatenate: a tensor version of numpy.concatenate
  37. from Your_AI_Library import LinearSpace, MeshGrid, Stack, Flatten, TensorSum, TensorAbs, Concatenate
  38. # Common functions for training models
  39. # LoadModel and SaveModel: Load and save the model, some APIs may require further adaptation to hardwares
  40. # Backward: Gradient backward to calculate the gratitude of each parameters
  41. # UpdateModelParametersWithAdam: Use Adam to update parameters, e.g., torch.optim.Adam
  42. from Your_AI_Library import LoadModel, Backward, UpdateModelParametersWithAdam, SaveModel
  43. # Custom functions to read your data from the disc
  44. # LoadData: Load the ERA5 data
  45. # LoadConstantMask: Load constant masks, e.g., soil type
  46. # LoadStatic: Load mean and std of the ERA5 training data, every fields such as T850 is treated as an image and calculate the mean and std
  47. from Your_Data_Code import LoadData, LoadConstantMask, LoadStatic
  48. def Inference(input, input_surface, forecast_range):
  49. '''Inference code, describing the algorithm of inference using models with different lead times.
  50. PanguModel24, PanguModel6, PanguModel3 and PanguModel1 share the same training algorithm but differ in lead times.
  51. Args:
  52. input: input tensor, need to be normalized to N(0, 1) in practice
  53. input_surface: target tensor, need to be normalized to N(0, 1) in practice
  54. forecast_range: iteration numbers when roll out the forecast model
  55. '''
  56. # Load 4 pre-trained models with different lead times
  57. PanguModel24 = LoadModel(ModelPath24)
  58. PanguModel6 = LoadModel(ModelPath6)
  59. PanguModel3 = LoadModel(ModelPath3)
  60. PanguModel1 = LoadModel(ModelPath1)
  61. # Load mean and std of the weather data
  62. weather_mean, weather_std, weather_surface_mean, weather_surface_std = LoadStatic()
  63. # Store initial input for different models
  64. input_24, input_surface_24 = input, input_surface
  65. input_6, input_surface_6 = input, input_surface
  66. input_3, input_surface_3 = input, input_surface
  67. # Using a list to store output
  68. output_list = []
  69. # Note: the following code is implemented for fast inference of [1,forecast_range]-hour forecasts -- if only one lead time is requested, the inference can be much faster.
  70. for i in range(forecast_range):
  71. # switch to the 24-hour model if the forecast time is 24 hours, 48 hours, ..., 24*N hours
  72. if (i+1) % 24 == 0:
  73. # Switch the input back to the stored input
  74. input, input_surface = input_24, input_surface_24
  75. # Call the model pretrained for 24 hours forecast
  76. output, output_surface = PanguModel24(input, input_surface)
  77. # Restore from uniformed output
  78. output = output * weather_std + weather_mean
  79. output_surface = output_surface * weather_surface_std + weather_surface_mean
  80. # Stored the output for next round forecast
  81. input_24, input_surface_24 = output, output_surface
  82. input_6, input_surface_6 = output, output_surface
  83. input_3, input_surface_3 = output, output_surface
  84. # switch to the 6-hour model if the forecast time is 30 hours, 36 hours, ..., 24*N + 6/12/18 hours
  85. elif (i+1) % 6 == 0:
  86. # Switch the input back to the stored input
  87. input, input_surface = input_6, input_surface_6
  88. # Call the model pretrained for 6 hours forecast
  89. output, output_surface = PanguModel6(input, input_surface)
  90. # Restore from uniformed output
  91. output = output * weather_std + weather_mean
  92. output_surface = output_surface * weather_surface_std + weather_surface_mean
  93. # Stored the output for next round forecast
  94. input_6, input_surface_6 = output, output_surface
  95. input_3, input_surface_3 = output, output_surface
  96. # switch to the 3-hour model if the forecast time is 3 hours, 9 hours, ..., 6*N + 3 hours
  97. elif (i+1) % 3 ==0:
  98. # Switch the input back to the stored input
  99. input, input_surface = input_3, input_surface_3
  100. # Call the model pretrained for 3 hours forecast
  101. output, output_surface = PanguModel3(input, input_surface)
  102. # Restore from uniformed output
  103. output = output * weather_std + weather_mean
  104. output_surface = output_surface * weather_surface_std + weather_surface_mean
  105. # Stored the output for next round forecast
  106. input_3, input_surface_3 = output, output_surface
  107. # switch to the 1-hour model
  108. else:
  109. # Call the model pretrained for 1 hours forecast
  110. output, output_surface = PanguModel1(input, input_surface)
  111. # Restore from uniformed output
  112. output = output * weather_std + weather_mean
  113. output_surface = output_surface * weather_surface_std + weather_surface_mean
  114. # Stored the output for next round forecast
  115. input, input_surface = output, output_surface
  116. # Save the output
  117. output_list.append((output, output_surface))
  118. return output_list
  119. def Train():
  120. '''Training code'''
  121. # Initialize the model, for some APIs some adaptation is needed to fit hardwares
  122. model = PanguModel()
  123. # Train single Pangu-Weather model
  124. epochs = 100
  125. for i in range(epochs):
  126. # For each epoch, we iterate from 1979 to 2017
  127. # dataset_length is the length of your training data, e.g., the sample between 1979 and 2017
  128. for step in range(dataset_length):
  129. # Load weather data at time t as the input; load weather data at time t+1/3/6/24 as the output
  130. # Note the data need to be randomly shuffled
  131. # Note the input and target need to be normalized, see Inference() for details
  132. input, input_surface, target, target_surface = LoadData(step)
  133. # Call the model and get the output
  134. output, output_surface = model(input, input_surface)
  135. # We use the MAE loss to train the model
  136. # The weight of surface loss is 0.25
  137. # Different weight can be applied for differen fields if needed
  138. loss = TensorAbs(output-target) + TensorAbs(output_surface-target_surface) * 0.25
  139. # Call the backward algorithm and calculate the gratitude of parameters
  140. Backward(loss)
  141. # Update model parameters with Adam optimizer
  142. # The learning rate is 5e-4 as in the paper, while the weight decay is 3e-6
  143. # A example solution is using torch.optim.adam
  144. UpdateModelParametersWithAdam()
  145. # Save the model at the end of the training stage
  146. SaveModel()
  147. class PanguModel:
  148. def __init__(self):
  149. # Drop path rate is linearly increased as the depth increases
  150. drop_path_list = LinearSpace(0, 0.2, 8)
  151. # Patch embedding
  152. self._input_layer = PatchEmbedding((2, 4, 4), 192)
  153. # Four basic layers
  154. self.layer1 = EarthSpecificLayer(2, 192, drop_list[:2], 6)
  155. self.layer2 = EarthSpecificLayer(6, 384, drop_list[6:], 12)
  156. self.layer3 = EarthSpecificLayer(6, 384, drop_list[6:], 12)
  157. self.layer4 = EarthSpecificLayer(2, 192, drop_list[:2], 6)
  158. # Upsample and downsample
  159. self.upsample = UpSample(384, 192)
  160. self.downsample = DownSample(192)
  161. # Patch Recovery
  162. self._output_layer = PatchRecovery(384)
  163. def forward(self, input, input_surface):
  164. '''Backbone architecture'''
  165. # Embed the input fields into patches
  166. x = self._input_layer(input, input_surface)
  167. # Encoder, composed of two layers
  168. # Layer 1, shape (8, 360, 181, C), C = 192 as in the original paper
  169. x = self.layer1(x, 8, 360, 181)
  170. # Store the tensor for skip-connection
  171. skip = x
  172. # Downsample from (8, 360, 181) to (8, 180, 91)
  173. x = self.downsample(x, 8, 360, 181)
  174. # Layer 2, shape (8, 180, 91, 2C), C = 192 as in the original paper
  175. x = self.layer2(x, 8, 180, 91)
  176. # Decoder, composed of two layers
  177. # Layer 3, shape (8, 180, 91, 2C), C = 192 as in the original paper
  178. x = self.layer3(x, 8, 180, 91)
  179. # Upsample from (8, 180, 91) to (8, 360, 181)
  180. x = self.upsample(x)
  181. # Layer 4, shape (8, 360, 181, 2C), C = 192 as in the original paper
  182. x = self.layer4(x, 8, 360, 181)
  183. # Skip connect, in last dimension(C from 192 to 384)
  184. x = Concatenate(skip, x)
  185. # Recover the output fields from patches
  186. output, output_surface = self._output_layer(x)
  187. return output, output_surface
  188. class PatchEmbedding:
  189. def __init__(self, patch_size, dim):
  190. '''Patch embedding operation'''
  191. # Here we use convolution to partition data into cubes
  192. self.conv = Conv3d(input_dims=5, output_dims=dim, kernel_size=patch_size, stride=patch_size)
  193. self.conv_surface = Conv2d(input_dims=7, output_dims=dim, kernel_size=patch_size[1:], stride=patch_size[1:])
  194. # Load constant masks from the disc
  195. self.land_mask, self.soil_type, self.topography = LoadConstantMask()
  196. def forward(self, input, input_surface):
  197. # Zero-pad the input
  198. input = Pad3D(input)
  199. input_surface = Pad2D(input_surface)
  200. # Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches, patch_size = (2, 4, 4) as in the original paper
  201. input = self.conv(input)
  202. # Add three constant fields to the surface fields
  203. input_surface = Concatenate(input_surface, self.land_mask, self.soil_type, self.topography)
  204. # Apply a linear projection for patch_size[1]*patch_size[2] patches
  205. input_surface = self.conv_surface(input_surface)
  206. # Concatenate the input in the pressure level, i.e., in Z dimension
  207. x = Concatenate(input, input_surface)
  208. # Reshape x for calculation of linear projections
  209. x = TransposeDimensions(x, (0, 2, 3, 4, 1))
  210. x = reshape(x, target_shape=(x.shape[0], 8*360*181, x.shape[-1]))
  211. return x
  212. class PatchRecovery:
  213. def __init__(self, dim):
  214. '''Patch recovery operation'''
  215. # Hear we use two transposed convolutions to recover data
  216. self.conv = ConvTranspose3d(input_dims=dim, output_dims=5, kernel_size=patch_size, stride=patch_size)
  217. self.conv_surface = ConvTranspose2d(input_dims=dim, output_dims=4, kernel_size=patch_size[1:], stride=patch_size[1:])
  218. def forward(self, x, Z, H, W):
  219. # The inverse operation of the patch embedding operation, patch_size = (2, 4, 4) as in the original paper
  220. # Reshape x back to three dimensions
  221. x = TransposeDimensions(x, (0, 2, 1))
  222. x = reshape(x, target_shape=(x.shape[0], x.shape[1], Z, H, W))
  223. # Call the transposed convolution
  224. output = self.conv(x[:, :, 1:, :, :])
  225. output_surface = self.conv_surface(x[:, :, 0, :, :])
  226. # Crop the output to remove zero-paddings
  227. output = Crop3D(output)
  228. output_surface = Crop2D(output_surface)
  229. return output, output_surface
  230. class DownSample:
  231. def __init__(self, dim):
  232. '''Down-sampling operation'''
  233. # A linear function and a layer normalization
  234. self.linear = Linear(4*dim, 2*dim, bias=Fasle)
  235. self.norm = LayerNorm(4*dim)
  236. def forward(self, x, Z, H, W):
  237. # Reshape x to three dimensions for downsampling
  238. x = reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[-1]))
  239. # Padding the input to facilitate downsampling
  240. x = Pad3D(x)
  241. # Reorganize x to reduce the resolution: simply change the order and downsample from (8, 360, 182) to (8, 180, 91)
  242. Z, H, W = x.shape
  243. # Reshape x to facilitate downsampling
  244. x = reshape(x, target_shape=(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1]))
  245. # Change the order of x
  246. x = TransposeDimensions(x, (0,1,2,4,3,5,6))
  247. # Reshape to get a tensor of resolution (8, 180, 91)
  248. x = reshape(x, target_shape=(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1]))
  249. # Call the layer normalization
  250. x = self.norm(x)
  251. # Decrease the channels of the data to reduce computation cost
  252. x = self.linear(x)
  253. return x
  254. class UpSample:
  255. def __init__(self, input_dim, output_dim):
  256. '''Up-sampling operation'''
  257. # Linear layers without bias to increase channels of the data
  258. self.linear1 = Linear(input_dim, output_dim*4, bias=False)
  259. # Linear layers without bias to mix the data up
  260. self.linear2 = Linear(output_dim, output_dim, bias=False)
  261. # Normalization
  262. self.norm = LayerNorm(output_dim)
  263. def forward(self, x):
  264. # Call the linear functions to increase channels of the data
  265. x = self.linear1(x)
  266. # Reorganize x to increase the resolution: simply change the order and upsample from (8, 180, 91) to (8, 360, 182)
  267. # Reshape x to facilitate upsampling.
  268. x = reshape(x, target_shape=(x.shape[0], 8, 180, 91, 2, 2, x.shape[-1]//4))
  269. # Change the order of x
  270. x = TransposeDimensions(x, (0,1,2,4,3,5,6))
  271. # Reshape to get Tensor with a resolution of (8, 360, 182)
  272. x = reshape(x, target_shape=(x.shape[0], 8, 360, 182, x.shape[-1]))
  273. # Crop the output to the input shape of the network
  274. x = Crop3D(x)
  275. # Reshape x back
  276. x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1]))
  277. # Call the layer normalization
  278. x = self.norm(x)
  279. # Mixup normalized tensors
  280. x = self.linear2(x)
  281. return x
  282. class EarthSpecificLayer:
  283. def __init__(self, depth, dim, drop_path_ratio_list, heads):
  284. '''Basic layer of our network, contains 2 or 6 blocks'''
  285. self.depth = depth
  286. self.blocks = []
  287. # Construct basic blocks
  288. for i in range(depth):
  289. self.blocks.append(EarthSpecificBlock(dim, drop_path_ratio_list[i], heads))
  290. def forward(self, x, Z, H, W):
  291. for i in range(self.depth):
  292. # Roll the input every two blocks
  293. if i % 2 == 0:
  294. self.blocks[i](x, Z, H, W, roll=False)
  295. else:
  296. self.blocks[i](x, Z, H, W, roll=True)
  297. return x
  298. class EarthSpecificBlock:
  299. def __init__(self, dim, drop_path_ratio, heads):
  300. '''
  301. 3D transformer block with Earth-Specific bias and window attention,
  302. see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
  303. The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.
  304. '''
  305. # Define the window size of the neural network
  306. self.window_size = (2, 6, 12)
  307. # Initialize serveral operations
  308. self.drop_path = DropPath(drop_rate=drop_path_ratio)
  309. self.norm1 = LayerNorm(dim)
  310. self.norm2 = LayerNorm(dim)
  311. self.linear = MLP(dim, 0)
  312. self.attention = EarthAttention3D(dim, heads, 0, self.window_size)
  313. def forward(self, x, Z, H, W, roll):
  314. # Save the shortcut for skip-connection
  315. shortcut = x
  316. # Reshape input to three dimensions to calculate window attention
  317. reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[2]))
  318. # Zero-pad input if needed
  319. x = pad3D(x)
  320. # Store the shape of the input for restoration
  321. ori_shape = x.shape
  322. if roll:
  323. # Roll x for half of the window for 3 dimensions
  324. x = roll3D(x, shift=[self.window_size[0]//2, self.window_size[1]//2, self.window_size[2]//2])
  325. # Generate mask of attention masks
  326. # If two pixels are not adjacent, then mask the attention between them
  327. # Your can set the matrix element to -1000 when it is not adjacent, then add it to the attention
  328. mask = gen_mask(x)
  329. else:
  330. # e.g., zero matrix when you add mask to attention
  331. mask = no_mask
  332. # Reorganize data to calculate window attention
  333. x_window = reshape(x, target_shape=(x.shape[0], Z//window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], x.shape[-1]))
  334. x_window = TransposeDimensions(x_window, (0, 1, 3, 5, 2, 4, 6, 7))
  335. # Get data stacked in 3D cubes, which will further be used to calculated attention among each cube
  336. x_window = reshape(x_window, target_shape=(-1, window_size[0]* window_size[1]*window_size[2], x.shape[-1]))
  337. # Apply 3D window attention with Earth-Specific bias
  338. x_window = self.attention(x, mask)
  339. # Reorganize data to original shapes
  340. x = reshape(x_window, target_shape=((-1, Z // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], x_window.shape[-1])))
  341. x = TransposeDimensions(x, (0, 1, 4, 2, 5, 3, 6, 7))
  342. # Reshape the tensor back to its original shape
  343. x = reshape(x_window, target_shape=ori_shape)
  344. if roll:
  345. # Roll x back for half of the window
  346. x = roll3D(x, shift=[-self.window_size[0]//2, -self.window_size[1]//2, -self.window_size[2]//2])
  347. # Crop the zero-padding
  348. x = Crop3D(x)
  349. # Reshape the tensor back to the input shape
  350. x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[4]))
  351. # Main calculation stages
  352. x = shortcut + self.drop_path(self.norm1(x))
  353. x = x + self.drop_path(self.norm2(self.linear(x)))
  354. return x
  355. class EarthAttention3D:
  356. def __init__(self, dim, heads, dropout_rate, window_size):
  357. '''
  358. 3D window attention with the Earth-Specific bias,
  359. see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
  360. '''
  361. # Initialize several operations
  362. self.linear1 = Linear(dim, dim=3, bias=True)
  363. self.linear2 = Linear(dim, dim)
  364. self.softmax = SoftMax(dim=-1)
  365. self.dropout = DropOut(dropout_rate)
  366. # Store several attributes
  367. self.head_number = heads
  368. self.dim = dim
  369. self.scale = (dim//heads)**-0.5
  370. self.window_size = window_size
  371. # input_shape is current shape of the self.forward function
  372. # You can run your code to record it, modify the code and rerun it
  373. # Record the number of different window types
  374. self.type_of_windows = (input_shape[0]//window_size[0])*(input_shape[1]//window_size[1])
  375. # For each type of window, we will construct a set of parameters according to the paper
  376. self.earth_specific_bias = ConstructTensor(shape=((2 * window_size[2] - 1) * window_size[1] * window_size[1] * window_size[0] * window_size[0], self.type_of_windows, heads))
  377. # Making these tensors to be learnable parameters
  378. self.earth_specific_bias = Parameters(self.earth_specific_bias)
  379. # Initialize the tensors using Truncated normal distribution
  380. TruncatedNormalInit(self.earth_specific_bias, std=0.02)
  381. # Construct position index to reuse self.earth_specific_bias
  382. self.position_index = self._construct_index()
  383. def _construct_index(self):
  384. ''' This function construct the position index to reuse symmetrical parameters of the position bias'''
  385. # Index in the pressure level of query matrix
  386. coords_zi = RangeTensor(self.window_size[0])
  387. # Index in the pressure level of key matrix
  388. coords_zj = -RangeTensor(self.window_size[0])*self.window_size[0]
  389. # Index in the latitude of query matrix
  390. coords_hi = RangeTensor(self.window_size[1])
  391. # Index in the latitude of key matrix
  392. coords_hj = -RangeTensor(self.window_size[1])*self.window_size[1]
  393. # Index in the longitude of the key-value pair
  394. coords_w = RangeTensor(self.window_size[2])
  395. # Change the order of the index to calculate the index in total
  396. coords_1 = Stack(MeshGrid([coords_zi, coords_hi, coords_w]))
  397. coords_2 = Stack(MeshGrid([coords_zj, coords_hj, coords_w]))
  398. coords_flatten_1 = Flatten(coords_1, start_dimension=1)
  399. coords_flatten_2 = Flatten(coords_2, start_dimension=1)
  400. coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]
  401. coords = TransposeDimensions(coords, (1, 2, 0))
  402. # Shift the index for each dimension to start from 0
  403. coords[:, :, 2] += self.window_size[2] - 1
  404. coords[:, :, 1] *= 2 * self.window_size[2] - 1
  405. coords[:, :, 0] *= (2 * self.window_size[2] - 1)*self.window_size[1]*self.window_size[1]
  406. # Sum up the indexes in three dimensions
  407. self.position_index = TensorSum(coords, dim=-1)
  408. # Flatten the position index to facilitate further indexing
  409. self.position_index = Flatten(self.position_index)
  410. def forward(self, x, mask):
  411. # Linear layer to create query, key and value
  412. x = self.linear1(x)
  413. # Record the original shape of the input
  414. original_shape = x.shape
  415. # reshape the data to calculate multi-head attention
  416. qkv = reshape(x, target_shape=(x.shape[0], x.shape[1], 3, self.head_number, self.dim // self.head_number))
  417. query, key, value = TransposeDimensions(qkv, (2, 0, 3, 1, 4))
  418. # Scale the attention
  419. query = query * self.scale
  420. # Calculated the attention, a learnable bias is added to fix the nonuniformity of the grid.
  421. attention = query @ key.T # @ denotes matrix multiplication
  422. # self.earth_specific_bias is a set of neural network parameters to optimize.
  423. EarthSpecificBias = self.earth_specific_bias[self.position_index]
  424. # Reshape the learnable bias to the same shape as the attention matrix
  425. EarthSpecificBias = reshape(EarthSpecificBias, target_shape=(self.window_size[0]*self.window_size[1]*self.window_size[2], self.window_size[0]*self.window_size[1]*self.window_size[2], self.type_of_windows, self.head_number))
  426. EarthSpecificBias = TransposeDimensions(EarthSpecificBias, (2, 3, 0, 1))
  427. EarthSpecificBias = reshape(EarthSpecificBias, target_shape = [1]+EarthSpecificBias.shape)
  428. # Add the Earth-Specific bias to the attention matrix
  429. attention = attention + EarthSpecificBias
  430. # Mask the attention between non-adjacent pixels, e.g., simply add -100 to the masked element.
  431. attention = self.mask_attention(attention, mask)
  432. attention = self.softmax(attention)
  433. attention = self.dropout(attention)
  434. # Calculated the tensor after spatial mixing.
  435. x = attention @ value.T # @ denote matrix multiplication
  436. # Reshape tensor to the original shape
  437. x = TransposeDimensions(x, (0, 2, 1))
  438. x = reshape(x, target_shape = original_shape)
  439. # Linear layer to post-process operated tensor
  440. x = self.linear2(x)
  441. x = self.dropout(x)
  442. return x
  443. class Mlp:
  444. def __init__(self, dim, dropout_rate):
  445. '''MLP layers, same as most vision transformer architectures.'''
  446. self.linear1 = Linear(dim, dim * 4)
  447. self.linear2 = Linear(dim * 4, dim)
  448. self.activation = GeLU()
  449. self.drop = DropOut(drop_rate=dropout_rate)
  450. def forward(self, x):
  451. x = self.linear(x)
  452. x = self.activation(x)
  453. x = self.drop(x)
  454. x = self.linear(x)
  455. x = self.drop(x)
  456. return x
  457. def PerlinNoise():
  458. '''Generate random Perlin noise: we follow https://github.com/pvigier/perlin-numpy/ to calculate the perlin noise.'''
  459. # Define number of noise
  460. octaves = 3
  461. # Define the scaling factor of noise
  462. noise_scale = 0.2
  463. # Define the number of periods of noise along the axis
  464. period_number = 12
  465. # The size of an input slice
  466. H, W = 721, 1440
  467. # Scaling factor between two octaves
  468. persistence = 0.5
  469. # see https://github.com/pvigier/perlin-numpy/ for the implementation of GenerateFractalNoise (e.g., from perlin_numpy import generate_fractal_noise_3d)
  470. perlin_noise = noise_scale*GenerateFractalNoise((H, W), (period_number, period_number), octaves, persistence)
  471. return perlin_noise