123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600 |
- '''
- Pseudocode of Pangu-Weather
- '''
- # The pseudocode can be implemented using deep learning libraries, e.g., Pytorch and Tensorflow or other high-level APIs
- # Basic operations used in our model, namely Linear, Conv3d, Conv2d, ConvTranspose3d and ConvTranspose2d
- # Linear: Linear transformation, available in all deep learning libraries
- # Conv3d and Con2d: Convolution with 2 or 3 dimensions, available in all deep learning libraries
- # ConvTranspose3d, ConvTranspose2d: transposed convolution with 2 or 3 dimensions, see Pytorch API or Tensorflow API
- from Your_AI_Library import Linear, Conv3d, Conv2d, ConvTranspose3d, ConvTranspose2d
- # Functions in the networks, namely GeLU, DropOut, DropPath, LayerNorm, and SoftMax
- # GeLU: the GeLU activation function, see Pytorch API or Tensorflow API
- # DropOut: the dropout function, available in all deep learning libraries
- # DropPath: the DropPath function, see the implementation of vision-transformer, see timm pakage of Pytorch
- # A possible implementation of DropPath: from timm.models.layers import DropPath
- # LayerNorm: the layer normalization function, see Pytorch API or Tensorflow API
- # Softmax: softmax function, see Pytorch API or Tensorflow API
- from Your_AI_Library import GeLU, DropOut, DropPath, LayerNorm, SoftMax
- # Common functions for roll, pad, and crop, depends on the data structure of your software environment
- from Your_AI_Library import roll3D, pad3D, pad2D, Crop3D, Crop2D
- # Common functions for reshaping and changing the order of dimensions
- # reshape: change the shape of the data with the order unchanged, see Pytorch API or Tensorflow API
- # TransposeDimensions: change the order of the dimensions, see Pytorch API or Tensorflow API
- from Your_AI_Library import reshape, TransposeDimensions
- # Common functions for creating new tensors
- # ConstructTensor: create a new tensor with an arbitrary shape
- # TruncatedNormalInit: Initialize the tensor with Truncate Normalization distribution
- # RangeTensor: create a new tensor like range(a, b)
- from Your_AI_Library import ConstructTensor, TruncatedNormalInit, RangeTensor
- # Common operations for the data, you may design it or simply use deep learning APIs default operations
- # LinearSpace: a tensor version of numpy.linspace
- # MeshGrid: a tensor version of numpy.meshgrid
- # Stack: a tensor version of numpy.stack
- # Flatten: a tensor version of numpy.ndarray.flatten
- # TensorSum: a tensor version of numpy.sum
- # TensorAbs: a tensor version of numpy.abs
- # Concatenate: a tensor version of numpy.concatenate
- from Your_AI_Library import LinearSpace, MeshGrid, Stack, Flatten, TensorSum, TensorAbs, Concatenate
- # Common functions for training models
- # LoadModel and SaveModel: Load and save the model, some APIs may require further adaptation to hardwares
- # Backward: Gradient backward to calculate the gratitude of each parameters
- # UpdateModelParametersWithAdam: Use Adam to update parameters, e.g., torch.optim.Adam
- from Your_AI_Library import LoadModel, Backward, UpdateModelParametersWithAdam, SaveModel
- # Custom functions to read your data from the disc
- # LoadData: Load the ERA5 data
- # LoadConstantMask: Load constant masks, e.g., soil type
- # LoadStatic: Load mean and std of the ERA5 training data, every fields such as T850 is treated as an image and calculate the mean and std
- from Your_Data_Code import LoadData, LoadConstantMask, LoadStatic
- def Inference(input, input_surface, forecast_range):
- '''Inference code, describing the algorithm of inference using models with different lead times.
- PanguModel24, PanguModel6, PanguModel3 and PanguModel1 share the same training algorithm but differ in lead times.
- Args:
- input: input tensor, need to be normalized to N(0, 1) in practice
- input_surface: target tensor, need to be normalized to N(0, 1) in practice
- forecast_range: iteration numbers when roll out the forecast model
- '''
- # Load 4 pre-trained models with different lead times
- PanguModel24 = LoadModel(ModelPath24)
- PanguModel6 = LoadModel(ModelPath6)
- PanguModel3 = LoadModel(ModelPath3)
- PanguModel1 = LoadModel(ModelPath1)
- # Load mean and std of the weather data
- weather_mean, weather_std, weather_surface_mean, weather_surface_std = LoadStatic()
- # Store initial input for different models
- input_24, input_surface_24 = input, input_surface
- input_6, input_surface_6 = input, input_surface
- input_3, input_surface_3 = input, input_surface
- # Using a list to store output
- output_list = []
- # Note: the following code is implemented for fast inference of [1,forecast_range]-hour forecasts -- if only one lead time is requested, the inference can be much faster.
- for i in range(forecast_range):
- # switch to the 24-hour model if the forecast time is 24 hours, 48 hours, ..., 24*N hours
- if (i+1) % 24 == 0:
- # Switch the input back to the stored input
- input, input_surface = input_24, input_surface_24
- # Call the model pretrained for 24 hours forecast
- output, output_surface = PanguModel24(input, input_surface)
- # Restore from uniformed output
- output = output * weather_std + weather_mean
- output_surface = output_surface * weather_surface_std + weather_surface_mean
- # Stored the output for next round forecast
- input_24, input_surface_24 = output, output_surface
- input_6, input_surface_6 = output, output_surface
- input_3, input_surface_3 = output, output_surface
- # switch to the 6-hour model if the forecast time is 30 hours, 36 hours, ..., 24*N + 6/12/18 hours
- elif (i+1) % 6 == 0:
- # Switch the input back to the stored input
- input, input_surface = input_6, input_surface_6
- # Call the model pretrained for 6 hours forecast
- output, output_surface = PanguModel6(input, input_surface)
- # Restore from uniformed output
- output = output * weather_std + weather_mean
- output_surface = output_surface * weather_surface_std + weather_surface_mean
-
- # Stored the output for next round forecast
- input_6, input_surface_6 = output, output_surface
- input_3, input_surface_3 = output, output_surface
- # switch to the 3-hour model if the forecast time is 3 hours, 9 hours, ..., 6*N + 3 hours
- elif (i+1) % 3 ==0:
- # Switch the input back to the stored input
- input, input_surface = input_3, input_surface_3
- # Call the model pretrained for 3 hours forecast
- output, output_surface = PanguModel3(input, input_surface)
- # Restore from uniformed output
- output = output * weather_std + weather_mean
- output_surface = output_surface * weather_surface_std + weather_surface_mean
-
- # Stored the output for next round forecast
- input_3, input_surface_3 = output, output_surface
- # switch to the 1-hour model
- else:
- # Call the model pretrained for 1 hours forecast
- output, output_surface = PanguModel1(input, input_surface)
- # Restore from uniformed output
- output = output * weather_std + weather_mean
- output_surface = output_surface * weather_surface_std + weather_surface_mean
- # Stored the output for next round forecast
- input, input_surface = output, output_surface
- # Save the output
- output_list.append((output, output_surface))
- return output_list
- def Train():
- '''Training code'''
- # Initialize the model, for some APIs some adaptation is needed to fit hardwares
- model = PanguModel()
- # Train single Pangu-Weather model
- epochs = 100
- for i in range(epochs):
- # For each epoch, we iterate from 1979 to 2017
- # dataset_length is the length of your training data, e.g., the sample between 1979 and 2017
- for step in range(dataset_length):
- # Load weather data at time t as the input; load weather data at time t+1/3/6/24 as the output
- # Note the data need to be randomly shuffled
- # Note the input and target need to be normalized, see Inference() for details
- input, input_surface, target, target_surface = LoadData(step)
- # Call the model and get the output
- output, output_surface = model(input, input_surface)
- # We use the MAE loss to train the model
- # The weight of surface loss is 0.25
- # Different weight can be applied for differen fields if needed
- loss = TensorAbs(output-target) + TensorAbs(output_surface-target_surface) * 0.25
- # Call the backward algorithm and calculate the gratitude of parameters
- Backward(loss)
- # Update model parameters with Adam optimizer
- # The learning rate is 5e-4 as in the paper, while the weight decay is 3e-6
- # A example solution is using torch.optim.adam
- UpdateModelParametersWithAdam()
- # Save the model at the end of the training stage
- SaveModel()
- class PanguModel:
- def __init__(self):
- # Drop path rate is linearly increased as the depth increases
- drop_path_list = LinearSpace(0, 0.2, 8)
- # Patch embedding
- self._input_layer = PatchEmbedding((2, 4, 4), 192)
- # Four basic layers
- self.layer1 = EarthSpecificLayer(2, 192, drop_list[:2], 6)
- self.layer2 = EarthSpecificLayer(6, 384, drop_list[6:], 12)
- self.layer3 = EarthSpecificLayer(6, 384, drop_list[6:], 12)
- self.layer4 = EarthSpecificLayer(2, 192, drop_list[:2], 6)
- # Upsample and downsample
- self.upsample = UpSample(384, 192)
- self.downsample = DownSample(192)
- # Patch Recovery
- self._output_layer = PatchRecovery(384)
-
- def forward(self, input, input_surface):
- '''Backbone architecture'''
- # Embed the input fields into patches
- x = self._input_layer(input, input_surface)
- # Encoder, composed of two layers
- # Layer 1, shape (8, 360, 181, C), C = 192 as in the original paper
- x = self.layer1(x, 8, 360, 181)
- # Store the tensor for skip-connection
- skip = x
- # Downsample from (8, 360, 181) to (8, 180, 91)
- x = self.downsample(x, 8, 360, 181)
- # Layer 2, shape (8, 180, 91, 2C), C = 192 as in the original paper
- x = self.layer2(x, 8, 180, 91)
- # Decoder, composed of two layers
- # Layer 3, shape (8, 180, 91, 2C), C = 192 as in the original paper
- x = self.layer3(x, 8, 180, 91)
- # Upsample from (8, 180, 91) to (8, 360, 181)
- x = self.upsample(x)
- # Layer 4, shape (8, 360, 181, 2C), C = 192 as in the original paper
- x = self.layer4(x, 8, 360, 181)
- # Skip connect, in last dimension(C from 192 to 384)
- x = Concatenate(skip, x)
- # Recover the output fields from patches
- output, output_surface = self._output_layer(x)
- return output, output_surface
- class PatchEmbedding:
- def __init__(self, patch_size, dim):
- '''Patch embedding operation'''
- # Here we use convolution to partition data into cubes
- self.conv = Conv3d(input_dims=5, output_dims=dim, kernel_size=patch_size, stride=patch_size)
- self.conv_surface = Conv2d(input_dims=7, output_dims=dim, kernel_size=patch_size[1:], stride=patch_size[1:])
- # Load constant masks from the disc
- self.land_mask, self.soil_type, self.topography = LoadConstantMask()
-
- def forward(self, input, input_surface):
- # Zero-pad the input
- input = Pad3D(input)
- input_surface = Pad2D(input_surface)
- # Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches, patch_size = (2, 4, 4) as in the original paper
- input = self.conv(input)
- # Add three constant fields to the surface fields
- input_surface = Concatenate(input_surface, self.land_mask, self.soil_type, self.topography)
- # Apply a linear projection for patch_size[1]*patch_size[2] patches
- input_surface = self.conv_surface(input_surface)
- # Concatenate the input in the pressure level, i.e., in Z dimension
- x = Concatenate(input, input_surface)
- # Reshape x for calculation of linear projections
- x = TransposeDimensions(x, (0, 2, 3, 4, 1))
- x = reshape(x, target_shape=(x.shape[0], 8*360*181, x.shape[-1]))
- return x
-
- class PatchRecovery:
- def __init__(self, dim):
- '''Patch recovery operation'''
- # Hear we use two transposed convolutions to recover data
- self.conv = ConvTranspose3d(input_dims=dim, output_dims=5, kernel_size=patch_size, stride=patch_size)
- self.conv_surface = ConvTranspose2d(input_dims=dim, output_dims=4, kernel_size=patch_size[1:], stride=patch_size[1:])
-
- def forward(self, x, Z, H, W):
- # The inverse operation of the patch embedding operation, patch_size = (2, 4, 4) as in the original paper
- # Reshape x back to three dimensions
- x = TransposeDimensions(x, (0, 2, 1))
- x = reshape(x, target_shape=(x.shape[0], x.shape[1], Z, H, W))
- # Call the transposed convolution
- output = self.conv(x[:, :, 1:, :, :])
- output_surface = self.conv_surface(x[:, :, 0, :, :])
- # Crop the output to remove zero-paddings
- output = Crop3D(output)
- output_surface = Crop2D(output_surface)
- return output, output_surface
- class DownSample:
- def __init__(self, dim):
- '''Down-sampling operation'''
- # A linear function and a layer normalization
- self.linear = Linear(4*dim, 2*dim, bias=Fasle)
- self.norm = LayerNorm(4*dim)
-
- def forward(self, x, Z, H, W):
- # Reshape x to three dimensions for downsampling
- x = reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[-1]))
- # Padding the input to facilitate downsampling
- x = Pad3D(x)
- # Reorganize x to reduce the resolution: simply change the order and downsample from (8, 360, 182) to (8, 180, 91)
- Z, H, W = x.shape
- # Reshape x to facilitate downsampling
- x = reshape(x, target_shape=(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1]))
- # Change the order of x
- x = TransposeDimensions(x, (0,1,2,4,3,5,6))
- # Reshape to get a tensor of resolution (8, 180, 91)
- x = reshape(x, target_shape=(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1]))
- # Call the layer normalization
- x = self.norm(x)
- # Decrease the channels of the data to reduce computation cost
- x = self.linear(x)
- return x
- class UpSample:
- def __init__(self, input_dim, output_dim):
- '''Up-sampling operation'''
- # Linear layers without bias to increase channels of the data
- self.linear1 = Linear(input_dim, output_dim*4, bias=False)
- # Linear layers without bias to mix the data up
- self.linear2 = Linear(output_dim, output_dim, bias=False)
- # Normalization
- self.norm = LayerNorm(output_dim)
-
- def forward(self, x):
- # Call the linear functions to increase channels of the data
- x = self.linear1(x)
- # Reorganize x to increase the resolution: simply change the order and upsample from (8, 180, 91) to (8, 360, 182)
- # Reshape x to facilitate upsampling.
- x = reshape(x, target_shape=(x.shape[0], 8, 180, 91, 2, 2, x.shape[-1]//4))
- # Change the order of x
- x = TransposeDimensions(x, (0,1,2,4,3,5,6))
- # Reshape to get Tensor with a resolution of (8, 360, 182)
- x = reshape(x, target_shape=(x.shape[0], 8, 360, 182, x.shape[-1]))
- # Crop the output to the input shape of the network
- x = Crop3D(x)
- # Reshape x back
- x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1]))
- # Call the layer normalization
- x = self.norm(x)
- # Mixup normalized tensors
- x = self.linear2(x)
- return x
-
- class EarthSpecificLayer:
- def __init__(self, depth, dim, drop_path_ratio_list, heads):
- '''Basic layer of our network, contains 2 or 6 blocks'''
- self.depth = depth
- self.blocks = []
- # Construct basic blocks
- for i in range(depth):
- self.blocks.append(EarthSpecificBlock(dim, drop_path_ratio_list[i], heads))
-
- def forward(self, x, Z, H, W):
- for i in range(self.depth):
- # Roll the input every two blocks
- if i % 2 == 0:
- self.blocks[i](x, Z, H, W, roll=False)
- else:
- self.blocks[i](x, Z, H, W, roll=True)
- return x
- class EarthSpecificBlock:
- def __init__(self, dim, drop_path_ratio, heads):
- '''
- 3D transformer block with Earth-Specific bias and window attention,
- see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
- The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.
- '''
- # Define the window size of the neural network
- self.window_size = (2, 6, 12)
- # Initialize serveral operations
- self.drop_path = DropPath(drop_rate=drop_path_ratio)
- self.norm1 = LayerNorm(dim)
- self.norm2 = LayerNorm(dim)
- self.linear = MLP(dim, 0)
- self.attention = EarthAttention3D(dim, heads, 0, self.window_size)
- def forward(self, x, Z, H, W, roll):
- # Save the shortcut for skip-connection
- shortcut = x
- # Reshape input to three dimensions to calculate window attention
- reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[2]))
- # Zero-pad input if needed
- x = pad3D(x)
- # Store the shape of the input for restoration
- ori_shape = x.shape
- if roll:
- # Roll x for half of the window for 3 dimensions
- x = roll3D(x, shift=[self.window_size[0]//2, self.window_size[1]//2, self.window_size[2]//2])
- # Generate mask of attention masks
- # If two pixels are not adjacent, then mask the attention between them
- # Your can set the matrix element to -1000 when it is not adjacent, then add it to the attention
- mask = gen_mask(x)
- else:
- # e.g., zero matrix when you add mask to attention
- mask = no_mask
- # Reorganize data to calculate window attention
- x_window = reshape(x, target_shape=(x.shape[0], Z//window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], x.shape[-1]))
- x_window = TransposeDimensions(x_window, (0, 1, 3, 5, 2, 4, 6, 7))
- # Get data stacked in 3D cubes, which will further be used to calculated attention among each cube
- x_window = reshape(x_window, target_shape=(-1, window_size[0]* window_size[1]*window_size[2], x.shape[-1]))
- # Apply 3D window attention with Earth-Specific bias
- x_window = self.attention(x, mask)
- # Reorganize data to original shapes
- x = reshape(x_window, target_shape=((-1, Z // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], x_window.shape[-1])))
- x = TransposeDimensions(x, (0, 1, 4, 2, 5, 3, 6, 7))
- # Reshape the tensor back to its original shape
- x = reshape(x_window, target_shape=ori_shape)
- if roll:
- # Roll x back for half of the window
- x = roll3D(x, shift=[-self.window_size[0]//2, -self.window_size[1]//2, -self.window_size[2]//2])
- # Crop the zero-padding
- x = Crop3D(x)
- # Reshape the tensor back to the input shape
- x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[4]))
- # Main calculation stages
- x = shortcut + self.drop_path(self.norm1(x))
- x = x + self.drop_path(self.norm2(self.linear(x)))
- return x
-
- class EarthAttention3D:
- def __init__(self, dim, heads, dropout_rate, window_size):
- '''
- 3D window attention with the Earth-Specific bias,
- see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
- '''
- # Initialize several operations
- self.linear1 = Linear(dim, dim=3, bias=True)
- self.linear2 = Linear(dim, dim)
- self.softmax = SoftMax(dim=-1)
- self.dropout = DropOut(dropout_rate)
- # Store several attributes
- self.head_number = heads
- self.dim = dim
- self.scale = (dim//heads)**-0.5
- self.window_size = window_size
- # input_shape is current shape of the self.forward function
- # You can run your code to record it, modify the code and rerun it
- # Record the number of different window types
- self.type_of_windows = (input_shape[0]//window_size[0])*(input_shape[1]//window_size[1])
- # For each type of window, we will construct a set of parameters according to the paper
- self.earth_specific_bias = ConstructTensor(shape=((2 * window_size[2] - 1) * window_size[1] * window_size[1] * window_size[0] * window_size[0], self.type_of_windows, heads))
- # Making these tensors to be learnable parameters
- self.earth_specific_bias = Parameters(self.earth_specific_bias)
- # Initialize the tensors using Truncated normal distribution
- TruncatedNormalInit(self.earth_specific_bias, std=0.02)
- # Construct position index to reuse self.earth_specific_bias
- self.position_index = self._construct_index()
-
- def _construct_index(self):
- ''' This function construct the position index to reuse symmetrical parameters of the position bias'''
- # Index in the pressure level of query matrix
- coords_zi = RangeTensor(self.window_size[0])
- # Index in the pressure level of key matrix
- coords_zj = -RangeTensor(self.window_size[0])*self.window_size[0]
- # Index in the latitude of query matrix
- coords_hi = RangeTensor(self.window_size[1])
- # Index in the latitude of key matrix
- coords_hj = -RangeTensor(self.window_size[1])*self.window_size[1]
- # Index in the longitude of the key-value pair
- coords_w = RangeTensor(self.window_size[2])
- # Change the order of the index to calculate the index in total
- coords_1 = Stack(MeshGrid([coords_zi, coords_hi, coords_w]))
- coords_2 = Stack(MeshGrid([coords_zj, coords_hj, coords_w]))
- coords_flatten_1 = Flatten(coords_1, start_dimension=1)
- coords_flatten_2 = Flatten(coords_2, start_dimension=1)
- coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]
- coords = TransposeDimensions(coords, (1, 2, 0))
- # Shift the index for each dimension to start from 0
- coords[:, :, 2] += self.window_size[2] - 1
- coords[:, :, 1] *= 2 * self.window_size[2] - 1
- coords[:, :, 0] *= (2 * self.window_size[2] - 1)*self.window_size[1]*self.window_size[1]
- # Sum up the indexes in three dimensions
- self.position_index = TensorSum(coords, dim=-1)
- # Flatten the position index to facilitate further indexing
- self.position_index = Flatten(self.position_index)
-
- def forward(self, x, mask):
- # Linear layer to create query, key and value
- x = self.linear1(x)
- # Record the original shape of the input
- original_shape = x.shape
- # reshape the data to calculate multi-head attention
- qkv = reshape(x, target_shape=(x.shape[0], x.shape[1], 3, self.head_number, self.dim // self.head_number))
- query, key, value = TransposeDimensions(qkv, (2, 0, 3, 1, 4))
- # Scale the attention
- query = query * self.scale
- # Calculated the attention, a learnable bias is added to fix the nonuniformity of the grid.
- attention = query @ key.T # @ denotes matrix multiplication
- # self.earth_specific_bias is a set of neural network parameters to optimize.
- EarthSpecificBias = self.earth_specific_bias[self.position_index]
- # Reshape the learnable bias to the same shape as the attention matrix
- EarthSpecificBias = reshape(EarthSpecificBias, target_shape=(self.window_size[0]*self.window_size[1]*self.window_size[2], self.window_size[0]*self.window_size[1]*self.window_size[2], self.type_of_windows, self.head_number))
- EarthSpecificBias = TransposeDimensions(EarthSpecificBias, (2, 3, 0, 1))
- EarthSpecificBias = reshape(EarthSpecificBias, target_shape = [1]+EarthSpecificBias.shape)
- # Add the Earth-Specific bias to the attention matrix
- attention = attention + EarthSpecificBias
- # Mask the attention between non-adjacent pixels, e.g., simply add -100 to the masked element.
- attention = self.mask_attention(attention, mask)
- attention = self.softmax(attention)
- attention = self.dropout(attention)
- # Calculated the tensor after spatial mixing.
- x = attention @ value.T # @ denote matrix multiplication
- # Reshape tensor to the original shape
- x = TransposeDimensions(x, (0, 2, 1))
- x = reshape(x, target_shape = original_shape)
- # Linear layer to post-process operated tensor
- x = self.linear2(x)
- x = self.dropout(x)
- return x
-
- class Mlp:
- def __init__(self, dim, dropout_rate):
- '''MLP layers, same as most vision transformer architectures.'''
- self.linear1 = Linear(dim, dim * 4)
- self.linear2 = Linear(dim * 4, dim)
- self.activation = GeLU()
- self.drop = DropOut(drop_rate=dropout_rate)
-
- def forward(self, x):
- x = self.linear(x)
- x = self.activation(x)
- x = self.drop(x)
- x = self.linear(x)
- x = self.drop(x)
- return x
- def PerlinNoise():
- '''Generate random Perlin noise: we follow https://github.com/pvigier/perlin-numpy/ to calculate the perlin noise.'''
- # Define number of noise
- octaves = 3
- # Define the scaling factor of noise
- noise_scale = 0.2
- # Define the number of periods of noise along the axis
- period_number = 12
- # The size of an input slice
- H, W = 721, 1440
- # Scaling factor between two octaves
- persistence = 0.5
- # see https://github.com/pvigier/perlin-numpy/ for the implementation of GenerateFractalNoise (e.g., from perlin_numpy import generate_fractal_noise_3d)
- perlin_noise = noise_scale*GenerateFractalNoise((H, W), (period_number, period_number), octaves, persistence)
- return perlin_noise
|