123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- #import hfai
- #hfai.set_watchdog_time(21600)
- import os
- from pathlib import Path
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.backends.cudnn as cudnn
- import torch.distributed as dist
- #import hfai.nccl.distributed as dist
- #from haiscale.ddp import DistributedDataParallel
- #from haiscale.pipeline import PipeDream, make_subgroups, partition
- from torch.utils.data.distributed import DistributedSampler # 干啥用的?
- import timm.optim
- from timm.scheduler import create_scheduler
- import torch.multiprocessing as mp
- from torch.utils.data import DataLoader
- from model.graphcast_sequential import GraphCast
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
- from data_factory.datasets import ERA5, EarthGraph
- #from model.graphcast_sequential import get_graphcast_module
- from utils.params import get_graphcast_args
- from utils.tools import load_model, save_model
- from utils.eval import graphcast_evaluate
- SAVE_PATH = Path('/root/output/graphcast-torch/')
- EarthGraph_PATH = Path('/root/code/OpenCastKit-removehfai/EarthGraph/')
- SAVE_PATH.mkdir(parents=True, exist_ok=True)
- EarthGraph_PATH.mkdir(parents=True, exist_ok=True)
- def train_one_epoch(epoch, model, criterion, data_loader, graph, optimizer, lr_scheduler, min_loss, device):
- is_last_pipeline_stage = True #(pp_group.rank() == pp_group.size() - 1)
- loss = torch.tensor(0., device="cuda")
- count = torch.tensor(0., device="cuda")
- model.train()
- input_x = [
- None,
- graph.mesh_data.x.half().cuda(non_blocking=True),
- graph.mesh_data.edge_index.cuda(non_blocking=True),
- graph.mesh_data.edge_attr.half().cuda(non_blocking=True),
- graph.grid2mesh_data.edge_index.cuda(non_blocking=True),
- graph.grid2mesh_data.edge_attr.half().cuda(non_blocking=True),
- graph.mesh2grid_data.edge_index.cuda(non_blocking=True),
- graph.mesh2grid_data.edge_attr.half().cuda(non_blocking=True)
- ]
- for step, batch in enumerate(data_loader):
- '''
- # Creates model and optimizer in default precision
- model = Net().cuda()
- optimizer = optim.SGD(model.parameters(), ...)
- for input, target in data:
- optimizer.zero_grad()
- # Enables autocasting for the forward pass (model + loss)
- with torch.autocast(device_type="cuda"):
- output = model(input)
- loss = loss_fn(output, target)
- # Exits the context manager before backward()
- loss.backward()
- optimizer.step()
- '''
- optimizer.zero_grad()
-
- x, y = [x.half().cuda(non_blocking=True) for x in batch] # 在era5()中,已经给出 y,是两步预测。因为 banch =2 所以用两个数据进行预测
- input_x[0] = x
- # 在这里,x 是多少都没有关系。因为 在 gx = self.grid_feat_embedding(gx) 特征嵌入的时候,怎样都可以。不管输入多少步,甚至输入的都不是应该的数据都行。
- # 不是的,x 必须和--grid-node-num 保持一致。因为 self.grid_feat_embedding = nn.Sequential(nn.Linear(gdim, gemb, bias=True) 决定了输入维度。
- with torch.cuda.amp.autocast():
- #step_loss, _ = model.forward_backward(*input_x, criterion=criterion, labels=(y,))
- out = model(input_x[0], input_x[1], input_x[2], input_x[3], input_x[4], input_x[5], input_x[6], input_x[7]) # *input_x 如果 “out = model(input_x)”
- step_loss = criterion(out, y)
- loss.backward()
- optimizer.step()
- if is_last_pipeline_stage:
- loss += step_loss.sum().item()
- count += 1
- '''
- if dp_group.rank() == 0 and is_last_pipeline_stage and hfai.client.receive_suspend_command():
- save_model(model.module.module, epoch, step + 1, optimizer, lr_scheduler, min_loss, SAVE_PATH / 'latest.pt')
- hfai.go_suspend()
- '''
-
- # all-reduce in data paralel group
- if device == 0 and is_last_pipeline_stage:
- dist.all_reduce(loss) # group=dp_group 不设置,使用默认
- dist.all_reduce(count) # group=dp_group 不设置,使用默认
- loss = loss / count
- # broadcast from the last stage to other pipeline stages
- dist.all_reduce(loss) # group=dp_group 不设置,使用默认
- return loss.item()
- def train(local_rank, args):
- rank, world_size = dist.get_rank(), dist.get_world_size()
- # data parallel + pipeline parallel
- #dp_group, pp_group = make_subgroups(pp_size=args.pp_size)
- #dp_rank, dp_size = dp_group.rank(), dp_group.size()
- #pp_rank, pp_size = pp_group.rank(), pp_group.size()
- # is_last_pipeline_stage = (pp_group.rank() == pp_group.size() - 1)
- #dp_rank = # 默认就好,不用特别指定
- #dp_size = # 训练线程数,感觉没啥用
- #print(f"RANK {rank}: data parallel {dp_rank}/{dp_size}", flush=True)
- # model & criterion & optimizer
- print("Load model & criterion & optimizer...")
- '''
- model = get_graphcast_module(args)????
- # model = hfai.nn.to_hfai(model)
- balance = [1, 1, 1, 1, 1, 1, 1, 1]????
- model = partition(model, pp_group.rank(), pp_group.size(), balance=balance)????
- model = DistributedDataParallel(model.cuda(), process_group=dp_group)????
- model = PipeDream(model, args.chunks, process_group=pp_group)????
- '''
- model = GraphCast(args).cuda()
- #param_groups = timm.optim.optim_factory(model, args.weight_decay)
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95))
- lr_scheduler, _ = create_scheduler(args, optimizer)
- criterion = nn.MSELoss()
- # generate graph
- if os.path.exists( EarthGraph_PATH / "EarthGraph"):
- graph = torch.load(EarthGraph_PATH / "EarthGraph")
- else:
- graph = EarthGraph()
- graph.generate_graph()
- torch.save(graph, EarthGraph_PATH / "EarthGraph")
- # load grid data
- print("Load data...")
- train_dataset = ERA5(split="train", check_data=True, modelname='graphcast')
- train_datasampler = DistributedSampler(train_dataset, shuffle=True)
- #train_dataloader = train_dataset.loader(args.batch_size, sampler=train_datasampler, num_workers=8, pin_memory=True, drop_last=True)
- train_dataloader = DataLoader(dataset=train_dataset,
- batch_size=args.batch_size,
- sampler=train_datasampler, drop_last=True, num_workers=8, pin_memory = True)
-
- # 这里有问题。并没有区分测试和验证集。
- val_dataset = ERA5(split="val", check_data=True, modelname='graphcast')
- #val_datasampler = DistributedSampler(val_dataset, num_replicas=dp_size, rank=dp_rank, shuffle=True)
- val_datasampler = DistributedSampler(val_dataset, shuffle=True)
- #val_dataloader = val_dataset.loader(args.batch_size, sampler=val_datasampler, num_workers=8, pin_memory=True, drop_last=False)
- val_dataloader = DataLoader(dataset=val_dataset,
- batch_size=args.batch_size,
- sampler=train_datasampler, drop_last=False, num_workers=8, pin_memory = True)
-
- # load
- start_epoch, start_step, min_loss = load_model(model, optimizer, lr_scheduler, SAVE_PATH / 'latest.pt')
- if local_rank == 0:
- print(f"Start training for {args.epochs} epochs")
- for epoch in range(start_epoch, args.epochs):
- train_loss = train_one_epoch(epoch, model, criterion, train_dataloader, graph, optimizer, lr_scheduler, min_loss, device="cuda:0")
- lr_scheduler.step(epoch)
- val_loss = graphcast_evaluate(val_dataloader, graph, model, criterion, dp_group, pp_group)
- if True:
- print(f"Epoch {epoch} | Train loss: {train_loss:.6f}, Val loss: {val_loss:.6f}")
- if True:
- save_model(model.module.module, epoch + 1, optimizer, lr_scheduler, min_loss, SAVE_PATH / 'latest.pt')
- if val_loss < min_loss:
- min_loss = val_loss
- save_model(model.module.module, path=SAVE_PATH / 'best.pt', only_model=True)
- # synchronize all processes
- model.module.reducer.stop()
- dist.barrier()
- def main(local_rank, args):
- # fix the seed for reproducibility
- torch.manual_seed(2023)
- np.random.seed(2023)
- cudnn.benchmark = True
- # init dist
- ip = os.environ.get("MASTER_ADDR", "127.0.0.1")
- port = os.environ.get("MASTER_PORT", "22568")
- hosts = int(os.environ.get("WORLD_SIZE", "1")) # number of nodes
- rank = int(os.environ.get("RANK", "0")) # node id
- gpus = torch.cuda.device_count() # gpus per node
- dist.init_process_group(backend="nccl", init_method=f"tcp://{ip}:{port}", world_size=hosts * gpus, rank=rank * gpus + local_rank)
- torch.cuda.set_device(local_rank)
- print("TRAINING STARTED...")
- train(local_rank, args)
- if __name__ == '__main__':
- args = get_graphcast_args()
- ngpus = torch.cuda.device_count()
- #hfai.multiprocessing.spawn(main, args=(args,), nprocs=ngpus, bind_numa=True)
- main(0, args)
|