train_graphcast.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #import hfai
  2. #hfai.set_watchdog_time(21600)
  3. import os
  4. from pathlib import Path
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.backends.cudnn as cudnn
  9. import torch.distributed as dist
  10. #import hfai.nccl.distributed as dist
  11. #from haiscale.ddp import DistributedDataParallel
  12. #from haiscale.pipeline import PipeDream, make_subgroups, partition
  13. from torch.utils.data.distributed import DistributedSampler # 干啥用的?
  14. import timm.optim
  15. from timm.scheduler import create_scheduler
  16. import torch.multiprocessing as mp
  17. from torch.utils.data import DataLoader
  18. from model.graphcast_sequential import GraphCast
  19. torch.backends.cuda.matmul.allow_tf32 = True
  20. torch.backends.cudnn.allow_tf32 = True
  21. from data_factory.datasets import ERA5, EarthGraph
  22. #from model.graphcast_sequential import get_graphcast_module
  23. from utils.params import get_graphcast_args
  24. from utils.tools import load_model, save_model
  25. from utils.eval import graphcast_evaluate
  26. SAVE_PATH = Path('/root/output/graphcast-torch/')
  27. EarthGraph_PATH = Path('/root/code/OpenCastKit-removehfai/EarthGraph/')
  28. SAVE_PATH.mkdir(parents=True, exist_ok=True)
  29. EarthGraph_PATH.mkdir(parents=True, exist_ok=True)
  30. def train_one_epoch(epoch, model, criterion, data_loader, graph, optimizer, lr_scheduler, min_loss, device):
  31. is_last_pipeline_stage = True #(pp_group.rank() == pp_group.size() - 1)
  32. loss = torch.tensor(0., device="cuda")
  33. count = torch.tensor(0., device="cuda")
  34. model.train()
  35. input_x = [
  36. None,
  37. graph.mesh_data.x.half().cuda(non_blocking=True),
  38. graph.mesh_data.edge_index.cuda(non_blocking=True),
  39. graph.mesh_data.edge_attr.half().cuda(non_blocking=True),
  40. graph.grid2mesh_data.edge_index.cuda(non_blocking=True),
  41. graph.grid2mesh_data.edge_attr.half().cuda(non_blocking=True),
  42. graph.mesh2grid_data.edge_index.cuda(non_blocking=True),
  43. graph.mesh2grid_data.edge_attr.half().cuda(non_blocking=True)
  44. ]
  45. for step, batch in enumerate(data_loader):
  46. '''
  47. # Creates model and optimizer in default precision
  48. model = Net().cuda()
  49. optimizer = optim.SGD(model.parameters(), ...)
  50. for input, target in data:
  51. optimizer.zero_grad()
  52. # Enables autocasting for the forward pass (model + loss)
  53. with torch.autocast(device_type="cuda"):
  54. output = model(input)
  55. loss = loss_fn(output, target)
  56. # Exits the context manager before backward()
  57. loss.backward()
  58. optimizer.step()
  59. '''
  60. optimizer.zero_grad()
  61. x, y = [x.half().cuda(non_blocking=True) for x in batch] # 在era5()中,已经给出 y,是两步预测。因为 banch =2 所以用两个数据进行预测
  62. input_x[0] = x
  63. # 在这里,x 是多少都没有关系。因为 在 gx = self.grid_feat_embedding(gx) 特征嵌入的时候,怎样都可以。不管输入多少步,甚至输入的都不是应该的数据都行。
  64. # 不是的,x 必须和--grid-node-num 保持一致。因为 self.grid_feat_embedding = nn.Sequential(nn.Linear(gdim, gemb, bias=True) 决定了输入维度。
  65. with torch.cuda.amp.autocast():
  66. #step_loss, _ = model.forward_backward(*input_x, criterion=criterion, labels=(y,))
  67. out = model(input_x[0], input_x[1], input_x[2], input_x[3], input_x[4], input_x[5], input_x[6], input_x[7]) # *input_x 如果 “out = model(input_x)”
  68. step_loss = criterion(out, y)
  69. loss.backward()
  70. optimizer.step()
  71. if is_last_pipeline_stage:
  72. loss += step_loss.sum().item()
  73. count += 1
  74. '''
  75. if dp_group.rank() == 0 and is_last_pipeline_stage and hfai.client.receive_suspend_command():
  76. save_model(model.module.module, epoch, step + 1, optimizer, lr_scheduler, min_loss, SAVE_PATH / 'latest.pt')
  77. hfai.go_suspend()
  78. '''
  79. # all-reduce in data paralel group
  80. if device == 0 and is_last_pipeline_stage:
  81. dist.all_reduce(loss) # group=dp_group 不设置,使用默认
  82. dist.all_reduce(count) # group=dp_group 不设置,使用默认
  83. loss = loss / count
  84. # broadcast from the last stage to other pipeline stages
  85. dist.all_reduce(loss) # group=dp_group 不设置,使用默认
  86. return loss.item()
  87. def train(local_rank, args):
  88. rank, world_size = dist.get_rank(), dist.get_world_size()
  89. # data parallel + pipeline parallel
  90. #dp_group, pp_group = make_subgroups(pp_size=args.pp_size)
  91. #dp_rank, dp_size = dp_group.rank(), dp_group.size()
  92. #pp_rank, pp_size = pp_group.rank(), pp_group.size()
  93. # is_last_pipeline_stage = (pp_group.rank() == pp_group.size() - 1)
  94. #dp_rank = # 默认就好,不用特别指定
  95. #dp_size = # 训练线程数,感觉没啥用
  96. #print(f"RANK {rank}: data parallel {dp_rank}/{dp_size}", flush=True)
  97. # model & criterion & optimizer
  98. print("Load model & criterion & optimizer...")
  99. '''
  100. model = get_graphcast_module(args)????
  101. # model = hfai.nn.to_hfai(model)
  102. balance = [1, 1, 1, 1, 1, 1, 1, 1]????
  103. model = partition(model, pp_group.rank(), pp_group.size(), balance=balance)????
  104. model = DistributedDataParallel(model.cuda(), process_group=dp_group)????
  105. model = PipeDream(model, args.chunks, process_group=pp_group)????
  106. '''
  107. model = GraphCast(args).cuda()
  108. #param_groups = timm.optim.optim_factory(model, args.weight_decay)
  109. optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95))
  110. lr_scheduler, _ = create_scheduler(args, optimizer)
  111. criterion = nn.MSELoss()
  112. # generate graph
  113. if os.path.exists( EarthGraph_PATH / "EarthGraph"):
  114. graph = torch.load(EarthGraph_PATH / "EarthGraph")
  115. else:
  116. graph = EarthGraph()
  117. graph.generate_graph()
  118. torch.save(graph, EarthGraph_PATH / "EarthGraph")
  119. # load grid data
  120. print("Load data...")
  121. train_dataset = ERA5(split="train", check_data=True, modelname='graphcast')
  122. train_datasampler = DistributedSampler(train_dataset, shuffle=True)
  123. #train_dataloader = train_dataset.loader(args.batch_size, sampler=train_datasampler, num_workers=8, pin_memory=True, drop_last=True)
  124. train_dataloader = DataLoader(dataset=train_dataset,
  125. batch_size=args.batch_size,
  126. sampler=train_datasampler, drop_last=True, num_workers=8, pin_memory = True)
  127. # 这里有问题。并没有区分测试和验证集。
  128. val_dataset = ERA5(split="val", check_data=True, modelname='graphcast')
  129. #val_datasampler = DistributedSampler(val_dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True)
  130. val_datasampler = DistributedSampler(val_dataset, shuffle=True)
  131. #val_dataloader = val_dataset.loader(args.batch_size, sampler=val_datasampler, num_workers=8, pin_memory=True, drop_last=False)
  132. val_dataloader = DataLoader(dataset=val_dataset,
  133. batch_size=args.batch_size,
  134. sampler=train_datasampler, drop_last=False, num_workers=8, pin_memory = True)
  135. # load
  136. start_epoch, start_step, min_loss = load_model(model, optimizer, lr_scheduler, SAVE_PATH / 'latest.pt')
  137. if local_rank == 0:
  138. print(f"Start training for {args.epochs} epochs")
  139. for epoch in range(start_epoch, args.epochs):
  140. train_loss = train_one_epoch(epoch, model, criterion, train_dataloader, graph, optimizer, lr_scheduler, min_loss, device="cuda:0")
  141. lr_scheduler.step(epoch)
  142. val_loss = graphcast_evaluate(val_dataloader, graph, model, criterion, dp_group, pp_group)
  143. if True:
  144. print(f"Epoch {epoch} | Train loss: {train_loss:.6f}, Val loss: {val_loss:.6f}")
  145. if True:
  146. save_model(model.module.module, epoch + 1, optimizer, lr_scheduler, min_loss, SAVE_PATH / 'latest.pt')
  147. if val_loss < min_loss:
  148. min_loss = val_loss
  149. save_model(model.module.module, path=SAVE_PATH / 'best.pt', only_model=True)
  150. # synchronize all processes
  151. model.module.reducer.stop()
  152. dist.barrier()
  153. def main(local_rank, args):
  154. # fix the seed for reproducibility
  155. torch.manual_seed(2023)
  156. np.random.seed(2023)
  157. cudnn.benchmark = True
  158. # init dist
  159. ip = os.environ.get("MASTER_ADDR", "127.0.0.1")
  160. port = os.environ.get("MASTER_PORT", "22568")
  161. hosts = int(os.environ.get("WORLD_SIZE", "1")) # number of nodes
  162. rank = int(os.environ.get("RANK", "0")) # node id
  163. gpus = torch.cuda.device_count() # gpus per node
  164. dist.init_process_group(backend="nccl", init_method=f"tcp://{ip}:{port}", world_size=hosts * gpus, rank=rank * gpus + local_rank)
  165. torch.cuda.set_device(local_rank)
  166. print("TRAINING STARTED...")
  167. train(local_rank, args)
  168. if __name__ == '__main__':
  169. args = get_graphcast_args()
  170. ngpus = torch.cuda.device_count()
  171. #hfai.multiprocessing.spawn(main, args=(args,), nprocs=ngpus, bind_numa=True)
  172. main(0, args)