diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 027a734eedd141..5d9fa7fe658f6c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -14,6 +14,7 @@ import os import queue +import random import sys import time import warnings @@ -23,10 +24,13 @@ from functools import partial from typing import Callable +import numpy as np + import paddle from paddle import framework from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer +from ..recompute.recompute import detach_variable, switch_rng_state_tracker from ..utils import timer_helper as timer from ..utils.hybrid_parallel_util import ( broadcast_dp_parameters, @@ -49,6 +53,9 @@ from .pp_utils import p2p_communication as p2p from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel.parallel_layers.random import ( + get_rng_state_tracker, +) from paddle.distributed.fleet.utils.tensor_fusion_helper import ( HOOK_ACTION, FusedCommBuffer, @@ -495,6 +502,17 @@ def __init__(self, layers, hcg, strategy): # only support user hooks during training self.user_hooks_enabled = True + self.full_recompute_overlap = True + if self.full_recompute_overlap: + # preserve = kwargs.pop('preserve_rng_state', True) + # 需要单独验证preserve_rng_state为True的行为是否正确 + self.preserve_rng_state = False + # offload_indices = kwargs.pop('offload_indices', []) + self.offload_indices = [] + self.custom_get_state_func = lambda x=None: None + self.custom_set_state_func = lambda x=None: None + self.state_buffers = [] + def register_hook( self, location: PipelineParallelMicroStepLocations, hook: Callable ): @@ -750,6 +768,173 @@ def _flush_records(self): f.writelines(record + '\n' for record in self._records) self._records = [] + def save_state(self, inputs, chunk_id, is_pipeline_first_stage): + state = {} + # inputs有可能是dict或者是tensor + if isinstance(inputs, paddle.Tensor): + inputs = (inputs,) + else: + # if inputs is dict, split its values as inputs + inputs = tuple(inputs.values()) + state["inputs_keys"] = tuple(inputs.keys()) + state["chunk_id"] = chunk_id + if self.preserve_rng_state: + state["fw_rng_state"] = paddle.get_rng_state() + state["fwd_rng_state_tracker"] = ( + get_rng_state_tracker().get_states_tracker() + ) + state["fwd_numpy_state"] = np.random.get_state() + state["fwd_random_state"] = random.getstate() + state["fwd_custom_state"] = self.custom_get_state_func() + state["custom_get_state_func"] = self.custom_get_state_func + state["custom_set_state_func"] = self.custom_set_state_func + tracer = framework._dygraph_tracer() + state["is_fw_autocast"] = ( + False if tracer._amp_level == framework.core.AmpLevel.O0 else True + ) + if tracer._amp_level == framework.core.AmpLevel.O2: + state["amp_level"] = 'O2' + elif tracer._amp_level in ( + framework.core.AmpLevel.O1, + framework.core.AmpLevel.O0, + ): + state["amp_level"] = 'O1' + else: + raise ValueError(f"unsupported amp level: {tracer._amp_level}") + + if tracer._amp_dtype == 'float16': + state["amp_dtype"] = 'float16' + elif tracer._amp_dtype in ('bfloat16', 'float32'): + state["amp_dtype"] = 'bfloat16' + else: + raise ValueError(f"unsupported amp dtype: {tracer._amp_dtype}") + state["amp_white_list"], state["amp_black_list"] = ( + tracer._get_amp_op_list() + ) + + state["tensor_indices"] = [] + state["tensor_inputs"] = [] + state["inputs"] = [] + for i, input_tensor in enumerate(inputs): + if paddle.is_tensor(input_tensor): + if i in self.offload_indices: + cpu_tensor = ( + input_tensor.pin_memory() + if framework.core.is_compiled_with_cuda() + else input_tensor.cpu() + ) + cpu_tensor._share_buffer_to(input_tensor) + if not is_pipeline_first_stage: + # with nograd之后算出来的每个input_tensor都是stop_gradient为true的,要进行重新设置 + input_tensor.stop_gradient = False + state["tensor_inputs"].append(input_tensor) + state["tensor_indices"].append(i) + state["inputs"].append(None) + elif type(input_tensor) is tuple: + assert i not in self.offload_indices, ( + f"offload_indices should not contain tensor tuple in position{i}" + ) + is_tensors = [paddle.is_tensor(a) for a in input_tensor] + if all(is_tensors): + # the tuple is a tuple of tensors + tensors_stop_gradient = [ + a.stop_gradient for a in input_tensor + ] + if not all(tensors_stop_gradient) and any( + tensors_stop_gradient + ): + # tensors in the tuple have different stop_gradient value, which pylayer doesn't support + raise ValueError( + "Recompute receive a tuple containing tensor holds different stop gradient." + ) + state["tensor_inputs"].append(input_tensor) + state["tensor_indices"].append(i) + state["inputs"].append(None) + elif any(is_tensors): + # the tuple contains tensors and non-tensor values + raise ValueError( + "Recompute receive a tuple containing tensor and non-tensor at same time." + ) + else: + state["inputs"].append(input_tensor) + else: + state["inputs"].append(input_tensor) + self.state_buffers.append(state) + + def load_state_and_forward(self, input_tensor): + state = self.state_buffers.pop(0) + chunk_id = state["chunk_id"] + with paddle.base.dygraph.guard(): + inputs = list(state["inputs"]) + tensor_indices = state["tensor_indices"] + tensors = state["tensor_inputs"] + for i, idx in enumerate(tensor_indices): + inputs[idx] = ( + tensors[i].to( + paddle.base.framework._current_expected_place() + ) + if i in self.offload_indices + else tensors[i] + ) + if i in self.offload_indices: + # NOTE(zhiqiu): tensor.to(device) will set stop_gradient=True, which may break the gragh + inputs[idx].stop_gradient = tensors[i].stop_gradient + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # NOTE support AMP + # need restore auto_cast state as well as w/b list + if self.preserve_rng_state: + with ( + switch_rng_state_tracker( + state["fw_rng_state"], + state["fwd_rng_state_tracker"], + state["fwd_numpy_state"], + state["fwd_random_state"], + state["fwd_custom_state"], + state["custom_get_state_func"], + state["custom_set_state_func"], + ), + paddle.amp.auto_cast( + enable=state["is_fw_autocast"], + custom_white_list=state["amp_white_list"], + custom_black_list=state["amp_black_list"], + level=state["amp_level"], + dtype=state["amp_dtype"], + ), + ): + detached_inputs = detach_variable(tuple(inputs)) + if "inputs_keys" in state: + # form detached_inputs to dict, keys:state["inputs_keys"] values:detached_inputs + final_input = dict( + zip(state["inputs_keys"], detach_variable) + ) + else: + final_input = detached_inputs[0] + outputs = self._layers.forward( + final_input, chunk_id=chunk_id + ) + else: + with paddle.amp.auto_cast( + enable=state["is_fw_autocast"], + custom_white_list=state["amp_white_list"], + custom_black_list=state["amp_black_list"], + level=state["amp_level"], + dtype=state["amp_dtype"], + ): + detached_inputs = detach_variable(tuple(inputs)) + if "inputs_keys" in state: + # form detached_inputs to dict, keys:state["inputs_keys"] values:detached_inputs + final_input = dict( + zip(state["inputs_keys"], detach_variable) + ) + else: + final_input = detached_inputs[0] + outputs = self._layers.forward( + final_input, chunk_id=chunk_id + ) + return final_input, outputs + def forward_backward_pipeline( self, data, @@ -809,7 +994,6 @@ def forward_backward_pipeline( self.is_pipeline_first_stage(), batch_p2p_comm=self._use_batch_p2p_comm, ) - input_tensor_dict, use_dict = tuple_to_dict_helper(input_tensor) self._record_stamp("F", step_id, '"B"', self._forward_color) @@ -851,6 +1035,7 @@ def forward_backward_pipeline( continue last_iter = i == (steady_steps - 1) + # 如果input_tensor是一个tuple,则把input_tensor转成字典,key是tensor.key,value是对应的tensor input_tensor_dict, use_dict = tuple_to_dict_helper(input_tensor) self._record_stamp( @@ -1250,15 +1435,30 @@ def _forward_step( input_tensor=input_tensor, step_id=step_id, ) - schedule_chunk = None if overlap_schedule_mode: schedule_chunk = self._layers.get_schedule_chunk(chunk_id=chunk_id) - output_tensor = schedule_chunk.forward(input_tensor) + if self.full_recompute_overlap: + self.save_state( + input_tensor, chunk_id, self.is_pipeline_first_stage() + ) + with paddle.no_grad(): + output_tensor = schedule_chunk.forward(input_tensor) + else: + output_tensor = schedule_chunk.forward(input_tensor) else: - output_tensor = self._layers.forward( - input_tensor, chunk_id=chunk_id - ) + if self.full_recompute_overlap: + self.save_state( + input_tensor, chunk_id, self.is_pipeline_first_stage() + ) + with paddle.no_grad(): + output_tensor = self._layers.forward( + input_tensor, chunk_id=chunk_id + ) + else: + output_tensor = self._layers.forward( + input_tensor, chunk_id=chunk_id + ) self.callbacks.on_location( PipelineParallelMicroStepLocations.FORWARD_END, @@ -1266,11 +1466,9 @@ def _forward_step( output_tensor=output_tensor, step_id=step_id, ) - backward_loss_tensor, backward_loss_fn_node = self._maybe_loss_compute( output_tensor, micro_dataset, overlap_schedule_mode ) - if self.is_pipeline_first_stage() or self.is_pipeline_last_stage(): # Only increase micro batch id at virtual first/last pp stage. # The micro batch id is used to load data, therefore, only increase it when load data. @@ -1296,6 +1494,7 @@ def _backward_step( schedule_chunk=None, loss_fn_node=None, ): + # 需要考虑下overlap_schedule_mode这个情况下梯度的传递是不是有问题 if self.user_hooks_enabled: self.backward_hooks.run_hook() if self._enable_timer: @@ -1304,6 +1503,11 @@ def _backward_step( profile_pipeline_details( f"[Pipeline details] Before_backward_step_chunk_{chunk_id}_step_{step_id}" ) + if self.full_recompute_overlap: + final_input, output_tensor_recompute = self.load_state_and_forward( + input_tensor + ) + output_tensor = output_tensor_recompute with paddle.amp.auto_cast(enable=False): self.callbacks.on_location( PipelineParallelMicroStepLocations.BACKWARD_BEGIN, @@ -1362,15 +1566,27 @@ def _backward_step( input_tensor_grad = None if input_tensor is not None: if isinstance(input_tensor, tuple): - input_tensor_grad = tuple( - [ - t.grad - for t in input_tensor - if not t.stop_gradient - ] - ) + if self.full_recompute_overlap: + input_tensor_grad = tuple( + [ + t.grad + for t in final_input + if not t.stop_gradient + ] + ) + else: + input_tensor_grad = tuple( + [ + t.grad + for t in input_tensor + if not t.stop_gradient + ] + ) else: - input_tensor_grad = input_tensor.grad + if self.full_recompute_overlap: + input_tensor_grad = final_input.grad + else: + input_tensor_grad = input_tensor.grad if self._enable_timer: self.timers("backward_step").stop() self.callbacks.on_location(