pretrain_graphcast_nonpall.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. import os
  2. from pathlib import Path
  3. import numpy as np
  4. import pandas as pd
  5. import torch
  6. import torch.nn as nn
  7. import random
  8. import torch.backends.cudnn as cudnn
  9. from torch.utils.data import Dataset, DataLoader
  10. import torch.distributed as dist
  11. import xarray as xr
  12. from torch.utils.data.distributed import DistributedSampler
  13. from torch.nn.parallel import DistributedDataParallel as DDP
  14. import torch.multiprocessing as mp
  15. #import timm.optim
  16. from timm.scheduler import create_scheduler
  17. from torch.cuda.amp import GradScaler
  18. from data_factory.datasets import ERA5, EarthGraph
  19. from model.graphcast_sequential import GraphCast
  20. from utils.params import get_graphcast_args
  21. from utils.tools import load_model, save_model
  22. import pickle
  23. from tqdm import tqdm
  24. # from utils.eval import graphcast_evaluate
  25. os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'
  26. torch.backends.cuda.matmul.allow_tf32 = True
  27. torch.backends.cudnn.allow_tf32 = True
  28. # To find a free port that is not blocked by firewalls, For check: netstat -atlpn | grep 45549
  29. import socket
  30. def find_free_port():
  31. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  32. s.bind(('', 0)) # Bind to a port that is free
  33. return s.getsockname()[1] # Return the port number
  34. SAVE_PATH = Path('/root/output/graphcast-torch/')
  35. # SAVE_PATH.mkdir(parents=True, exist_ok=True)
  36. # device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
  37. # print(device)
  38. data_dir = 'autodl-fs/data/dataset'
  39. data_path = "/root/autodl-fs/data/dataset/dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-40.nc"
  40. def chunk_time(ds):
  41. dims = {k:v for k, v in ds.dims.items()}
  42. dims['time'] = 1
  43. ds = ds.chunk(dims)
  44. return ds
  45. def load_dataset():
  46. ds = []
  47. '''
  48. for y in range(2007, 2017):
  49. data_name = os.path.join(data_dir, f'weather_round1_train_{y}')
  50. x = xr.open_zarr(data_name, consolidated=True)
  51. print(f'{data_name}, {x.time.values[0]} ~ {x.time.values[-1]}')
  52. ds.append(x)
  53. ds = xr.concat(ds, 'time')
  54. ds = chunk_time(ds)
  55. '''
  56. with open(f"{data_path}", "rb") as f:
  57. ds = xr.load_dataset(f).compute()
  58. return ds
  59. def compute_rmse(out, tgt):
  60. rmse = torch.sqrt(((out - tgt)**2).mean())
  61. return rmse
  62. climates = {
  63. 't2m': 3.1084048748016357,
  64. 'u10': 4.114771819114685,
  65. 'v10': 4.184110546112061,
  66. 'msl': 729.5839385986328,
  67. 'tp': 0.49046186606089276,
  68. }
  69. def run_eval(output, target):
  70. '''
  71. output: (batch x step x channel x lat x lon), eg: N x 20 x 5 x H x W
  72. target: (batch x step x channel x lat x lon), eg: N x 20 x 5 x H x W
  73. '''
  74. result = {}
  75. output = output.detach()
  76. target = target.detach()
  77. for cid, (name, clim) in enumerate(climates.items()):
  78. res = []
  79. for sid in range(output.shape[1]):
  80. out = output[:, sid, cid] # [N, H, W] 每个时间步的每个特征
  81. tgt = target[:, sid, cid]
  82. rmse = compute_rmse(out, tgt) #
  83. rmse = reduce_tensor(rmse).item()
  84. # rmse = rmse.to(torch.device("cpu"))
  85. nrmse = (rmse - clim) / clim
  86. res.append(nrmse)
  87. score = max(0, -np.mean(res))
  88. result[name] = float(score)
  89. score = np.mean(list(result.values()))
  90. result['score'] = float(score)
  91. return result
  92. def run_eval_valid(output, target):
  93. '''
  94. output: (batch x step x channel x lat x lon), eg: N x 20 x 5 x H x W
  95. target: (batch x step x channel x lat x lon), eg: N x 20 x 5 x H x W
  96. '''
  97. result = {}
  98. output = output.detach()
  99. target = target.detach()
  100. for cid, (name, clim) in enumerate(climates.items()):
  101. res = []
  102. for sid in range(output.shape[1]):
  103. out = output[:, sid, cid] # [N, H, W] 每个时间步的每个特征
  104. tgt = target[:, sid, cid]
  105. rmse = compute_rmse(out, tgt) #
  106. # rmse = reduce_tensor(rmse).item()
  107. rmse = rmse.to(torch.device("cpu"))
  108. nrmse = (rmse - clim) / clim
  109. res.append(nrmse)
  110. score = max(0, -np.mean(res))
  111. result[name] = float(score)
  112. score = np.mean(list(result.values()))
  113. result['score'] = float(score)
  114. return result
  115. def average_score(RMSE_list, key):
  116. score = 0
  117. for RMSE in RMSE_list:
  118. score += RMSE[key]
  119. return score/len(RMSE_list)
  120. def train_one_epoch(epoch, model, criterion, data_loader, graph, optimizer, predict_steps, weight, lat_weight, device="cuda:0"):
  121. # teacher_forcing_rate = 0.5
  122. loss_all = torch.tensor(0.).to(device)
  123. count = torch.tensor(0.).to(device)
  124. score_all = torch.tensor(0.).to(device)
  125. model.train() # torch.nn.Module 的一个方法
  126. # input_x is preparing a list of tensors to be used as input for a model, likely a graph-based model given the variable names. Each tensor represents a different aspect of the input data.
  127. input_x = [
  128. None,
  129. graph.mesh_data.x.half().cuda(non_blocking=True),
  130. graph.mesh_data.edge_index.cuda(non_blocking=True),
  131. graph.mesh_data.edge_attr.half().cuda(non_blocking=True),
  132. graph.grid2mesh_data.edge_index.cuda(non_blocking=True),
  133. graph.grid2mesh_data.edge_attr.half().cuda(non_blocking=True),
  134. graph.mesh2grid_data.edge_index.cuda(non_blocking=True),
  135. graph.mesh2grid_data.edge_attr.half().cuda(non_blocking=True)
  136. ]
  137. '''
  138. weight = get_weight(args) # [batch, channel, h, w]
  139. weight = weight.unsqueeze(1).to(device) # [batch, 1, channel, h, w]
  140. # diff_std = get_diff_std(args)
  141. # diff_std = diff_std.unsqueeze(1).to(device)# [batch, 1, channel, h, w]
  142. '''
  143. scaler = GradScaler()
  144. for step, batch in enumerate(data_loader):
  145. # 从 batch 中取出 x 和 y
  146. x, y = [x.half().cuda(non_blocking=True) for x in batch]
  147. y = y[:, :predict_steps, ...]
  148. input_x[0] = x
  149. bs,ts,c,h,w = x.shape
  150. # print(bs)
  151. pred_list = []
  152. optimizer.zero_grad()
  153. for t in range(predict_steps):
  154. # optimizer.zero_grad()
  155. with torch.cuda.amp.autocast():
  156. # out = model(input_x)
  157. out = model(*input_x)
  158. out = out.reshape(bs, h,w, c).permute(0, 3, 1, 2) # [bs, c, h, w]
  159. out = out.unsqueeze(1)
  160. pred_list.append(out + x[:, 1:, ...]) # [bs, 1, c, h, w]
  161. x = torch.concat([x[:,1:,...], x[:,1:,...]+out], dim=1)
  162. input_x[0] = x
  163. pred = torch.concat(pred_list,dim=1)
  164. loss = criterion(pred*weight*lat_weight, y*weight*lat_weight)
  165. # print(f'step {step}, loss:{loss}')
  166. # loss_all += loss
  167. loss_all += reduce_tensor(loss).item()#有多个进程,把进程0和1的loss加起来平均
  168. # print(f'step {step}, loss:{loss_all}')
  169. count += 1
  170. scaler.scale(loss).backward()
  171. scaler.step(optimizer)
  172. scaler.update()
  173. # RMSE
  174. score = run_eval(pred[...,-5:,30:-30,30:-30], y[...,-5:,30:-30,30:-30])
  175. score_all += score["score"]
  176. #score_all += reduce_tensor(torch.tensor(score["score"]).to(device)).item()
  177. if step % 200 == 0 and device==0:
  178. print("Step: ", step, " | Training Aver Loss:", (loss_all/count).item(), " | Train Eval Score: ", (score_all/count).item(), flush=True)
  179. return loss_all/count
  180. @torch.no_grad()
  181. def graphcast_evaluate(data_loader, graph, model, criterion, predict_steps, device="cuda:0"):
  182. loss_all = torch.tensor(0.).to(device)
  183. count = torch.tensor(0.).to(device)
  184. score_all = torch.tensor(0.).to(device)
  185. input_x = [
  186. None, # gx
  187. graph.mesh_data.x.half().cuda(non_blocking=True), #mx
  188. graph.mesh_data.edge_index.cuda(non_blocking=True), # me_i
  189. graph.mesh_data.edge_attr.half().cuda(non_blocking=True), # me_x
  190. graph.grid2mesh_data.edge_index.cuda(non_blocking=True), # g2me_i
  191. graph.grid2mesh_data.edge_attr.half().cuda(non_blocking=True), # g2me_x
  192. graph.mesh2grid_data.edge_index.cuda(non_blocking=True),# m2ge_i
  193. graph.mesh2grid_data.edge_attr.half().cuda(non_blocking=True) # m2ge_x
  194. ]
  195. # switch to evaluation mode
  196. model.eval()
  197. for step, batch in enumerate(data_loader):
  198. pred_list = []
  199. x, y = [x.half().cuda(non_blocking=True) for x in batch]
  200. y = y[:, :predict_steps, ...]
  201. input_x[0] = x
  202. bs,ts,c,h,w = x.shape
  203. # y [batch, time(20), channel(70), h, w]
  204. for t in range(predict_steps):
  205. with torch.cuda.amp.autocast():
  206. out = model(*input_x)
  207. out = out.reshape(bs, h, w, c).permute(0, 3, 1, 2) # [bs, c, h, w]
  208. out = out.unsqueeze(1)
  209. pred_list.append(out + x[:, 1:, ...]) # [bs, 1, c, h, w]
  210. x = torch.concat([x[:,1:,...], x[:,1:,...]+out], dim=1)
  211. input_x[0] = x
  212. pred = torch.concat(pred_list,dim=1)
  213. loss = criterion(pred[:,:,-5:,...], y[:,:,-5:,...])
  214. loss_all += loss.item()
  215. # loss_all += reduce_tensor(loss)#有多个进程,把进程0和1的loss加起来平均
  216. count += 1
  217. score = run_eval_valid(pred[...,-5:,30:-30,30:-30], y[...,-5:,30:-30,30:-30])
  218. score_all += score["score"]
  219. # score_all += reduce_tensor(torch.tensor(score["score"]).to(device)).item()
  220. if step % 200 == 0 and device==0:
  221. print("Step: ", step, " | Valid Aver Loss:", (loss_all/count).item(), " | Valid Eval Score: ", (score_all/count).item(), flush=True)
  222. return loss_all / count, score_all/count
  223. def get_weight(args):
  224. levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
  225. k = 1 / np.sum(levels)
  226. weight = torch.zeros(13)
  227. for i, level in enumerate(levels):
  228. weight[i] = level*k
  229. weight = torch.concat([weight.repeat(5), torch.ones(5)])
  230. weight = torch.sqrt(weight)
  231. weight = weight.reshape(1, 70, 1, 1).repeat(args.batch_size, 1, 161, 161)
  232. return weight
  233. def get_lat_lon_weight(args):
  234. weight = torch.ones(args.grid_node_num).reshape(161, 161)
  235. for i in range(30):
  236. if i == 0:
  237. weight[i, :] = 0.1 + i * 0.03
  238. weight[-(i+1), :] = 0.1 + i * 0.03
  239. weight[:, i] = 0.1 + i * 0.03
  240. weight[:, -(i+1)] = 0.1 + i * 0.03
  241. else:
  242. weight[i, i:-i] = 0.1 + i * 0.03
  243. weight[-(i+1), i:-i] = 0.1 + i * 0.03
  244. weight[i:-i, i] = 0.1 + i * 0.03
  245. weight[i:-i, -(i+1)] = 0.1 + i * 0.03
  246. weight = torch.sqrt(weight)
  247. weight = weight.unsqueeze(0).unsqueeze(0).repeat(args.batch_size, 1, 1, 1)
  248. return weight
  249. def get_lat_weight(lat, args):
  250. diff = np.diff(lat)
  251. if not np.all(np.isclose(diff[0], diff)):
  252. raise ValueError(f'Vector {diff} is not uniformly spaced.')
  253. delta_latitude = np.abs(diff[0])
  254. # print(delta_latitude)
  255. weights = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(delta_latitude/2))
  256. # print(weights)
  257. weights[0] = np.sin(np.deg2rad(delta_latitude/4)) * np.cos(np.deg2rad(50 - delta_latitude/4))
  258. weights[-1] = np.sin(np.deg2rad(delta_latitude/4)) * np.cos(np.deg2rad(10 + delta_latitude/4))
  259. # print(weights)
  260. weights = weights / weights.mean()
  261. weights = np.sqrt(weights)
  262. weights = torch.tensor(weights).reshape(1, 1, 1, -1, 1).repeat(args.batch_size, 1, 1, 1, 161)
  263. return weights
  264. def get_diff_std(args):
  265. with open("./scaler.pkl", "rb") as f:
  266. pkl = pickle.load(f)
  267. channels = pkl["channels"]
  268. std_r = pkl["std"]
  269. std = torch.tensor(std_r)
  270. std = std.reshape(1, 70, 1, 1).repeat(args.batch_size, 1, 161, 161)
  271. return std
  272. def reduce_tensor(tensor: torch.Tensor):
  273. rt = tensor.clone()
  274. '''
  275. dist.all_reduce(rt,op=dist.ReduceOp.SUM)
  276. rt /= dist.get_world_size() # 总进程数
  277. '''
  278. rt /= 1 # 总进程数 随便写的
  279. return rt
  280. def train(local_rank, args, ds, num_data , port):
  281. '''
  282. Args:
  283. local_rank: 本地进程编号
  284. rank: 进程的global编号
  285. local_size: 每个节点几个进程
  286. word_size: 进程总数
  287. port: 空闲端口,设置空闲端口,用于多线程通信
  288. '''
  289. # 初始化
  290. print("初始化")
  291. rank = local_rank
  292. gpu = local_rank
  293. torch.cuda.set_device(gpu)
  294. '''
  295. dist.init_process_group("nccl",
  296. init_method=f"tcp://127.0.0.1:{port}", # 设置空闲端口,用于多线程通信 init_method=f"tcp://localhost:{port}" init_method="tcp://localhost:22355",
  297. rank=rank,
  298. world_size=args.world_size)
  299. # parser.add_argument("--world_size", default=3, type=int)
  300. # 但是这里的端口似乎不能随机选,避免和 args 里面的冲突
  301. # 在您的培训计划中,您应该在开始时调用以下函数来启动分布式后端。强烈建议init_method=env://。其他init方法(例如tcp://)可能有效,但env://是本模块正式支持的方法 https://pytorch.org/docs/stable/distributed.html#launch-utility
  302. '''
  303. # generate graph
  304. print("生成图")
  305. if os.path.exists("./EarthGraph"):
  306. graph = torch.load("./EarthGraph")
  307. else:
  308. graph = EarthGraph()
  309. graph.generate_graph()
  310. torch.save(graph, "./EarthGraph")
  311. args.grid2mesh_edge_num = graph.grid2mesh_data.num_edges
  312. args.mesh2grid_edge_num = graph.mesh2grid_data.num_edges
  313. args.mesh_edge_num = graph.mesh_data.num_edges
  314. # 模型初始化
  315. print("模型初始化")
  316. model = GraphCast(args).cuda()
  317. optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay)
  318. lr_scheduler, _ = create_scheduler(args, optimizer) # 删了??不知道写这个啥意思
  319. # lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
  320. criterion = nn.MSELoss()
  321. start_epoch, start_step, min_loss = load_model(model, optimizer, lr_scheduler, path=SAVE_PATH / 'latest.pt')
  322. num_gpus = torch.cuda.device_count()
  323. if num_gpus > 1:
  324. model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
  325. train_num_data = int(num_data * 0.9)
  326. valid_num_data = int(num_data * 0.1)
  327. train_ds_time = ds.time.values[slice(1, train_num_data)]
  328. valid_ds_time = ds.time.values[slice(train_num_data, train_num_data+valid_num_data)]
  329. train_ds = ds.sel(time=train_ds_time)
  330. valid_ds = ds.sel(time=valid_ds_time)
  331. # dataset 初始化
  332. print("dataset初始化")
  333. # 训练的时候可能会用到 “shuffle”? 是否需要打乱顺序训练呢?
  334. train_dataset = ERA5(train_ds, output_window_size=args.predict_steps) # 对 train_ds 进行包装 torch.utils.data.Dataset()
  335. #train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) # 帮助维持多线程计算的,先注释掉 # This is a sampler that restricts data loading to a subset of the dataset. It is useful in scenarios where you're doing multi-process training and want to split the data across the processes. In this case, it's being used to split the train_dataset across the available processes.
  336. train_loader = DataLoader(dataset=train_dataset,
  337. batch_size=args.batch_size,
  338. drop_last=True, shuffle=False, num_workers=0, pin_memory = True) # 这里去掉一个参数。这是个性化采样:sampler=train_sampler,
  339. valid_dataset=ERA5(valid_ds, output_window_size=args.predict_steps)
  340. # valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset)
  341. valid_loader = DataLoader(dataset=valid_dataset,
  342. batch_size=args.batch_size,
  343. drop_last=True, num_workers=0, pin_memory = True)
  344. # load
  345. # start_epoch, start_step, min_loss = load_model(model, optimizer, path=SAVE_PATH / 'latest.pt')
  346. max_score = 0.72
  347. # start_epoch = 0
  348. # min_loss = 100
  349. #计算weight和lat_weight
  350. print("计算weight和lat_weight")
  351. weight = get_weight(args) # [batch, channel, h, w]
  352. weight = weight.unsqueeze(1).to(gpu) # [batch, 1, channel, h, w]
  353. lat_lon_weight = get_lat_lon_weight(args)
  354. lat_lon_weight = lat_lon_weight.unsqueeze(1).to(gpu)
  355. lat_values = ds.lat.values
  356. lat_weight = get_lat_weight(lat_values, args).to(gpu)
  357. print("开始训练")
  358. for epoch in range(start_epoch, args.epochs):
  359. # train_sampler.set_epoch(epoch) 用于多线程的? 在 OpenCasrKit 中没有
  360. train_loss = train_one_epoch(epoch, model, criterion, train_loader, graph, optimizer, args.predict_steps, weight, lat_weight, device=gpu) # device=gpu
  361. # 删了
  362. lr_scheduler.step(epoch)
  363. # save_model(model, epoch + 1, optimizer=optimizer, lr_scheduler=lr_scheduler, min_loss=min_loss, path= SAVE_PATH / 'latest.pt')
  364. # val_loss, val_score = graphcast_evaluate(valid_loader, graph, model, criterion, args.predict_steps, device=gpu)
  365. # print(f"Epoch {epoch} | LR: {optimizer.param_groups[0]['lr']:.6f} | Train loss: {train_loss.item():.6f} | Val loss: {val_loss.item():.6f}, Val score: {val_score.item():.6f}")
  366. # save model
  367. if gpu == 0:
  368. val_loss, val_score = graphcast_evaluate(valid_loader, graph, model, criterion, args.predict_steps, device=gpu)
  369. print(f"Epoch {epoch} | LR: {optimizer.param_groups[0]['lr']:.6f} | Train loss: {train_loss.item():.6f} | Val loss: {val_loss.item():.6f}, Val score: {val_score.item():.6f}", flush=True)
  370. save_model(model, epoch + 1, optimizer=optimizer, lr_scheduler=lr_scheduler, min_loss=min_loss, path= SAVE_PATH / 'latest.pt')
  371. if val_score > max_score:
  372. max_score = val_score
  373. min_loss = val_loss
  374. save_model(model, path=SAVE_PATH / f'epoch{epoch+1}_{val_score:.6f}_best.pt', min_loss=min_loss, only_model=True)
  375. #dist.barrier()
  376. # lr_scheduler.step(max_score)
  377. if __name__=="__main__":
  378. #
  379. free_port = find_free_port()
  380. print(f"Free port for args: {free_port}")
  381. args = get_graphcast_args( free_port )
  382. # ds = load_dataset().x
  383. ds = load_dataset()
  384. # shape = ds.shape # batch x channel x lat x lon
  385. # ---制作 fake data---
  386. times = 10 # ds时间的倍数
  387. ds_temp1 = ds.copy(deep=True)
  388. ds_temp2 = ds.copy(deep=True)
  389. ds_fake = xr.concat([ds_temp1, ds_temp2], dim="time")
  390. if times > 2 :
  391. with tqdm(total=times-2) as pbar:
  392. for i in range(times-2):
  393. ds_fake = xr.concat([ds_fake, ds_temp1], dim="time")
  394. pbar.update(1)
  395. # Assuming ds is your original dataset
  396. original_time = ds.coords['time']
  397. new_time_length = len(ds_fake['time'])
  398. # Create new time coordinate
  399. new_time = np.arange(start=original_time.values[0], stop=original_time.values[0] + np.timedelta64(new_time_length, 'h'), step=np.timedelta64(6, 'h'))
  400. ds_fake = ds_fake.assign_coords(time=new_time)
  401. # --------------------
  402. times = ds_fake.time.values
  403. # times = ds.time.values
  404. #
  405. free_port = find_free_port()
  406. print(f"Free port for train: {free_port}")
  407. init_times = times[slice(1, -21)]
  408. num_data = len(init_times)
  409. torch.manual_seed(2023)
  410. np.random.seed(2023)
  411. cudnn.benchmark = True
  412. # train(args.gpuid,args)
  413. #mp.spawn(train, args=(args, ds, num_data, free_port), nprocs=2, join=True) # nprocs=3
  414. # ----20240229----
  415. # Initialize the distributed training environment
  416. #os.environ['RANK'] = '0' #
  417. #dist.init_process_group(backend='nccl')
  418. # Get the local rank
  419. local_rank = 0
  420. # 如果不删除 /root/output/graphcast-torch 存储的权重,会意外跳出循环。
  421. train(local_rank, args, ds_fake, num_data, free_port)
  422. # ----20240229----