Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 232 additions & 16 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import queue
import random
import sys
import time
import warnings
Expand All @@ -23,10 +24,13 @@
from functools import partial
from typing import Callable

import numpy as np

import paddle
from paddle import framework

from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
from ..recompute.recompute import detach_variable, switch_rng_state_tracker
from ..utils import timer_helper as timer
from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
Expand All @@ -49,6 +53,9 @@
from .pp_utils import p2p_communication as p2p

from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
from paddle.distributed.fleet.utils.tensor_fusion_helper import (
HOOK_ACTION,
FusedCommBuffer,
Expand Down Expand Up @@ -495,6 +502,17 @@ def __init__(self, layers, hcg, strategy):
# only support user hooks during training
self.user_hooks_enabled = True

self.full_recompute_overlap = True
if self.full_recompute_overlap:
# preserve = kwargs.pop('preserve_rng_state', True)
# 需要单独验证preserve_rng_state为True的行为是否正确
self.preserve_rng_state = False
# offload_indices = kwargs.pop('offload_indices', [])
self.offload_indices = []
self.custom_get_state_func = lambda x=None: None
self.custom_set_state_func = lambda x=None: None
self.state_buffers = []

def register_hook(
self, location: PipelineParallelMicroStepLocations, hook: Callable
):
Expand Down Expand Up @@ -750,6 +768,173 @@ def _flush_records(self):
f.writelines(record + '\n' for record in self._records)
self._records = []

def save_state(self, inputs, chunk_id, is_pipeline_first_stage):
state = {}
# inputs有可能是dict或者是tensor
if isinstance(inputs, paddle.Tensor):
inputs = (inputs,)
else:
# if inputs is dict, split its values as inputs
inputs = tuple(inputs.values())
state["inputs_keys"] = tuple(inputs.keys())
state["chunk_id"] = chunk_id
if self.preserve_rng_state:
state["fw_rng_state"] = paddle.get_rng_state()
state["fwd_rng_state_tracker"] = (
get_rng_state_tracker().get_states_tracker()
)
state["fwd_numpy_state"] = np.random.get_state()
state["fwd_random_state"] = random.getstate()
state["fwd_custom_state"] = self.custom_get_state_func()
state["custom_get_state_func"] = self.custom_get_state_func
state["custom_set_state_func"] = self.custom_set_state_func
tracer = framework._dygraph_tracer()
state["is_fw_autocast"] = (
False if tracer._amp_level == framework.core.AmpLevel.O0 else True
)
if tracer._amp_level == framework.core.AmpLevel.O2:
state["amp_level"] = 'O2'
elif tracer._amp_level in (
framework.core.AmpLevel.O1,
framework.core.AmpLevel.O0,
):
state["amp_level"] = 'O1'
else:
raise ValueError(f"unsupported amp level: {tracer._amp_level}")

if tracer._amp_dtype == 'float16':
state["amp_dtype"] = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
state["amp_dtype"] = 'bfloat16'
else:
raise ValueError(f"unsupported amp dtype: {tracer._amp_dtype}")
state["amp_white_list"], state["amp_black_list"] = (
tracer._get_amp_op_list()
)

state["tensor_indices"] = []
state["tensor_inputs"] = []
state["inputs"] = []
for i, input_tensor in enumerate(inputs):
if paddle.is_tensor(input_tensor):
if i in self.offload_indices:
cpu_tensor = (
input_tensor.pin_memory()
if framework.core.is_compiled_with_cuda()
else input_tensor.cpu()
)
cpu_tensor._share_buffer_to(input_tensor)
if not is_pipeline_first_stage:
# with nograd之后算出来的每个input_tensor都是stop_gradient为true的,要进行重新设置
input_tensor.stop_gradient = False
state["tensor_inputs"].append(input_tensor)
state["tensor_indices"].append(i)
state["inputs"].append(None)
elif type(input_tensor) is tuple:
assert i not in self.offload_indices, (
f"offload_indices should not contain tensor tuple in position{i}"
)
is_tensors = [paddle.is_tensor(a) for a in input_tensor]
if all(is_tensors):
# the tuple is a tuple of tensors
tensors_stop_gradient = [
a.stop_gradient for a in input_tensor
]
if not all(tensors_stop_gradient) and any(
tensors_stop_gradient
):
# tensors in the tuple have different stop_gradient value, which pylayer doesn't support
raise ValueError(
"Recompute receive a tuple containing tensor holds different stop gradient."
)
state["tensor_inputs"].append(input_tensor)
state["tensor_indices"].append(i)
state["inputs"].append(None)
elif any(is_tensors):
# the tuple contains tensors and non-tensor values
raise ValueError(
"Recompute receive a tuple containing tensor and non-tensor at same time."
)
else:
state["inputs"].append(input_tensor)
else:
state["inputs"].append(input_tensor)
self.state_buffers.append(state)

def load_state_and_forward(self, input_tensor):
state = self.state_buffers.pop(0)
chunk_id = state["chunk_id"]
with paddle.base.dygraph.guard():
inputs = list(state["inputs"])
tensor_indices = state["tensor_indices"]
tensors = state["tensor_inputs"]
for i, idx in enumerate(tensor_indices):
inputs[idx] = (
tensors[i].to(
paddle.base.framework._current_expected_place()
)
if i in self.offload_indices
else tensors[i]
)
if i in self.offload_indices:
# NOTE(zhiqiu): tensor.to(device) will set stop_gradient=True, which may break the gragh
inputs[idx].stop_gradient = tensors[i].stop_gradient
tracer = framework._dygraph_tracer()
tracer._has_grad = True

# NOTE support AMP
# need restore auto_cast state as well as w/b list
if self.preserve_rng_state:
with (
switch_rng_state_tracker(
state["fw_rng_state"],
state["fwd_rng_state_tracker"],
state["fwd_numpy_state"],
state["fwd_random_state"],
state["fwd_custom_state"],
state["custom_get_state_func"],
state["custom_set_state_func"],
),
paddle.amp.auto_cast(
enable=state["is_fw_autocast"],
custom_white_list=state["amp_white_list"],
custom_black_list=state["amp_black_list"],
level=state["amp_level"],
dtype=state["amp_dtype"],
),
):
detached_inputs = detach_variable(tuple(inputs))
if "inputs_keys" in state:
# form detached_inputs to dict, keys:state["inputs_keys"] values:detached_inputs
final_input = dict(
zip(state["inputs_keys"], detach_variable)
)
else:
final_input = detached_inputs[0]
outputs = self._layers.forward(
final_input, chunk_id=chunk_id
)
else:
with paddle.amp.auto_cast(
enable=state["is_fw_autocast"],
custom_white_list=state["amp_white_list"],
custom_black_list=state["amp_black_list"],
level=state["amp_level"],
dtype=state["amp_dtype"],
):
detached_inputs = detach_variable(tuple(inputs))
if "inputs_keys" in state:
# form detached_inputs to dict, keys:state["inputs_keys"] values:detached_inputs
final_input = dict(
zip(state["inputs_keys"], detach_variable)
)
else:
final_input = detached_inputs[0]
outputs = self._layers.forward(
final_input, chunk_id=chunk_id
)
return final_input, outputs

def forward_backward_pipeline(
self,
data,
Expand Down Expand Up @@ -809,7 +994,6 @@ def forward_backward_pipeline(
self.is_pipeline_first_stage(),
batch_p2p_comm=self._use_batch_p2p_comm,
)

input_tensor_dict, use_dict = tuple_to_dict_helper(input_tensor)

self._record_stamp("F", step_id, '"B"', self._forward_color)
Expand Down Expand Up @@ -851,6 +1035,7 @@ def forward_backward_pipeline(
continue
last_iter = i == (steady_steps - 1)

# 如果input_tensor是一个tuple,则把input_tensor转成字典,key是tensor.key,value是对应的tensor
input_tensor_dict, use_dict = tuple_to_dict_helper(input_tensor)

self._record_stamp(
Expand Down Expand Up @@ -1250,27 +1435,40 @@ def _forward_step(
input_tensor=input_tensor,
step_id=step_id,
)

schedule_chunk = None
if overlap_schedule_mode:
schedule_chunk = self._layers.get_schedule_chunk(chunk_id=chunk_id)
output_tensor = schedule_chunk.forward(input_tensor)
if self.full_recompute_overlap:
self.save_state(
input_tensor, chunk_id, self.is_pipeline_first_stage()
)
with paddle.no_grad():
output_tensor = schedule_chunk.forward(input_tensor)
else:
output_tensor = schedule_chunk.forward(input_tensor)
else:
output_tensor = self._layers.forward(
input_tensor, chunk_id=chunk_id
)
if self.full_recompute_overlap:
self.save_state(
input_tensor, chunk_id, self.is_pipeline_first_stage()
)
with paddle.no_grad():
output_tensor = self._layers.forward(
input_tensor, chunk_id=chunk_id
)
else:
output_tensor = self._layers.forward(
input_tensor, chunk_id=chunk_id
)

self.callbacks.on_location(
PipelineParallelMicroStepLocations.FORWARD_END,
input_tensor=input_tensor,
output_tensor=output_tensor,
step_id=step_id,
)

backward_loss_tensor, backward_loss_fn_node = self._maybe_loss_compute(
output_tensor, micro_dataset, overlap_schedule_mode
)

if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
# Only increase micro batch id at virtual first/last pp stage.
# The micro batch id is used to load data, therefore, only increase it when load data.
Expand All @@ -1296,6 +1494,7 @@ def _backward_step(
schedule_chunk=None,
loss_fn_node=None,
):
# 需要考虑下overlap_schedule_mode这个情况下梯度的传递是不是有问题
if self.user_hooks_enabled:
self.backward_hooks.run_hook()
if self._enable_timer:
Expand All @@ -1304,6 +1503,11 @@ def _backward_step(
profile_pipeline_details(
f"[Pipeline details] Before_backward_step_chunk_{chunk_id}_step_{step_id}"
)
if self.full_recompute_overlap:
final_input, output_tensor_recompute = self.load_state_and_forward(
input_tensor
)
output_tensor = output_tensor_recompute
with paddle.amp.auto_cast(enable=False):
self.callbacks.on_location(
PipelineParallelMicroStepLocations.BACKWARD_BEGIN,
Expand Down Expand Up @@ -1362,15 +1566,27 @@ def _backward_step(
input_tensor_grad = None
if input_tensor is not None:
if isinstance(input_tensor, tuple):
input_tensor_grad = tuple(
[
t.grad
for t in input_tensor
if not t.stop_gradient
]
)
if self.full_recompute_overlap:
input_tensor_grad = tuple(
[
t.grad
for t in final_input
if not t.stop_gradient
]
)
else:
input_tensor_grad = tuple(
[
t.grad
for t in input_tensor
if not t.stop_gradient
]
)
else:
input_tensor_grad = input_tensor.grad
if self.full_recompute_overlap:
input_tensor_grad = final_input.grad
else:
input_tensor_grad = input_tensor.grad
if self._enable_timer:
self.timers("backward_step").stop()
self.callbacks.on_location(
Expand Down
Loading