diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f1d8449 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + types: [python] + - id: trailing-whitespace + + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black \ No newline at end of file diff --git a/api/codegeex-api-example-python/generation_example.py b/api/codegeex-api-example-python/generation_example.py index 2a732dc..e36f6d4 100644 --- a/api/codegeex-api-example-python/generation_example.py +++ b/api/codegeex-api-example-python/generation_example.py @@ -4,29 +4,31 @@ import requests -''' +""" Code Generation -''' +""" API_KEY = "" # Get from Tianqi console. 从控制台获取 API_SECRET = "" # Get from Tianqi console. 从控制台获取 -PROMPT = "from typing import List\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n " \ - "\"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given " \ - "threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements(" \ - "[1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n" +PROMPT = ( + "from typing import List\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n " + '""" Check if in given list of numbers, are any two numbers closer to each other than\n given ' + "threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements(" + '[1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n' +) NUMBER = 3 LANG = "Python" request_url = "https://tianqi.aminer.cn/api/v2/" -api = 'multilingual_code_generate' +api = "multilingual_code_generate" # Request is in json format. 指定请求参数格式为json -headers = {'Content-Type': 'application/json'} +headers = {"Content-Type": "application/json"} request_url = request_url + api data = { "apikey": API_KEY, "apisecret": API_SECRET, "prompt": PROMPT, "n": NUMBER, - "lang": LANG + "lang": LANG, } @@ -36,5 +38,5 @@ def main(): print(response.json()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/codegeex/__init__.py b/codegeex/__init__.py index 0bc55c6..d358bd8 100644 --- a/codegeex/__init__.py +++ b/codegeex/__init__.py @@ -13,9 +13,9 @@ def get_model( def generate( - model, - tokenizer: CodeGeeXTokenizer, - prompt: str, + model, + tokenizer: CodeGeeXTokenizer, + prompt: str, out_seq_length: int, seq_length: int = 2048, top_k: int = 0, @@ -32,7 +32,7 @@ def generate( if verbose: print(f"Current prompt:\n{prompt}") print("N_token_prompt:", n_token_prompt) - + generated_codes = [] if backend == "megatron": token_stream = get_token_stream( @@ -53,17 +53,22 @@ def generate( for j in range(micro_batch_size): if is_finished[j]: continue - - if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len(generated_tokens[j]) >= out_seq_length: + + if ( + generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id + or len(generated_tokens[j]) >= out_seq_length + ): is_finished[j] = True generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() - generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) + generated_code = tokenizer.decode_code( + generated_tokens_[n_token_prompt:] + ) generated_code = "".join(generated_code) generated_codes.append(generated_code) if verbose: print(f"\nGenerated code {i}:\n{generated_code}") - + if all(is_finished): break - return generated_codes \ No newline at end of file + return generated_codes diff --git a/codegeex/benchmark/evaluate_humaneval_x.py b/codegeex/benchmark/evaluate_humaneval_x.py index 076e0e8..b84a4e0 100644 --- a/codegeex/benchmark/evaluate_humaneval_x.py +++ b/codegeex/benchmark/evaluate_humaneval_x.py @@ -16,10 +16,10 @@ from codegeex.benchmark.execution import check_correctness LANGUAGE_NAME = { - "cpp" : "CPP", - "go" : "Go", - "java" : "Java", - "js" : "JavaScript", + "cpp": "CPP", + "go": "Go", + "java": "Java", + "js": "JavaScript", "python": "Python", } @@ -29,7 +29,11 @@ def process_humaneval_test(sample, problems, example_test=False): language = task_id.split("/")[0].lower() prompt = sample["prompt"] - if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "": + if ( + example_test + and "example_test" in problems[task_id] + and problems[task_id]["example_test"] != "" + ): test = problems[task_id]["example_test"] else: test = problems[task_id]["test"] @@ -39,7 +43,7 @@ def process_humaneval_test(sample, problems, example_test=False): if language == "python": code_ = [] for line in code.split("\n"): - if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'): + if len(line.strip()) > 0 and line[0] != " " and line[0] != "\t": break code_.append(line) code = "\n".join(code_) @@ -68,10 +72,21 @@ def process_humaneval_test(sample, problems, example_test=False): if pkg not in test_setup: p = pkg.split("/")[-1] if p + "." in code: - other_pkgs.append(f"\"{pkg}\"") + other_pkgs.append(f'"{pkg}"') if other_pkgs: - import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")" - test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test + import_other_pkgs = ( + "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")" + ) + test_string = ( + test_setup + + "\n" + + import_other_pkgs + + "\n" + + prompt + + code + + "\n" + + test + ) else: test_string = test_setup + "\n" + prompt + code + "\n" + test elif language == "rust": @@ -97,21 +112,20 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]: def evaluate_functional_correctness( - input_file: str = None, - tmp_dir: str = "./", - n_workers: int = 32, - timeout: float = 500.0, - problem_file: str = "../data/humaneval_python.jsonl.gz", - out_dir: str = None, - k: List[int] = [1, 10, 100], - test_groundtruth: bool = False, - example_test: bool = False, + input_file: str = None, + tmp_dir: str = "./", + n_workers: int = 32, + timeout: float = 500.0, + problem_file: str = "../data/humaneval_python.jsonl.gz", + out_dir: str = None, + k: List[int] = [1, 10, 100], + test_groundtruth: bool = False, + example_test: bool = False, ): if example_test: print("Example test...") - problems = read_dataset(problem_file, - dataset_type="humaneval") + problems = read_dataset(problem_file, dataset_type="humaneval") sample_jsonl = stream_jsonl_all(input_file) if example_test: @@ -121,7 +135,9 @@ def evaluate_functional_correctness( if out_dir is not None: if not os.path.exists(out_dir): os.makedirs(out_dir) - out_file = os.path.join(out_dir, input_file.split('/')[-1].replace(".jsonl", suffix)) + out_file = os.path.join( + out_dir, input_file.split("/")[-1].replace(".jsonl", suffix) + ) else: out_file = os.path.join(input_file.replace(".jsonl", suffix)) @@ -149,10 +165,19 @@ def evaluate_functional_correctness( lang = "js" tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") sample["generation"] = sample["canonical_solution"] - sample["test_code"] = process_humaneval_test(sample, problems, example_test) + sample["test_code"] = process_humaneval_test( + sample, problems, example_test + ) if sample["test_code"] is None: continue - args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id]) + args = ( + task_id, + sample, + lang, + timeout, + tmp_dir_, + completion_id[task_id], + ) future = executor.submit(check_correctness, *args) futures.append(future) completion_id[task_id] += 1 @@ -164,7 +189,11 @@ def evaluate_functional_correctness( lang = task_id.split("/")[0].lower() if translation_mode: task_id = sample["task_id"].split("/")[-1] - lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-") + lang = ( + regex.findall("-to-.*-", input_file)[0] + .split("-to-")[-1] + .rstrip("-") + ) for l in LANGUAGE_NAME: if l in lang: lang = l @@ -174,7 +203,9 @@ def evaluate_functional_correctness( lang = "js" tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") sample["task_id"] = task_id - sample["test_code"] = process_humaneval_test(sample, problems, example_test) + sample["test_code"] = process_humaneval_test( + sample, problems, example_test + ) if sample["test_code"] is None: continue if "completion_id" in sample: @@ -208,8 +239,11 @@ def evaluate_functional_correctness( correct = np.array(correct) if evaluate_pass_at_k: ks = k - pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() - for k in ks if (total >= k).all()} + pass_at_k = { + f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() + for k in ks + if (total >= k).all() + } print(pass_at_k) else: print("Total:", np.sum(total)) @@ -222,7 +256,7 @@ def evaluate_functional_correctness( for r in res: fp.write((json.dumps(r[1]) + "\n").encode("utf-8")) else: - fp = open(out_file, 'w') + fp = open(out_file, "w") for res in results.values(): for r in res: fp.write(json.dumps(r[1]) + "\n") diff --git a/codegeex/benchmark/execution.py b/codegeex/benchmark/execution.py index cbdf14f..ae187b4 100644 --- a/codegeex/benchmark/execution.py +++ b/codegeex/benchmark/execution.py @@ -12,6 +12,7 @@ import json from typing import * + def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> None: """ Method saves list of dicts into jsonl file. @@ -20,34 +21,34 @@ def dicts_to_jsonl(data_list: list, filename: str, compress: bool = True) -> Non .jsonl suffix into the file. :param compress: (bool) should file be compressed into a gzip archive? """ - sjsonl = '.jsonl' - sgz = '.gz' + sjsonl = ".jsonl" + sgz = ".gz" # Check filename if not filename.endswith(sjsonl): filename = filename + sjsonl # Save data - + if compress: filename = filename + sgz - with gzip.open(filename, 'w') as compressed: + with gzip.open(filename, "w") as compressed: for ddict in data_list: - jout = json.dumps(ddict) + '\n' - jout = jout.encode('utf-8') + jout = json.dumps(ddict) + "\n" + jout = jout.encode("utf-8") compressed.write(jout) else: - with open(filename, 'w') as out: + with open(filename, "w") as out: for ddict in data_list: - jout = json.dumps(ddict) + '\n' + jout = json.dumps(ddict) + "\n" out.write(jout) def check_correctness( - task_id: str, - sample: dict, - language_type: str, - timeout: float = 3.0, - tmp_dir: str = None, - completion_id: Optional[int] = None, + task_id: str, + sample: dict, + language_type: str, + timeout: float = 3.0, + tmp_dir: str = None, + completion_id: Optional[int] = None, ) -> Dict: """ Evaluates the functional correctness of a completion by running the test @@ -62,6 +63,7 @@ def unsafe_execute(tmp_dir): # These system calls are needed when cleaning up tempdir. import os import shutil + rmtree = shutil.rmtree rmdir = os.rmdir chdir = os.chdir @@ -97,7 +99,9 @@ def unsafe_execute(tmp_dir): os.chdir = chdir elif "go" in language_type.lower(): - assert tmp_dir is not None, "Go should be evaluated in a dir where necessary module files installed." + assert ( + tmp_dir is not None + ), "Go should be evaluated in a dir where necessary module files installed." import os import shutil @@ -109,7 +113,7 @@ def unsafe_execute(tmp_dir): os.makedirs(tmp_dir) os.chdir(tmp_dir) - open(f"main_test.go", 'w').write(sample["test_code"]) + open(f"main_test.go", "w").write(sample["test_code"]) try: exec_result = None with time_limit(timeout): @@ -122,7 +126,11 @@ def unsafe_execute(tmp_dir): # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True) + exec_result = subprocess.run( + ["go", "test", f"-timeout={timeout}s", "main_test.go"], + timeout=timeout, + capture_output=True, + ) if exec_result.returncode == 0: result.append("passed") @@ -154,7 +162,7 @@ def unsafe_execute(tmp_dir): os.makedirs(tmp_dir) os.chdir(tmp_dir) - open(f"test.js", 'w').write(sample["test_code"]) + open(f"test.js", "w").write(sample["test_code"]) try: exec_result = None with time_limit(timeout): @@ -167,7 +175,9 @@ def unsafe_execute(tmp_dir): # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True) + exec_result = subprocess.run( + ["node", "test.js"], timeout=timeout, capture_output=True + ) if exec_result.stderr.decode(): err = exec_result.stderr.decode() @@ -193,14 +203,19 @@ def unsafe_execute(tmp_dir): os.makedirs(tmp_dir) os.chdir(tmp_dir) - open(f"test.cpp", 'w').write(sample["test_code"]) + open(f"test.cpp", "w").write(sample["test_code"]) if "162" in task_id: - compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"], - timeout=timeout, - capture_output=True) + compilation_result = subprocess.run( + ["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"], + timeout=timeout, + capture_output=True, + ) else: - compilation_result = subprocess.run(["/usr/bin/g++", "-std=c++11", "test.cpp"], timeout=timeout, - capture_output=True) + compilation_result = subprocess.run( + ["/usr/bin/g++", "-std=c++11", "test.cpp"], + timeout=timeout, + capture_output=True, + ) if compilation_result.returncode != 0: if compilation_result.stderr: err = compilation_result.stderr.decode() @@ -220,7 +235,9 @@ def unsafe_execute(tmp_dir): # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True) + exec_result = subprocess.run( + ["./a.out"], timeout=timeout, capture_output=True + ) if exec_result.returncode == 0: result.append("passed") @@ -240,16 +257,16 @@ def unsafe_execute(tmp_dir): result.append("timed out") shutil.rmtree(tmp_dir) - elif "rust" in language_type.lower(): - import os - + elif "rust" in language_type.lower(): + import os + WD: str = os.path.dirname(os.path.abspath(__file__)) RUST_DIR: str = os.path.join(WD, "rust") RUST_SRC: str = os.path.join(RUST_DIR, "src") RUST_BIN: str = os.path.join(RUST_SRC, "bin") RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp") RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs") - RUST_EXT: str = ".rs" + RUST_EXT: str = ".rs" # Create mandatory tmp directories os.makedirs(RUST_TMP_DIR, exist_ok=True) @@ -257,18 +274,18 @@ def unsafe_execute(tmp_dir): os.makedirs(RUST_SRC, exist_ok=True) os.makedirs(RUST_BIN, exist_ok=True) - with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f: - #temporal file name + with tempfile.NamedTemporaryFile(dir=RUST_BIN, delete=False) as f: + # temporal file name file_prefix = sample["task_id"].lower().replace("/", "_") - file_name:str = file_prefix +RUST_EXT - + file_name: str = file_prefix + RUST_EXT + os.rename(f.name, os.path.join(RUST_BIN, file_name)) - + # Sample to pure Rust function rust_code: str = sample["test_code"] # dump the rust source code in the target temporal file - f.write(rust_code.encode('utf-8')) + f.write(rust_code.encode("utf-8")) # Proceed towards Rust binaries compilation. Therefore move to Rust module root dir. os.chdir(RUST_DIR) @@ -277,35 +294,44 @@ def unsafe_execute(tmp_dir): # Pass OR Fail compilation log_filename: str = file_prefix + ".jsonl" log_path: str = os.path.join(RUST_LOGS, log_filename) - cargo_check: str = "cargo check --bin " + file_prefix + " --message-format json >> " + log_path + cargo_check: str = ( + "cargo check --bin " + + file_prefix + + " --message-format json >> " + + log_path + ) # Compilation build status returned_val_compilation: int - + # Overwrite file content if os.path.exists(log_path): - if(file_size := os.path.getsize(log_path)) >= 0: + if (file_size := os.path.getsize(log_path)) >= 0: os.remove(log_path) returned_val_compilation = os.system(cargo_check) - else: + else: returned_val_compilation = os.system(cargo_check) - # 0 means success + # 0 means success if returned_val_compilation == 0: - #Execution pipeline - cargo_test: str = "cargo test --bin " +file_prefix+ " --message-format json >> " + log_path + # Execution pipeline + cargo_test: str = ( + "cargo test --bin " + + file_prefix + + " --message-format json >> " + + log_path + ) returned_val_execution = os.system(cargo_test) - + if returned_val_execution == 0: result.append("passed") else: - result.append(f"failed: execution error") + result.append(f"failed: execution error") else: result.append(f"failed: compilation error") - elif "java" in language_type.lower(): assert tmp_dir is not None, "Java should be evaluated in a temporary dir." @@ -319,13 +345,16 @@ def unsafe_execute(tmp_dir): os.makedirs(tmp_dir) os.chdir(tmp_dir) - open(os.path.join(tmp_dir, "Main.java"), 'w').write(sample["test_code"]) + open(os.path.join(tmp_dir, "Main.java"), "w").write(sample["test_code"]) res = "failed: unknown error" compile_returncode = -1 for _ in range(5): try: - compilation_result = subprocess.run(['javac', os.path.join(tmp_dir, "Main.java")], timeout=5, - capture_output=True) + compilation_result = subprocess.run( + ["javac", os.path.join(tmp_dir, "Main.java")], + timeout=5, + capture_output=True, + ) compile_returncode = compilation_result.returncode break except subprocess.TimeoutExpired as e: @@ -348,7 +377,9 @@ def unsafe_execute(tmp_dir): if exec_result.returncode == 0: res = "passed" elif exec_result.returncode == 1: - if "AssertionError" in exec_result.stderr.decode('unicode-escape'): + if "AssertionError" in exec_result.stderr.decode( + "unicode-escape" + ): res = "failed: wrong answer" else: res = f"failed: {exec_result.stderr.decode()}" @@ -359,7 +390,7 @@ def unsafe_execute(tmp_dir): result.append(res) shutil.rmtree(tmp_dir) - + manager = multiprocessing.Manager() result = manager.list() @@ -373,18 +404,19 @@ def unsafe_execute(tmp_dir): result.append("timed out") return { - "task_id" : task_id, + "task_id": task_id, "completion_id": completion_id, - "test_code" : sample["test_code"], - "prompt" : sample["prompt"], - "generation" : sample["generation"], - "result" : result[0], - "passed" : result[0] == "passed", - "finish" : -1 if "finish" not in sample else sample["finish"], - "file" : "" if "file" not in sample else sample["file"], - "output" : [] if "output" not in sample else sample["output"], + "test_code": sample["test_code"], + "prompt": sample["prompt"], + "generation": sample["generation"], + "result": result[0], + "passed": result[0] == "passed", + "finish": -1 if "finish" not in sample else sample["finish"], + "file": "" if "file" not in sample else sample["file"], + "output": [] if "output" not in sample else sample["output"], } + # Copyright (c) OpenAI (https://openai.com) # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -439,7 +471,7 @@ class TimeoutException(Exception): class WriteOnlyStringIO(io.StringIO): - """ StringIO that throws an exception when it's read from """ + """StringIO that throws an exception when it's read from""" def read(self, *args, **kwargs): raise IOError @@ -451,12 +483,12 @@ def readlines(self, *args, **kwargs): raise IOError def readable(self, *args, **kwargs): - """ Returns True if the IO object can be read. """ + """Returns True if the IO object can be read.""" return False class redirect_stdin(contextlib._RedirectStream): # type: ignore - _stream = 'stdin' + _stream = "stdin" @contextlib.contextmanager @@ -482,26 +514,35 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): WARNING This function is NOT a security sandbox. Untrusted code, including, model- - generated code, should not be blindly executed outside of one. See the + generated code, should not be blindly executed outside of one. See the Codex paper for more information about OpenAI's code sandbox, and proceed with caution. """ if maximum_memory_bytes is not None: import resource - resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) - resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) - if not platform.uname().system == 'Darwin': - resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + + resource.setrlimit( + resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) + ) + resource.setrlimit( + resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) + ) + if not platform.uname().system == "Darwin": + resource.setrlimit( + resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) + ) faulthandler.disable() import builtins + builtins.exit = None builtins.quit = None import os - os.environ['OMP_NUM_THREADS'] = '1' + + os.environ["OMP_NUM_THREADS"] = "1" os.kill = None os.system = None @@ -532,18 +573,21 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): os.chdir = None import shutil + shutil.rmtree = None shutil.move = None shutil.chown = None import subprocess + subprocess.Popen = None # type: ignore - __builtins__['help'] = None + __builtins__["help"] = None import sys - sys.modules['ipdb'] = None - sys.modules['joblib'] = None - sys.modules['resource'] = None - sys.modules['psutil'] = None - sys.modules['tkinter'] = None + + sys.modules["ipdb"] = None + sys.modules["joblib"] = None + sys.modules["resource"] = None + sys.modules["psutil"] = None + sys.modules["tkinter"] = None diff --git a/codegeex/benchmark/gather_output.py b/codegeex/benchmark/gather_output.py index db63986..c6fad9b 100644 --- a/codegeex/benchmark/gather_output.py +++ b/codegeex/benchmark/gather_output.py @@ -33,7 +33,12 @@ def gather_output( output_list = glob.glob(os.path.join(output_dir, output_prefix + "*")) for output_file in output_list: - if "rank" in output_file or "_unfinished" in output_file or "all" in output_file or "_result" in output_file: + if ( + "rank" in output_file + or "_unfinished" in output_file + or "all" in output_file + or "_result" in output_file + ): continue if "_finished" not in output_file: continue diff --git a/codegeex/benchmark/humaneval-x/evaluate_humaneval_x.py b/codegeex/benchmark/humaneval-x/evaluate_humaneval_x.py index 076e0e8..b84a4e0 100644 --- a/codegeex/benchmark/humaneval-x/evaluate_humaneval_x.py +++ b/codegeex/benchmark/humaneval-x/evaluate_humaneval_x.py @@ -16,10 +16,10 @@ from codegeex.benchmark.execution import check_correctness LANGUAGE_NAME = { - "cpp" : "CPP", - "go" : "Go", - "java" : "Java", - "js" : "JavaScript", + "cpp": "CPP", + "go": "Go", + "java": "Java", + "js": "JavaScript", "python": "Python", } @@ -29,7 +29,11 @@ def process_humaneval_test(sample, problems, example_test=False): language = task_id.split("/")[0].lower() prompt = sample["prompt"] - if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "": + if ( + example_test + and "example_test" in problems[task_id] + and problems[task_id]["example_test"] != "" + ): test = problems[task_id]["example_test"] else: test = problems[task_id]["test"] @@ -39,7 +43,7 @@ def process_humaneval_test(sample, problems, example_test=False): if language == "python": code_ = [] for line in code.split("\n"): - if (len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t'): + if len(line.strip()) > 0 and line[0] != " " and line[0] != "\t": break code_.append(line) code = "\n".join(code_) @@ -68,10 +72,21 @@ def process_humaneval_test(sample, problems, example_test=False): if pkg not in test_setup: p = pkg.split("/")[-1] if p + "." in code: - other_pkgs.append(f"\"{pkg}\"") + other_pkgs.append(f'"{pkg}"') if other_pkgs: - import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")" - test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test + import_other_pkgs = ( + "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")" + ) + test_string = ( + test_setup + + "\n" + + import_other_pkgs + + "\n" + + prompt + + code + + "\n" + + test + ) else: test_string = test_setup + "\n" + prompt + code + "\n" + test elif language == "rust": @@ -97,21 +112,20 @@ def stream_jsonl_all(filename: str) -> Iterable[Dict]: def evaluate_functional_correctness( - input_file: str = None, - tmp_dir: str = "./", - n_workers: int = 32, - timeout: float = 500.0, - problem_file: str = "../data/humaneval_python.jsonl.gz", - out_dir: str = None, - k: List[int] = [1, 10, 100], - test_groundtruth: bool = False, - example_test: bool = False, + input_file: str = None, + tmp_dir: str = "./", + n_workers: int = 32, + timeout: float = 500.0, + problem_file: str = "../data/humaneval_python.jsonl.gz", + out_dir: str = None, + k: List[int] = [1, 10, 100], + test_groundtruth: bool = False, + example_test: bool = False, ): if example_test: print("Example test...") - problems = read_dataset(problem_file, - dataset_type="humaneval") + problems = read_dataset(problem_file, dataset_type="humaneval") sample_jsonl = stream_jsonl_all(input_file) if example_test: @@ -121,7 +135,9 @@ def evaluate_functional_correctness( if out_dir is not None: if not os.path.exists(out_dir): os.makedirs(out_dir) - out_file = os.path.join(out_dir, input_file.split('/')[-1].replace(".jsonl", suffix)) + out_file = os.path.join( + out_dir, input_file.split("/")[-1].replace(".jsonl", suffix) + ) else: out_file = os.path.join(input_file.replace(".jsonl", suffix)) @@ -149,10 +165,19 @@ def evaluate_functional_correctness( lang = "js" tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") sample["generation"] = sample["canonical_solution"] - sample["test_code"] = process_humaneval_test(sample, problems, example_test) + sample["test_code"] = process_humaneval_test( + sample, problems, example_test + ) if sample["test_code"] is None: continue - args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id]) + args = ( + task_id, + sample, + lang, + timeout, + tmp_dir_, + completion_id[task_id], + ) future = executor.submit(check_correctness, *args) futures.append(future) completion_id[task_id] += 1 @@ -164,7 +189,11 @@ def evaluate_functional_correctness( lang = task_id.split("/")[0].lower() if translation_mode: task_id = sample["task_id"].split("/")[-1] - lang = regex.findall("-to-.*-", input_file)[0].split("-to-")[-1].rstrip("-") + lang = ( + regex.findall("-to-.*-", input_file)[0] + .split("-to-")[-1] + .rstrip("-") + ) for l in LANGUAGE_NAME: if l in lang: lang = l @@ -174,7 +203,9 @@ def evaluate_functional_correctness( lang = "js" tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") sample["task_id"] = task_id - sample["test_code"] = process_humaneval_test(sample, problems, example_test) + sample["test_code"] = process_humaneval_test( + sample, problems, example_test + ) if sample["test_code"] is None: continue if "completion_id" in sample: @@ -208,8 +239,11 @@ def evaluate_functional_correctness( correct = np.array(correct) if evaluate_pass_at_k: ks = k - pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() - for k in ks if (total >= k).all()} + pass_at_k = { + f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() + for k in ks + if (total >= k).all() + } print(pass_at_k) else: print("Total:", np.sum(total)) @@ -222,7 +256,7 @@ def evaluate_functional_correctness( for r in res: fp.write((json.dumps(r[1]) + "\n").encode("utf-8")) else: - fp = open(out_file, 'w') + fp = open(out_file, "w") for res in results.values(): for r in res: fp.write(json.dumps(r[1]) + "\n") diff --git a/codegeex/benchmark/humaneval-x/generate_humaneval_x.py b/codegeex/benchmark/humaneval-x/generate_humaneval_x.py index 75eee09..0621851 100644 --- a/codegeex/benchmark/humaneval-x/generate_humaneval_x.py +++ b/codegeex/benchmark/humaneval-x/generate_humaneval_x.py @@ -91,7 +91,7 @@ def add_code_generation_args(parser): "--recompute", action="store_true", help="During generation recompute all attention " - "instead of using previously computed keys/values.", + "instead of using previously computed keys/values.", ) group.add_argument( "--load-deepspeed", @@ -180,22 +180,22 @@ def add_code_generation_args(parser): action="store_true", ) group.add_argument( - '--language-type', + "--language-type", default=None, - help='Identify the type of programming language to generate', + help="Identify the type of programming language to generate", ) group.add_argument( - '--bad-ids', + "--bad-ids", nargs="*", type=int, default=None, - help='Specify bad ids that will not be used', + help="Specify bad ids that will not be used", ) group.add_argument( "--quantize", action="store_true", ) - + return parser @@ -210,8 +210,8 @@ def main(node_rank: int, local_rank: int, master_port: int, num_devices: int): extra_args_provider=add_code_generation_args, args_defaults={ "tokenizer_type": "GPT2BPETokenizer", - "no_load_rng" : True, - "no_load_optim" : True, + "no_load_rng": True, + "no_load_optim": True, }, ) @@ -297,15 +297,17 @@ def server(): default=1, ) parser.add_argument( - '--language-type', + "--language-type", default=None, - help='Identify the type of programming language to generate', + help="Identify the type of programming language to generate", ) args = parser.parse_known_args()[0] entries = read_dataset(args.input_path, dataset_type="humaneval") - assert args.samples_per_problem % args.micro_batch_size == 0, "samples_per_problem should be divisible by micro_batch_size" + assert ( + args.samples_per_problem % args.micro_batch_size == 0 + ), "samples_per_problem should be divisible by micro_batch_size" for entry in entries.values(): entry["prompt"] = process_extra_prompt(entry["prompt"], args.language_type) @@ -352,13 +354,20 @@ def server(): else: entry = remaining_entries.pop() time_elapsed = time.perf_counter() - start_time - print(f"[ server ] Sending entry {entry['task_id']} to worker {rank}", flush=True) + print( + f"[ server ] Sending entry {entry['task_id']} to worker {rank}", + flush=True, + ) remaining = ( - len(remaining_entries) - / (len(all_entries) - len(remaining_entries)) - * time_elapsed + len(remaining_entries) + / (len(all_entries) - len(remaining_entries)) + * time_elapsed + ) + time_per_sampple = ( + 0.0 + if num_finished == 0 + else time_elapsed / num_finished / args.micro_batch_size ) - time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished / args.micro_batch_size print( f"[ server ] total {len(all_entries)}, assigned {len(all_entries) - len(remaining_entries)}, " f"finished {num_finished}, " @@ -374,7 +383,7 @@ def server(): socket.send_json({"pong": 1}) else: print(f"[ server ] {msg['task_id']} is not finished", flush=True) - remaining_entries.append(msg['task_id']) + remaining_entries.append(msg["task_id"]) socket.send_json({"pong": 1}) break @@ -416,7 +425,7 @@ def server(): node_rank = i break assert ( - node_rank is not None + node_rank is not None ), f"Could not find hostname ({socket.gethostbyname(socket.gethostname())}) in hostlist" # launch server diff --git a/codegeex/benchmark/humaneval-x/translate_humaneval_x.py b/codegeex/benchmark/humaneval-x/translate_humaneval_x.py index cab2888..895aa93 100644 --- a/codegeex/benchmark/humaneval-x/translate_humaneval_x.py +++ b/codegeex/benchmark/humaneval-x/translate_humaneval_x.py @@ -88,7 +88,7 @@ def add_code_generate_args(parser): "--recompute", action="store_true", help="During generation recompute all attention " - "instead of using previously computed keys/values.", + "instead of using previously computed keys/values.", ) group.add_argument( "--load-deepspeed", @@ -191,11 +191,11 @@ def add_code_generate_args(parser): default=None, ) group.add_argument( - '--bad-ids', + "--bad-ids", nargs="*", type=int, default=None, - help='Identify the type of programming language to generate', + help="Identify the type of programming language to generate", ) group.add_argument( "--src-path", @@ -210,16 +210,16 @@ def add_code_generate_args(parser): help="Get target path", ) group.add_argument( - '--language-src-type', + "--language-src-type", type=str, default=None, - help='Identify the type of programming language', + help="Identify the type of programming language", ) group.add_argument( - '--language-tgt-type', + "--language-tgt-type", type=str, default=None, - help='Identify the type of programming language to translate', + help="Identify the type of programming language to translate", ) return parser @@ -236,8 +236,8 @@ def main(node_rank: int, local_rank: int, master_port: int, num_devices: int): extra_args_provider=add_code_generate_args, args_defaults={ "tokenizer_type": "GPT2BPETokenizer", - "no_load_rng" : True, - "no_load_optim" : True, + "no_load_rng": True, + "no_load_optim": True, }, ) @@ -327,27 +327,31 @@ def server(): default=1, ) parser.add_argument( - '--language-src-type', + "--language-src-type", type=str, default=None, - help='Identify the type of programming language', + help="Identify the type of programming language", ) parser.add_argument( - '--language-tgt-type', + "--language-tgt-type", type=str, default=None, - help='Identify the type of programming language to translate', + help="Identify the type of programming language to translate", ) args = parser.parse_known_args()[0] - entries = read_translation_dataset(args.src_path, - args.tgt_path, - lang_src=args.language_src_type, - lang_tgt=args.language_tgt_type, - dataset_type="humaneval") + entries = read_translation_dataset( + args.src_path, + args.tgt_path, + lang_src=args.language_src_type, + lang_tgt=args.language_tgt_type, + dataset_type="humaneval", + ) - assert args.samples_per_problem % args.micro_batch_size == 0, "samples_per_problem should be divisible by micro_batch_size" + assert ( + args.samples_per_problem % args.micro_batch_size == 0 + ), "samples_per_problem should be divisible by micro_batch_size" res = [] for entry in entries.values(): @@ -391,13 +395,20 @@ def server(): else: entry = remaining_entries.pop() time_elapsed = time.perf_counter() - start_time - print(f"[ server ] Sending entry {entry['task_id']} to worker {rank}", flush=True) + print( + f"[ server ] Sending entry {entry['task_id']} to worker {rank}", + flush=True, + ) remaining = ( - len(remaining_entries) - / (len(all_entries) - len(remaining_entries)) - * time_elapsed + len(remaining_entries) + / (len(all_entries) - len(remaining_entries)) + * time_elapsed + ) + time_per_sampple = ( + 0.0 + if num_finished == 0 + else time_elapsed / num_finished / args.micro_batch_size ) - time_per_sampple = 0.0 if num_finished == 0 else time_elapsed / num_finished / args.micro_batch_size print( f"[ server ] total {len(all_entries)}, assigned {len(all_entries) - len(remaining_entries)}, " f"finished {num_finished}, " @@ -413,7 +424,7 @@ def server(): socket.send_json({"pong": 1}) else: print(f"[ server ] {msg['task_id']} is not finished", flush=True) - remaining_entries.append(msg['task_id']) + remaining_entries.append(msg["task_id"]) socket.send_json({"pong": 1}) break @@ -455,7 +466,7 @@ def server(): node_rank = i break assert ( - node_rank is not None + node_rank is not None ), f"Could not find hostname ({socket.gethostbyname(socket.gethostname())}) in hostlist" # launch server diff --git a/codegeex/benchmark/inspect_result.py b/codegeex/benchmark/inspect_result.py index 52adf99..c4b38ed 100644 --- a/codegeex/benchmark/inspect_result.py +++ b/codegeex/benchmark/inspect_result.py @@ -10,7 +10,7 @@ from codegeex.benchmark.metric import estimate_pass_at_k error_types = { - "python" : [ + "python": [ "accepted", "assertion error", "undefined error", @@ -19,7 +19,7 @@ "timeout error", "type error", ], - "java" : [ + "java": [ "accepted", "compilation error", "assertion error", @@ -36,7 +36,7 @@ "arithmetic error", "others", ], - "cpp" : [ + "cpp": [ "accepted", "compilation error", "assertion error", @@ -57,7 +57,7 @@ "range error", "type error", ], - "go" : [ + "go": [ "accepted", "assertion error", "undefined error", @@ -71,10 +71,10 @@ def inspect_result( - input_dir: str = None, - input_file: str = None, - output_dir: str = None, - pass_at_k_outpath: str = None, + input_dir: str = None, + input_file: str = None, + output_dir: str = None, + pass_at_k_outpath: str = None, ): if input_dir is not None: input_files = glob.glob(input_dir + "/*_results.jsonl") @@ -122,7 +122,11 @@ def inspect_result( if language_type == "python": if "assertionerror" in error.lower(): result_stats[task_id]["assertion error"] += 1 - elif "syntax" in error.lower() or "indent" in error.lower() or "literal" in error.lower(): + elif ( + "syntax" in error.lower() + or "indent" in error.lower() + or "literal" in error.lower() + ): result_stats[task_id]["syntax error"] += 1 elif "not defined" in error.lower(): result_stats[task_id]["undefined error"] += 1 @@ -169,29 +173,32 @@ def inspect_result( elif "int main(): assertion" in error.lower(): result_stats[task_id]["assertion error"] += 1 elif "out_of_range" in error.lower(): - result_stats[task_id]['range error'] += 1 + result_stats[task_id]["range error"] += 1 elif "corrupted top size" in error.lower(): - result_stats[task_id]['range error'] += 1 + result_stats[task_id]["range error"] += 1 elif "length_error" in error.lower(): - result_stats[task_id]['range error'] += 1 + result_stats[task_id]["range error"] += 1 elif "invalid_argument" in error.lower(): - result_stats[task_id]['invalid argument'] += 1 + result_stats[task_id]["invalid argument"] += 1 elif "invalid pointer" in error.lower(): - result_stats[task_id]['pointer error'] += 1 + result_stats[task_id]["pointer error"] += 1 elif "double free" in error.lower(): - result_stats[task_id]['pointer error'] += 1 + result_stats[task_id]["pointer error"] += 1 elif "free()" in error.lower(): - result_stats[task_id]['pointer error'] += 1 + result_stats[task_id]["pointer error"] += 1 elif "logic_error" in error.lower(): - result_stats[task_id]['pointer error'] += 1 + result_stats[task_id]["pointer error"] += 1 elif "sysmalloc: assertion" in error.lower(): - result_stats[task_id]['pointer error'] += 1 + result_stats[task_id]["pointer error"] += 1 elif "stack smashing" in error.lower(): - result_stats[task_id]['out of memory'] += 1 + result_stats[task_id]["out of memory"] += 1 elif "bad_alloc" in error.lower(): - result_stats[task_id]['out of memory'] += 1 - elif "terminate called after throwing an instance of" in error.lower(): - result_stats[task_id]['package error'] += 1 + result_stats[task_id]["out of memory"] += 1 + elif ( + "terminate called after throwing an instance of" + in error.lower() + ): + result_stats[task_id]["package error"] += 1 else: result_stats[task_id]["others"] += 1 @@ -227,9 +234,9 @@ def inspect_result( elif "timed out" in error: result_stats[task_id]["timeout error"] += 1 elif "not used" in error: - result_stats[task_id]['notused error'] += 1 + result_stats[task_id]["notused error"] += 1 elif "type" in error: - result_stats[task_id]['type error'] += 1 + result_stats[task_id]["type error"] += 1 else: incompleted = True break @@ -252,8 +259,11 @@ def inspect_result( correct = np.array(correct) ks = [1, 10, 100, 1000] - pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() - for k in ks if (total >= k).all()} + pass_at_k = { + f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() + for k in ks + if (total >= k).all() + } print(pass_at_k) pass_at_k["file"] = input_file @@ -269,7 +279,7 @@ def inspect_result( except Exception as e: print(e) print(f"Data incompleted, aborted. {input_file}") - + if pass_at_k_outpath is not None: jsonl_path = os.path.join(output_dir, pass_at_k_outpath) with open(jsonl_path, "w") as f_out: diff --git a/codegeex/benchmark/metric.py b/codegeex/benchmark/metric.py index 79d6db3..7c2ac5f 100644 --- a/codegeex/benchmark/metric.py +++ b/codegeex/benchmark/metric.py @@ -25,9 +25,9 @@ def estimate_pass_at_k( - num_samples: Union[int, List[int], np.ndarray], - num_correct: Union[List[int], np.ndarray], - k: int + num_samples: Union[int, List[int], np.ndarray], + num_correct: Union[List[int], np.ndarray], + k: int, ) -> np.ndarray: """ Estimates pass@k of each problem and returns them in an array. @@ -47,4 +47,6 @@ def estimator(n: int, c: int, k: int) -> float: assert len(num_samples) == len(num_correct) num_samples_it = iter(num_samples) - return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) + return np.array( + [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] + ) diff --git a/codegeex/benchmark/utils.py b/codegeex/benchmark/utils.py index de47797..5823ea3 100644 --- a/codegeex/benchmark/utils.py +++ b/codegeex/benchmark/utils.py @@ -22,7 +22,7 @@ "from typing import *", "from collections import *", ], - "go" : [ + "go": [ "math", "strings", "fmt", @@ -34,7 +34,7 @@ "math/rand", "crypto/md5", ], - "cpp" : [ + "cpp": [ "#include", "#include", "#include", @@ -58,7 +58,14 @@ def read_dataset( if "humaneval" in dataset_type.lower(): if data_file is None: current_path = os.path.dirname(os.path.abspath(__file__)) - data_file = os.path.join(current_path, "..", "humaneval-x", "python", "data", "humaneval_python.jsonl.gz") + data_file = os.path.join( + current_path, + "..", + "humaneval-x", + "python", + "data", + "humaneval_python.jsonl.gz", + ) dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} else: raise f"Dataset: {dataset_type} not supported." @@ -75,7 +82,9 @@ def read_translation_dataset( ) -> Dict: if "humaneval" in dataset_type.lower(): dataset_src = {task["task_id"]: task for task in stream_jsonl(data_file_src)} - dataset_tgt = {task["task_id"].split("/")[-1]: task for task in stream_jsonl(data_file_tgt)} + dataset_tgt = { + task["task_id"].split("/")[-1]: task for task in stream_jsonl(data_file_tgt) + } for k, sample in dataset_src.items(): prompt = "code translation\n" if lang_src == "cpp": @@ -84,7 +93,12 @@ def read_translation_dataset( prompt += "JavaScript:\n" else: prompt += f"{lang_src}:\n".capitalize() - prompt += dataset_src[k]["declaration"] + "\n" + dataset_src[k]["canonical_solution"].rstrip() + "\n" + prompt += ( + dataset_src[k]["declaration"] + + "\n" + + dataset_src[k]["canonical_solution"].rstrip() + + "\n" + ) if lang_tgt == "cpp": prompt += "C++:\n" elif lang_tgt == "js": @@ -126,7 +140,7 @@ def is_code_generation_finished( if "humaneval" in dataset.lower(): if language_type.lower() == "python": for line in code.split("\n"): - if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t': + if len(line.strip()) > 0 and line[0] != " " and line[0] != "\t": return True end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"] for w in end_words: @@ -164,27 +178,27 @@ def cleanup_code( end_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint", "\nassert"] for w in end_words: if w in code: - code = code[:code.rfind(w)] + code = code[: code.rfind(w)] elif language_type.lower() == "java": main_pos = code.find("public static void main") if main_pos != -1: - code = code[:main_pos] + '}' - if '}' in code: - code = code[:code.rfind('}')] + '}' - if code.count('{') + 1 == code.count('}'): + code = code[:main_pos] + "}" + if "}" in code: + code = code[: code.rfind("}")] + "}" + if code.count("{") + 1 == code.count("}"): code += "\n}" elif language_type.lower() == "go": end_words = ["\n//", "\nfunc main("] for w in end_words: if w in code: - code = code[:code.rfind(w)] - if '}' in code: - code = code[:code.rfind('}')] + '}' + code = code[: code.rfind(w)] + if "}" in code: + code = code[: code.rfind("}")] + "}" elif language_type.lower() == "cpp": - if '}' in code: - code = code[:code.rfind('}')] + '}' + if "}" in code: + code = code[: code.rfind("}")] + "}" elif language_type.lower() == "js": - if '}' in code: - code = code[:code.rfind('}')] + '}' + if "}" in code: + code = code[: code.rfind("}")] + "}" return code diff --git a/codegeex/data/data_utils.py b/codegeex/data/data_utils.py index 9d5e631..a4c9b11 100644 --- a/codegeex/data/data_utils.py +++ b/codegeex/data/data_utils.py @@ -5,62 +5,62 @@ from typing import * LANGUAGE_TAG = { - "c" : "// language: C", - "c++" : "// language: C++", - "cpp" : "// language: C++", - "c#" : "// language: C#", - "csharp" : "// language: C#", - "css" : "/* language: CSS */", - "cuda" : "// language: Cuda", - "dart" : "// language: Dart", - "lua" : "// language: Lua", - "objectivec" : "// language: Objective-C", - "objective-c" : "// language: Objective-C", + "c": "// language: C", + "c++": "// language: C++", + "cpp": "// language: C++", + "c#": "// language: C#", + "csharp": "// language: C#", + "css": "/* language: CSS */", + "cuda": "// language: Cuda", + "dart": "// language: Dart", + "lua": "// language: Lua", + "objectivec": "// language: Objective-C", + "objective-c": "// language: Objective-C", "objective-c++": "// language: Objective-C++", - "python" : "# language: Python", - "perl" : "# language: Perl", - "prolog" : f"% language: Prolog", - "swift" : "// language: swift", - "lisp" : "; language: Lisp", - "java" : "// language: Java", - "scala" : "// language: Scala", - "tex" : f"% language: TeX", - "vue" : "", - "markdown" : "", - "html" : "", - "php" : "// language: PHP", - "js" : "// language: JavaScript", - "javascript" : "// language: JavaScript", - "typescript" : "// language: TypeScript", - "go" : "// language: Go", - "shell" : "# language: Shell", - "rust" : "// language: Rust", - "sql" : "-- language: SQL", - "kotlin" : "// language: Kotlin", - "vb" : "' language: Visual Basic", - "ruby" : "# language: Ruby", - "pascal" : "// language: Pascal", - "r" : "# language: R", - "fortran" : "!language: Fortran", - "lean" : "-- language: Lean", - "matlab" : f"% language: Matlab", - "delphi" : "{language: Delphi}", - "scheme" : "; language: Scheme", - "basic" : "' language: Basic", - "assembly" : "; language: Assembly", - "groovy" : "// language: Groovy", - "abap" : "* language: Abap", - "gdscript" : "# language: GDScript", - "haskell" : "-- language: Haskell", - "julia" : "# language: Julia", - "elixir" : "# language: Elixir", - "excel" : "' language: Excel", - "clojure" : "; language: Clojure", - "actionscript" : "// language: ActionScript", - "solidity" : "// language: Solidity", - "powershell" : "# language: PowerShell", - "erlang" : f"% language: Erlang", - "cobol" : "// language: Cobol", + "python": "# language: Python", + "perl": "# language: Perl", + "prolog": f"% language: Prolog", + "swift": "// language: swift", + "lisp": "; language: Lisp", + "java": "// language: Java", + "scala": "// language: Scala", + "tex": f"% language: TeX", + "vue": "", + "markdown": "", + "html": "", + "php": "// language: PHP", + "js": "// language: JavaScript", + "javascript": "// language: JavaScript", + "typescript": "// language: TypeScript", + "go": "// language: Go", + "shell": "# language: Shell", + "rust": "// language: Rust", + "sql": "-- language: SQL", + "kotlin": "// language: Kotlin", + "vb": "' language: Visual Basic", + "ruby": "# language: Ruby", + "pascal": "// language: Pascal", + "r": "# language: R", + "fortran": "!language: Fortran", + "lean": "-- language: Lean", + "matlab": f"% language: Matlab", + "delphi": "{language: Delphi}", + "scheme": "; language: Scheme", + "basic": "' language: Basic", + "assembly": "; language: Assembly", + "groovy": "// language: Groovy", + "abap": "* language: Abap", + "gdscript": "# language: GDScript", + "haskell": "-- language: Haskell", + "julia": "# language: Julia", + "elixir": "# language: Elixir", + "excel": "' language: Excel", + "clojure": "; language: Clojure", + "actionscript": "// language: ActionScript", + "solidity": "// language: Solidity", + "powershell": "# language: PowerShell", + "erlang": f"% language: Erlang", + "cobol": "// language: Cobol", } @@ -99,13 +99,13 @@ def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False): with open(filename, mode) as fp: for x in data: fp.write((json.dumps(x) + "\n").encode("utf-8")) - - + + def sliding_window( - prompt_tokens: list, - code_tokens: list, - seq_len: int, - sliding_stride: int, + prompt_tokens: list, + code_tokens: list, + seq_len: int, + sliding_stride: int, minimum_code_len: int = 1, ) -> Iterable[Tuple[list, list]]: """ @@ -115,16 +115,20 @@ def sliding_window( code_len = len(code_tokens) total_len = prompt_len + code_len - start_idx = max(0, prompt_len - seq_len + minimum_code_len) # at least `minimum_code_len` code token should be in the window + start_idx = max( + 0, prompt_len - seq_len + minimum_code_len + ) # at least `minimum_code_len` code token should be in the window end_idx = max(0, total_len - seq_len) start_idx = min(start_idx, end_idx) for i in range(start_idx, end_idx + 1, sliding_stride): - current_prompt = prompt_tokens[i:i + seq_len] - current_code = code_tokens[max(i - prompt_len, 0):i - prompt_len + seq_len] + current_prompt = prompt_tokens[i : i + seq_len] + current_code = code_tokens[max(i - prompt_len, 0) : i - prompt_len + seq_len] yield current_prompt, current_code if (end_idx - start_idx) % sliding_stride != 0: - current_prompt = prompt_tokens[end_idx:end_idx + seq_len] - current_code = code_tokens[max(end_idx - prompt_len, 0):end_idx - prompt_len + seq_len] + current_prompt = prompt_tokens[end_idx : end_idx + seq_len] + current_code = code_tokens[ + max(end_idx - prompt_len, 0) : end_idx - prompt_len + seq_len + ] yield current_prompt, current_code diff --git a/codegeex/data/process_pretrain_dataset.py b/codegeex/data/process_pretrain_dataset.py index 8d9f03b..d846933 100644 --- a/codegeex/data/process_pretrain_dataset.py +++ b/codegeex/data/process_pretrain_dataset.py @@ -32,7 +32,7 @@ def try_format_code(code: str): def load_pretrain_dataset(dataset_path: Union[str, List[str]]) -> Dict: if type(dataset_path) is str: dataset_path = [dataset_path] - + for p in dataset_path: if not os.path.isdir(p): if p.endswith(".gz") or p.endswith(".jsonl"): @@ -44,18 +44,18 @@ def load_pretrain_dataset(dataset_path: Union[str, List[str]]) -> Dict: if p_.endswith(".gz") or p_.endswith(".jsonl"): print(f"loading from {p_}") yield from stream_jsonl(p_) - - + + def process_sample( - sample: Dict, - language: str=None, - mode: str="pretrain", + sample: Dict, + language: str = None, + mode: str = "pretrain", ) -> Iterable[PromptSample]: if mode == "pretrain": prompt = "" else: prompt = sample["prompt"] - + try: if language is not None and language in LANGUAGE_TAG.keys(): code = LANGUAGE_TAG[language] + "\n" + sample["code"] @@ -65,12 +65,12 @@ def process_sample( print(e) print("The key 'code' is missing in data. Aborted") exit(0) - + yield PromptSample(prompt, code) def generate_prompt_samples( - dataset: Iterable[Dict], + dataset: Iterable[Dict], language: str = None, mode: str = "pretrain", ) -> PromptDataset: @@ -90,7 +90,7 @@ def main( seq_len: int = 2048, ): DATA_KEYS = ["input_ids", "attention_mask", "labels"] - + # create output dir os.makedirs(os.path.dirname(output_prefix), exist_ok=True) @@ -117,17 +117,18 @@ def main( # NOTE that we use seq_len + 1 instead of seq_len, since the input tokens will be shifted by one. processor = PromptDatasetProcessor( - tokenize=tokenizer.encode_code, + tokenize=tokenizer.encode_code, pad_token=pad_token_id, - max_seq_len=seq_len + 1, + max_seq_len=seq_len + 1, discard_overlong=discard_overlong, sliding_stride=sliding_stride, - eod_token=pad_token_id) - + eod_token=pad_token_id, + ) + processor.start_time = perf_counter() - doc_iter = pool.imap_unordered(processor.process_sample_strict, - prompt_dataset, - chunksize=20) + doc_iter = pool.imap_unordered( + processor.process_sample_strict, prompt_dataset, chunksize=20 + ) for doc_idx, docs in tqdm(enumerate(doc_iter, start=1)): processor.doc_processed += 1 diff --git a/codegeex/data/processor.py b/codegeex/data/processor.py index 25775da..bce2250 100644 --- a/codegeex/data/processor.py +++ b/codegeex/data/processor.py @@ -14,7 +14,7 @@ def __init__( max_seq_len: int = 2048, sliding_stride: int = 200, discard_overlong: bool = True, - eod_token: int = None, + eod_token: int = None, preprocess: Callable = None, ): super(PromptDatasetProcessor, self).__init__() @@ -31,12 +31,18 @@ def __init__( self.doc_generated = 0 self.start_time = 0 - def pad_seq(self, prompt_tokens: List[int], code_tokens: List[int], extra: dict = None) -> Dict[str, List[int]]: + def pad_seq( + self, prompt_tokens: List[int], code_tokens: List[int], extra: dict = None + ) -> Dict[str, List[int]]: total_length = len(prompt_tokens) + len(code_tokens) - assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}" + assert ( + total_length <= self._max_seq_len + ), f"padding sequence: {total_length} > {self._max_seq_len}" pad_len = self._max_seq_len - total_length input_ids = prompt_tokens + code_tokens + [self._pad_token] * pad_len - attention_mask = [1] * len(prompt_tokens) + [1] * len(code_tokens) + [0] * pad_len + attention_mask = ( + [1] * len(prompt_tokens) + [1] * len(code_tokens) + [0] * pad_len + ) labels = [-100] * len(prompt_tokens) + code_tokens + [-100] * pad_len return { @@ -58,7 +64,13 @@ def process_sample(self, sample: PromptSample) -> Iterable[Dict[str, List[int]]] if len(prompt_tokens) + len(code_tokens) > self._max_seq_len: if self._discard_overlong: return - for p, t in sliding_window(prompt_tokens, code_tokens, self._max_seq_len, self._sliding_stride, self._sliding_stride): + for p, t in sliding_window( + prompt_tokens, + code_tokens, + self._max_seq_len, + self._sliding_stride, + self._sliding_stride, + ): yield self.pad_seq(p, t) else: yield self.pad_seq(prompt_tokens, code_tokens, extra=sample.extra) @@ -69,7 +81,7 @@ def process_sample_strict(self, sample: PromptSample) -> List[Dict[str, List[int """ if sample is None: return None - + return list(self.process_sample(sample)) def process_sample_(self, sample) -> List[Dict[str, List[int]]]: @@ -80,9 +92,12 @@ def report(self): duration = perf_counter() - self.start_time process_speed = self.doc_processed * 1.0 / duration gen_speed = self.doc_generated * 1.0 / duration - print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s") - print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s") - + print( + f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s" + ) + print( + f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s" + ) class LabelDatasetProcessor(object): @@ -94,7 +109,7 @@ def __init__( max_seq_len: int = 2048, sliding_stride: int = 200, discard_overlong: bool = True, - eod_token: int = None, + eod_token: int = None, preprocess: Callable = None, ): super(LabelDatasetProcessor, self).__init__() @@ -111,20 +126,25 @@ def __init__( self.doc_generated = 0 self.start_time = 0 - def pad_seq(self, prompt_tokens: List[int], label: int, extra: dict = None) -> Dict[str, List[int]]: - total_length = len(prompt_tokens) - assert total_length <= self._max_seq_len, f"padding sequence: {total_length} > {self._max_seq_len}" + def pad_seq( + self, prompt_tokens: List[int], label: int, extra: dict = None + ) -> Dict[str, List[int]]: + total_length = len(prompt_tokens) + assert ( + total_length <= self._max_seq_len + ), f"padding sequence: {total_length} > {self._max_seq_len}" pad_len = self._max_seq_len - total_length - input_ids = prompt_tokens + [self._pad_token] * pad_len + input_ids = prompt_tokens + [self._pad_token] * pad_len attention_mask = [1] * len(prompt_tokens) + [0] * pad_len label = [label] return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "length": [len(prompt_tokens)], - "labels": label + "input_ids": input_ids, + "attention_mask": attention_mask, + "length": [len(prompt_tokens)], + "labels": label, } + def process_sample(self, sample: LabelSample) -> Iterable[Dict[str, List[int]]]: """ Process a sample. @@ -132,12 +152,11 @@ def process_sample(self, sample: LabelSample) -> Iterable[Dict[str, List[int]]]: prompt_tokens = self._tokenize(sample.prompt) label = sample.label - if len(prompt_tokens) > self._max_seq_len: if self._discard_overlong: return - prompt_tokens=prompt_tokens[-self._max_seq_len:] - + prompt_tokens = prompt_tokens[-self._max_seq_len :] + yield self.pad_seq(prompt_tokens, label, extra=sample.extra) def process_sample_strict(self, sample: LabelSample) -> List[Dict[str, List[int]]]: @@ -146,7 +165,7 @@ def process_sample_strict(self, sample: LabelSample) -> List[Dict[str, List[int] """ if sample is None: return None - + return list(self.process_sample(sample)) def process_sample_(self, sample) -> List[Dict[str, List[int]]]: @@ -157,5 +176,9 @@ def report(self): duration = perf_counter() - self.start_time process_speed = self.doc_processed * 1.0 / duration gen_speed = self.doc_generated * 1.0 / duration - print(f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s") - print(f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s") + print( + f">>> processed: {self.doc_processed} in {duration:.2f}s, speed: {process_speed:.2f} docs/s" + ) + print( + f"... generated: {self.doc_generated} in {duration:.2f}s, speed: {gen_speed:.2f} docs/s" + ) diff --git a/codegeex/data/types.py b/codegeex/data/types.py index 394dec9..90282cf 100644 --- a/codegeex/data/types.py +++ b/codegeex/data/types.py @@ -11,10 +11,12 @@ class PromptSample: PromptDataset = Iterable[PromptSample] + @dataclass class LabelSample: prompt: str label: int extra: dict = None -LabelDataset = Iterable[LabelSample] \ No newline at end of file + +LabelDataset = Iterable[LabelSample] diff --git a/codegeex/kernels/__init__.py b/codegeex/kernels/__init__.py index 6037536..85208da 100644 --- a/codegeex/kernels/__init__.py +++ b/codegeex/kernels/__init__.py @@ -12,7 +12,9 @@ class Kernel: def __init__(self, filename: str, function_names: List[str]): filename = filename + ".fatbin" if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename): - raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME)) + raise RuntimeError( + "File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME) + ) self.filename = filename self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename) self._function_names = function_names @@ -50,12 +52,19 @@ def compress_int4_weight(weight: torch.Tensor): # (n, m) blockDim, 0, stream, - [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], + [ + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int32(n), + ctypes.c_int32(m), + ], ) return out -def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): +def extract_weight_to_half( + weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int +): if source_bit_width == 8: func = kernels.int8WeightExtractionHalf elif source_bit_width == 4: @@ -65,7 +74,9 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc with torch.cuda.device(weight.device): n, m = weight.size(0), weight.size(1) - out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda") + out = torch.empty( + n, m * (8 // source_bit_width), dtype=torch.half, device="cuda" + ) stream = torch.cuda.current_stream() gridDim = (n, 1, 1) diff --git a/codegeex/megatron/arguments.py b/codegeex/megatron/arguments.py index 315f4b9..5ccfd5a 100644 --- a/codegeex/megatron/arguments.py +++ b/codegeex/megatron/arguments.py @@ -294,7 +294,7 @@ def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False) "for distribute-checkpointed-activations to work you " "need to enable checkpoint-activations" ) - + _print_args(args) return args @@ -392,10 +392,12 @@ def _add_network_size_args(parser): action="store_true", help="If set, use original BERT residula connection " "ordering.", ) - group.add_argument('--scaled-upper-triang-masked-softmax-fusion', - action='store_true', - help='Enable fusion of query_key_value_scaling ' - 'time (upper diagonal) masking, softmax.') + group.add_argument( + "--scaled-upper-triang-masked-softmax-fusion", + action="store_true", + help="Enable fusion of query_key_value_scaling " + "time (upper diagonal) masking, softmax.", + ) group.add_argument( "--openai-gelu", action="store_true", @@ -530,11 +532,7 @@ def _add_regularization_args(parser): default=0.05, help="Beta for GOLD tempering.", ) - group.add_argument( - "--play-tau", - type=float, - default=2.0 - ) + group.add_argument("--play-tau", type=float, default=2.0) group.add_argument( "--clip-grad", type=float, @@ -569,18 +567,18 @@ def _add_regularization_args(parser): action="store_true", ) group.add_argument( - "--shrink-embedding-gradient-alpha", - type=float, + "--shrink-embedding-gradient-alpha", + type=float, default=1.0, - help='Shrink embedding gradient for alpha', + help="Shrink embedding gradient for alpha", ) group.add_argument( - "--shrink-embedding-gradient-steps", - nargs='*', + "--shrink-embedding-gradient-steps", + nargs="*", default=None, - help='--shrink-embedding-gradient-steps ' - 'Shrink embedding gradient alpha for x1 steps,' - 'then warm it up to 1.0 with x2 steps', + help="--shrink-embedding-gradient-steps " + "Shrink embedding gradient alpha for x1 steps," + "then warm it up to 1.0 with x2 steps", ) return parser @@ -774,32 +772,32 @@ def _add_inference_args(parser): group = parser.add_argument_group(title="initialization") group.add_argument( - '--evaluation', + "--evaluation", action="store_true", ) group.add_argument( - '--beam-warmup', + "--beam-warmup", action="store_true", ) group.add_argument( - '--beam-warmup-length', + "--beam-warmup-length", type=int, default=0, ) group.add_argument( - '--beam-search', + "--beam-search", action="store_true", ) group.add_argument( - '--beam-search-nucleus', + "--beam-search-nucleus", action="store_true", ) group.add_argument( - '--num-beams', + "--num-beams", type=int, default=4, ) - + return parser @@ -933,7 +931,7 @@ def _add_checkpointing_args(parser): action="store_true", default=None, help="Load model checkpoint in low memory mode." - "On each machine, workers load the checkpoint one at a time." + "On each machine, workers load the checkpoint one at a time.", ) group.add_argument( "--dist-timeout", @@ -974,7 +972,9 @@ def _add_mixed_precision_args(parser): group = parser.add_argument_group(title="mixed precision") group.add_argument("--fp16", action="store_true", help="Run model in fp16 mode.") - group.add_argument("--ln-fp16", action="store_true", help="Run layernorm in fp16 mode.") + group.add_argument( + "--ln-fp16", action="store_true", help="Run layernorm in fp16 mode." + ) group.add_argument( "--bf16", action="store_true", help="Run model in bfloat16 mode." ) @@ -989,7 +989,7 @@ def _add_mixed_precision_args(parser): group.add_argument( "--initial-loss-scale", type=float, - default=2 ** 32, + default=2**32, help="Initial loss-scale for dynamic loss scaling.", ) group.add_argument( @@ -1012,10 +1012,13 @@ def _add_mixed_precision_args(parser): action="store_true", help="Move residual connections to fp32.", ) - group.add_argument('--apply-query-key-layer-scaling', action='store_true', - help='Scale Q * K^T by 1 / layer-number. If this flag ' - 'is set, then it will automatically set ' - 'attention-softmax-in-fp32 to true') + group.add_argument( + "--apply-query-key-layer-scaling", + action="store_true", + help="Scale Q * K^T by 1 / layer-number. If this flag " + "is set, then it will automatically set " + "attention-softmax-in-fp32 to true", + ) group.add_argument( "--attention-softmax-in-fp32", action="store_true", @@ -1143,7 +1146,7 @@ def _add_validation_args(parser): group.add_argument( "--co-evaluation", action="store_true", - help="If set, run evaluation on each part of the validation set" + help="If set, run evaluation on each part of the validation set", ) return parser @@ -1171,15 +1174,17 @@ def _add_data_args(parser): "dataset2-path ...;" "when co-evaluation is enabled, the form will be dataset1-tag dataset1-path ...", ) - group.add_argument("--index-cache-dir", type=str, default=None, help="Path to the index cache") + group.add_argument( + "--index-cache-dir", type=str, default=None, help="Path to the index cache" + ) group.add_argument( "--test-data-path", nargs="*", default=None, help="Path to the test dataset. Accepted format:" - "1) a single data path, 2) multiple datasets in the" - "form: dataset1-tag dataset1-path dataset2-tag " - "dataset2-path ...", + "1) a single data path, 2) multiple datasets in the" + "form: dataset1-tag dataset1-path dataset2-tag " + "dataset2-path ...", ) group.add_argument( "--split", @@ -1191,21 +1196,21 @@ def _add_data_args(parser): "validation and 5%% for test.", ) group.add_argument( - "--vocab-file", - type=str, - default=None, + "--vocab-file", + type=str, + default=None, help="Path to the vocab file.", ) group.add_argument( - "--merge-file", - type=str, - default=None, + "--merge-file", + type=str, + default=None, help="Path to the BPE merge file.", ) group.add_argument( - "--tokenizer-path", - type=str, - default=None, + "--tokenizer-path", + type=str, + default=None, help="Path to the tokenizer dir.", ) group.add_argument( diff --git a/codegeex/megatron/checkpointing.py b/codegeex/megatron/checkpointing.py index 53f71e2..591ed31 100644 --- a/codegeex/megatron/checkpointing.py +++ b/codegeex/megatron/checkpointing.py @@ -23,7 +23,13 @@ import torch -from codegeex.megatron import get_args, mpu, print_rank_0, update_num_microbatches, utils +from codegeex.megatron import ( + get_args, + mpu, + print_rank_0, + update_num_microbatches, + utils, +) _CHECKPOINT_VERSION = None @@ -90,7 +96,9 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False): return os.path.join( checkpoints_path, directory, - "mp_rank_{:02d}_model_states.pt".format(mpu.get_tensor_model_parallel_rank()), + "mp_rank_{:02d}_model_states.pt".format( + mpu.get_tensor_model_parallel_rank() + ), ) return os.path.join( checkpoints_path, @@ -299,10 +307,10 @@ def load_deepspeed_state(model): def load_checkpoint( - model, - optimizer, - lr_scheduler, - load_arg="load", + model, + optimizer, + lr_scheduler, + load_arg="load", strict=True, ): """Load a model checkpoint and return the iteration. @@ -337,7 +345,9 @@ def load_checkpoint( # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_0( - "WARNING: could not find the metadata file {} ".format(tracker_filename) + "WARNING: could not find the metadata file {} ".format( + tracker_filename + ) ) iteration = 0 release = True @@ -366,14 +376,15 @@ def load_checkpoint( ) sys.exit() - assert iteration > 0 or release, "error parsing metadata file {}".format( - tracker_filename - ) + assert ( + iteration > 0 or release + ), "error parsing metadata file {}".format(tracker_filename) # Checkpoint. checkpoint_name = get_checkpoint_name(load_dir, iteration, release) - print_rank_0(f" loading checkpoint from {args.load} at iteration {iteration}") - + print_rank_0( + f" loading checkpoint from {args.load} at iteration {iteration}" + ) # Load the checkpoint. try: diff --git a/codegeex/megatron/code_generation_utils.py b/codegeex/megatron/code_generation_utils.py index 8ec082d..7c27a6d 100644 --- a/codegeex/megatron/code_generation_utils.py +++ b/codegeex/megatron/code_generation_utils.py @@ -131,8 +131,8 @@ def generate_samples_input_from_file(model): raw_text_len = 0 if ( - mpu.is_pipeline_first_stage() - and mpu.get_tensor_model_parallel_rank() == 0 + mpu.is_pipeline_first_stage() + and mpu.get_tensor_model_parallel_rank() == 0 ): raw_text = all_raw_text[input_pos] input_pos += 1 @@ -174,8 +174,8 @@ def generate_samples_input_from_file(model): # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if ( - mpu.get_tensor_model_parallel_rank() == 0 - and args.pipeline_model_parallel_size > 1 + mpu.get_tensor_model_parallel_rank() == 0 + and args.pipeline_model_parallel_size > 1 ): if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() @@ -206,8 +206,8 @@ def generate_samples_input_from_file(model): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize(decode_tokens)[ - raw_text_len: - ] + raw_text_len: + ] print("\nMegatron-LM:", trim_decode_tokens, flush=True) fname_out.write("\n\nMegatron-LM:") @@ -261,8 +261,8 @@ def generate_samples_interactive_code_contest(model, print_frequency=10): raw_text_len = 0 if ( - mpu.is_pipeline_first_stage() - and mpu.get_tensor_model_parallel_rank() == 0 + mpu.is_pipeline_first_stage() + and mpu.get_tensor_model_parallel_rank() == 0 ): # os.system("clear") raw_text = [] @@ -320,8 +320,8 @@ def generate_samples_interactive_code_contest(model, print_frequency=10): # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if ( - mpu.get_tensor_model_parallel_rank() == 0 - and args.pipeline_model_parallel_size > 1 + mpu.get_tensor_model_parallel_rank() == 0 + and args.pipeline_model_parallel_size > 1 ): if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() @@ -337,13 +337,15 @@ def generate_samples_interactive_code_contest(model, print_frequency=10): torch.distributed.broadcast(context_tokens_tensor, src, group) context_tokens = context_tokens_tensor.cpu().numpy().tolist() - token_stream = get_token_stream(model, [context_tokens for _ in range(args.micro_batch_size)]) + token_stream = get_token_stream( + model, [context_tokens for _ in range(args.micro_batch_size)] + ) for counter, decode_tokens in enumerate(token_stream): if ( - counter % print_frequency != 0 - or mpu.get_tensor_model_parallel_rank() != 0 - or not mpu.is_pipeline_first_stage() + counter % print_frequency != 0 + or mpu.get_tensor_model_parallel_rank() != 0 + or not mpu.is_pipeline_first_stage() ): continue @@ -353,11 +355,15 @@ def generate_samples_interactive_code_contest(model, print_frequency=10): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() trim_decode_tokens = tokenizer.detokenize(decode_tokens)[raw_text_len:] - print(f"\nMegatron-LM (gen len: {counter}):", trim_decode_tokens, flush=True) + print( + f"\nMegatron-LM (gen len: {counter}):", + trim_decode_tokens, + flush=True, + ) if ( - mpu.is_pipeline_first_stage() - and mpu.get_tensor_model_parallel_rank() == 0 + mpu.is_pipeline_first_stage() + and mpu.get_tensor_model_parallel_rank() == 0 ): os.system("clear") print("\nContext:", raw_text, flush=True) @@ -386,8 +392,8 @@ def generate_samples_interactive(model, print_frequency=24): raw_text_len = 0 if ( - mpu.is_pipeline_first_stage() - and mpu.get_tensor_model_parallel_rank() == 0 + mpu.is_pipeline_first_stage() + and mpu.get_tensor_model_parallel_rank() == 0 ): os.system("clear") raw_text = input("\nContext prompt (stop to exit) >>> ") @@ -430,8 +436,8 @@ def generate_samples_interactive(model, print_frequency=24): # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if ( - mpu.get_tensor_model_parallel_rank() == 0 - and args.pipeline_model_parallel_size > 1 + mpu.get_tensor_model_parallel_rank() == 0 + and args.pipeline_model_parallel_size > 1 ): if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() @@ -451,9 +457,9 @@ def generate_samples_interactive(model, print_frequency=24): for counter, decode_tokens in enumerate(token_stream): if ( - counter % print_frequency != 0 - or mpu.get_tensor_model_parallel_rank() != 0 - or not mpu.is_pipeline_first_stage() + counter % print_frequency != 0 + or mpu.get_tensor_model_parallel_rank() != 0 + or not mpu.is_pipeline_first_stage() ): continue @@ -466,8 +472,8 @@ def generate_samples_interactive(model, print_frequency=24): print("\nMegatron-LM:", trim_decode_tokens, flush=True) if ( - mpu.is_pipeline_first_stage() - and mpu.get_tensor_model_parallel_rank() == 0 + mpu.is_pipeline_first_stage() + and mpu.get_tensor_model_parallel_rank() == 0 ): os.system("clear") print("\nContext:", raw_text, flush=True) @@ -507,7 +513,7 @@ def generate_samples_unconditional(model): length_batch = token_stream[1].cpu().numpy().tolist() assert len(length_batch) == args.micro_batch_size for tokens, length in zip(token_batch, length_batch): - tokens = tokens[1: length - 1] + tokens = tokens[1 : length - 1] text = tokenizer.detokenize(tokens) is_finished = length < args.seq_length - 1 datum = {"text": text, "length": length - 1, "finished": is_finished} @@ -531,8 +537,8 @@ def generate_and_write_samples_unconditional(model): with open(args.genfile, "w") as f: for datum in generate_samples_unconditional(model): if ( - mpu.is_pipeline_last_stage() - and mpu.get_tensor_model_parallel_rank() == 0 + mpu.is_pipeline_last_stage() + and mpu.get_tensor_model_parallel_rank() == 0 ): f.write(json.dumps(datum) + "\n") @@ -563,7 +569,12 @@ def topk_sampling(logits: torch.FloatTensor, num_samples: int): return topk_tokens, topk_log_prob -def nuclear_sampling(logits: torch.FloatTensor, temperature: float, top_p: float = None, top_k: int = None): +def nuclear_sampling( + logits: torch.FloatTensor, + temperature: float, + top_p: float = None, + top_k: int = None, +): orig_log_probs = F.log_softmax(logits, dim=-1) logits /= temperature logits = top_k_logits(logits, top_k, top_p) @@ -576,10 +587,17 @@ def nuclear_sampling(logits: torch.FloatTensor, temperature: float, top_p: float return tokens, new_scores -def sample_topk_tokens(model, - input_tokens, attention_mask, position_ids, - context_length: int, num_samples: int): - assert context_length < input_tokens.shape[-1], "context_length must be smaller than seq_length" +def sample_topk_tokens( + model, + input_tokens, + attention_mask, + position_ids, + context_length: int, + num_samples: int, +): + assert ( + context_length < input_tokens.shape[-1] + ), "context_length must be smaller than seq_length" model.eval() with torch.no_grad(): @@ -597,10 +615,19 @@ def sample_topk_tokens(model, return topk_sampling(logits, num_samples) -def nuclear_sample_tokens(model, - input_tokens, attention_mask, position_ids, - context_length: int, temperature: float, top_p: float, top_k: int): - assert context_length < input_tokens.shape[-1], "context_length must be smaller than seq_length" +def nuclear_sample_tokens( + model, + input_tokens, + attention_mask, + position_ids, + context_length: int, + temperature: float, + top_p: float, + top_k: int, +): + assert ( + context_length < input_tokens.shape[-1] + ), "context_length must be smaller than seq_length" model.eval() with torch.no_grad(): @@ -642,10 +669,14 @@ def expand_beams(beams: List[Beam], num_beams: int, model) -> List[Beam]: context_tokens_tensor = torch.cuda.LongTensor(context_tokens) tokens, attention_mask, position_ids = get_batch_(context_tokens_tensor) - tokens, scores = sample_topk_tokens(model, tokens, attention_mask, position_ids, context_length, num_beams) + tokens, scores = sample_topk_tokens( + model, tokens, attention_mask, position_ids, context_length, num_beams + ) tokens = tokens.detach().cpu().tolist() scores = scores.detach().cpu().tolist() - assert len(tokens) == len(beams), "output tokens and input beams must have the same length" + assert len(tokens) == len( + beams + ), "output tokens and input beams must have the same length" all_beams = [] for i in range(len(beams)): @@ -684,7 +715,10 @@ def beam_search(model, context_tokens, num_beams: int): next_beams = [] for beam in expanded_beams: if args.beam_warmup: - if len(beam.tokens) >= org_context_len + args.beam_warmup_length or beam.tokens[-1] == tokenizer.eod: + if ( + len(beam.tokens) >= org_context_len + args.beam_warmup_length + or beam.tokens[-1] == tokenizer.eod + ): finished_beams.append(beam) else: next_beams.append(beam) @@ -713,7 +747,9 @@ def beam_search(model, context_tokens, num_beams: int): if min_score >= beams[0].score: break else: - print(f"we have got enough finished beams, but the minimal score is {min_score}") + print( + f"we have got enough finished beams, but the minimal score is {min_score}" + ) print(f"and the maximum searching score is {beams[0].score}") # return top-k finished and unfinished beams @@ -739,7 +775,9 @@ def derived(self, new_token: int, log_prob: float): return Handle(self.tokens + [new_token], self.score + log_prob) -def expand_handles(handles: List[Handle], temperature: float, top_p: float, top_k: int, model): +def expand_handles( + handles: List[Handle], temperature: float, top_p: float, top_k: int, model +): args = get_args() tokenizer = get_tokenizer() @@ -752,11 +790,21 @@ def expand_handles(handles: List[Handle], temperature: float, top_p: float, top_ context_tokens_tensor = torch.cuda.LongTensor(context_tokens) tokens, attention_mask, position_ids = get_batch_(context_tokens_tensor) - tokens, scores = nuclear_sample_tokens(model, tokens, attention_mask, position_ids, context_length, temperature, - top_p, top_k) + tokens, scores = nuclear_sample_tokens( + model, + tokens, + attention_mask, + position_ids, + context_length, + temperature, + top_p, + top_k, + ) tokens = tokens.detach().cpu().tolist() scores = scores.detach().cpu().tolist() - assert len(tokens) == len(handles), "output tokens and input must have the same length" + assert len(tokens) == len( + handles + ), "output tokens and input must have the same length" all_beams = [] for i in range(len(handles)): @@ -768,7 +816,14 @@ def expand_handles(handles: List[Handle], temperature: float, top_p: float, top_ return all_beams -def generate_nuclear_sampling(model, context_tokens, num_samples: int, temperature: float, top_p: float, top_k: int): +def generate_nuclear_sampling( + model, + context_tokens, + num_samples: int, + temperature: float, + top_p: float, + top_k: int, +): """Beam search. Note that this function does not support model parallel! @@ -799,16 +854,16 @@ def generate_nuclear_sampling(model, context_tokens, num_samples: int, temperatu def forward_step( - model, - tokens, - position_ids, - attention_mask, - tokentype_ids, - layer_past=None, - get_key_value=None, - forward_method_parallel_output=None, - prompt_length=None, - context_length=None, + model, + tokens, + position_ids, + attention_mask, + tokentype_ids, + layer_past=None, + get_key_value=None, + forward_method_parallel_output=None, + prompt_length=None, + context_length=None, ): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size @@ -839,15 +894,15 @@ def forward_step( def get_token_stream( - model, - context_tokens, - return_scores: bool = False, - prompt_length: int = None, - micro_batch_size: int = None, - bad_ids: List = None, - temperature: float = None, - topp: float = None, - topk: int = None, + model, + context_tokens, + return_scores: bool = False, + prompt_length: int = None, + micro_batch_size: int = None, + bad_ids: List = None, + temperature: float = None, + topp: float = None, + topk: int = None, ): args = get_args() tokenizer = get_tokenizer() @@ -869,7 +924,9 @@ def get_token_stream( ) context_length = context_length_tensor.min().item() - tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, micro_batch_size) + tokens, attention_mask, position_ids = get_batch( + context_tokens_tensor, micro_batch_size + ) batch_token_iterator = sample_sequence_batch( model, @@ -903,19 +960,19 @@ def switch(val1, val2, boolean): def sample_sequence_batch( - model, - context_tokens, - context_lengths, - attention_mask, - position_ids, - maxlen=None, - type_ids=None, - return_scores: bool = False, - prompt_length: int = None, - bad_ids: List = None, - temperature: float = None, - topp: float = None, - topk: int = None, + model, + context_tokens, + context_lengths, + attention_mask, + position_ids, + maxlen=None, + type_ids=None, + return_scores: bool = False, + prompt_length: int = None, + bad_ids: List = None, + temperature: float = None, + topp: float = None, + topk: int = None, ): args = get_args() tokenizer = get_tokenizer() @@ -951,12 +1008,15 @@ def sample_sequence_batch( scores = torch.zeros([batch_size]).float().cuda() if args.beam_search: - beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length], - num_beams=args.num_beams) + beams = beam_search( + model, + context_tokens=tokens.cpu().numpy().tolist()[0][:context_length], + num_beams=args.num_beams, + ) if args.beam_warmup: beam = beams[0] tokens_ = beam.tokens - tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1]) + tokens_ = tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1] tokens_warmup = [] for i in range(batch_size): tokens_warmup.append(tokens_.copy()) @@ -976,14 +1036,15 @@ def sample_sequence_batch( else: while context_length <= (maxlen): if args.recompute: - logits = model(tokens, - position_ids, - attention_mask, - tokentype_ids=type_ids, - forward_method_parallel_output=False, - prompt_length=prompt_length, - context_length=context_length, - ) + logits = model( + tokens, + position_ids, + attention_mask, + tokentype_ids=type_ids, + forward_method_parallel_output=False, + prompt_length=prompt_length, + context_length=context_length, + ) logits = logits[:, context_length - 1, :] else: types2use = None @@ -993,23 +1054,25 @@ def sample_sequence_batch( if type_ids is not None: types2use = type_ids[:, :context_length] else: - tokens2use = tokens[:, context_length - 1].view( - batch_size, -1) + tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( - batch_size, -1) + batch_size, -1 + ) if type_ids is not None: types2use = type_ids[:, context_length - 1].view( - batch_size, -1) - logits, layer_past = model(tokens2use, - positions2use, - attention_mask, - layer_past=layer_past, - get_key_value=True, - tokentype_ids=types2use, - forward_method_parallel_output=False, - prompt_length=prompt_length, - context_length=context_length, - ) + batch_size, -1 + ) + logits, layer_past = model( + tokens2use, + positions2use, + attention_mask, + layer_past=layer_past, + get_key_value=True, + tokentype_ids=types2use, + forward_method_parallel_output=False, + prompt_length=prompt_length, + context_length=context_length, + ) logits = logits[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): @@ -1029,7 +1092,9 @@ def sample_sequence_batch( started = context_lengths <= context_length - new_tokens = switch(tokens[:, context_length].view(-1), prev, started) + new_tokens = switch( + tokens[:, context_length].view(-1), prev, started + ) if not args.greedy and return_scores: indices = prev.view(-1, 1) diff --git a/codegeex/megatron/convert_ckpt_parallel.py b/codegeex/megatron/convert_ckpt_parallel.py index f1b0e50..1bbe24b 100644 --- a/codegeex/megatron/convert_ckpt_parallel.py +++ b/codegeex/megatron/convert_ckpt_parallel.py @@ -7,26 +7,26 @@ def get_change_ckpt_args(parser): """Provide extra arguments required for merging.""" - group = parser.add_argument_group(title='Mindspore to megatron') + group = parser.add_argument_group(title="Mindspore to megatron") group.add_argument( - '--load-ckpt-path', + "--load-ckpt-path", type=str, required=True, help='path to load ".pt" checkpoint.', ) group.add_argument( - '--save-ckpt-path', + "--save-ckpt-path", type=str, required=True, - help='dir to save converted checkpoints.', + help="dir to save converted checkpoints.", ) group.add_argument( - '--target-tensor-model-parallel-size', + "--target-tensor-model-parallel-size", type=int, default=2, - help='target tensor model parallel size', + help="target tensor model parallel size", ) - + return parser @@ -49,7 +49,7 @@ def main(): parser = argparse.ArgumentParser() parser = get_change_ckpt_args(parser) args, _ = parser.parse_known_args() - + print(f"Load ckpt from {args.load_ckpt_path}...") state_dict = torch.load(args.load_ckpt_path, map_location="cpu") @@ -57,12 +57,18 @@ def main(): output_state_dict = [] for i in range(args.target_tensor_model_parallel_size): output_state_dict.append({}) - + print("Converting Embedding layers...") - word_embeddings = state_dict['module']['language_model']['embedding']['word_embeddings']['weight'] - position_embeddings = state_dict['module']['language_model']['embedding']['position_embeddings']['weight'] - out_word_embeddings = torch.chunk(word_embeddings, args.target_tensor_model_parallel_size, dim=0) - + word_embeddings = state_dict["module"]["language_model"]["embedding"][ + "word_embeddings" + ]["weight"] + position_embeddings = state_dict["module"]["language_model"]["embedding"][ + "position_embeddings" + ]["weight"] + out_word_embeddings = torch.chunk( + word_embeddings, args.target_tensor_model_parallel_size, dim=0 + ) + for i in range(args.target_tensor_model_parallel_size): pos_emb_dict = get_element_from_dict_by_path( output_state_dict[i], "module.language_model.embedding.position_embeddings" @@ -73,52 +79,71 @@ def main(): output_state_dict[i], "module.language_model.embedding.word_embeddings" ) word_emb_dict["weight"] = out_word_embeddings[i].clone() - + print("Converting QueryEmbedding layers...") - query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'] - out_query_embeddings = torch.chunk(query_embeddings, args.target_tensor_model_parallel_size, dim=0) - + query_embeddings = state_dict["module"]["language_model"]["topQueryEmbedding"][ + "top_query_embeddings" + ]["weight"] + out_query_embeddings = torch.chunk( + query_embeddings, args.target_tensor_model_parallel_size, dim=0 + ) + for i in range(args.target_tensor_model_parallel_size): query_emb_dict = get_element_from_dict_by_path( - output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings" + output_state_dict[i], + "module.language_model.topQueryEmbedding.top_query_embeddings", ) query_emb_dict["weight"] = out_query_embeddings[i].clone() - + print("Converting Transformer layers...") - for layer_name in state_dict['module']['language_model']['transformer'].keys(): - params = state_dict['module']['language_model']['transformer'][layer_name] + for layer_name in state_dict["module"]["language_model"]["transformer"].keys(): + params = state_dict["module"]["language_model"]["transformer"][layer_name] if "layernorm" in layer_name: pass elif "attention" in layer_name and "weight" in layer_name: if "dense" in layer_name: - params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1) + params = torch.chunk( + params, args.target_tensor_model_parallel_size, dim=1 + ) else: - params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0) + params = torch.chunk( + params, args.target_tensor_model_parallel_size, dim=0 + ) elif "weight" in layer_name and "dense" in layer_name: if "h_to_4h" in layer_name: - params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0) + params = torch.chunk( + params, args.target_tensor_model_parallel_size, dim=0 + ) else: - params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=1) + params = torch.chunk( + params, args.target_tensor_model_parallel_size, dim=1 + ) elif "bias" in layer_name: if "dense" not in layer_name or "mlp" in layer_name: if "4h_to_h" in layer_name: pass else: - params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=0) - + params = torch.chunk( + params, args.target_tensor_model_parallel_size, dim=0 + ) + for i in range(args.target_tensor_model_parallel_size): - params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer") + params_dict = get_element_from_dict_by_path( + output_state_dict[i], "module.language_model.transformer" + ) if type(params) is tuple: params_dict[layer_name] = params[i].clone() else: params_dict[layer_name] = params - + os.makedirs(args.save_ckpt_path, exist_ok=True) for rank in range(args.target_tensor_model_parallel_size): - save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt") + save_ckpt_path = os.path.join( + args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt" + ) torch.save(output_state_dict[rank], save_ckpt_path) print(f"Converted checkpoint saved in {save_ckpt_path}.") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/codegeex/megatron/data/indexed_dataset.py b/codegeex/megatron/data/indexed_dataset.py index fe8d2a9..ed9d111 100644 --- a/codegeex/megatron/data/indexed_dataset.py +++ b/codegeex/megatron/data/indexed_dataset.py @@ -117,9 +117,7 @@ def __best_fitting_dtype(vocab_size=None): def make_mmap_builder(out_file, vocab_size=None): - return MMapIndexedDatasetBuilder( - out_file, dtype=__best_fitting_dtype(vocab_size) - ) + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) def code(dtype): diff --git a/codegeex/megatron/data/prompt_dataset.py b/codegeex/megatron/data/prompt_dataset.py index af9778e..a976814 100644 --- a/codegeex/megatron/data/prompt_dataset.py +++ b/codegeex/megatron/data/prompt_dataset.py @@ -105,13 +105,25 @@ def _build_train_valid_test_datasets( """Build train, valid, and test datasets.""" # Indexed dataset. - assert os.path.exists(data_prefix + "_input_ids.bin"), f"Input tokens datafile not found: {data_prefix}_input_ids.bin" - assert os.path.exists(data_prefix + "_attention_mask.bin"), f"Attention mask datafile not found: {data_prefix}_attention_mask.bin" - assert os.path.exists(data_prefix + "_labels.bin"), f"Labels datafile not found: {data_prefix}_labels.bin" - - input_ids_indexed_dataset = get_indexed_dataset_(data_prefix + "_input_ids", data_impl, skip_warmup) - attention_mask_indexed_dataset = get_indexed_dataset_(data_prefix + "_attention_mask", data_impl, skip_warmup) - labels_indexed_dataset = get_indexed_dataset_(data_prefix + "_labels", data_impl, skip_warmup) + assert os.path.exists( + data_prefix + "_input_ids.bin" + ), f"Input tokens datafile not found: {data_prefix}_input_ids.bin" + assert os.path.exists( + data_prefix + "_attention_mask.bin" + ), f"Attention mask datafile not found: {data_prefix}_attention_mask.bin" + assert os.path.exists( + data_prefix + "_labels.bin" + ), f"Labels datafile not found: {data_prefix}_labels.bin" + + input_ids_indexed_dataset = get_indexed_dataset_( + data_prefix + "_input_ids", data_impl, skip_warmup + ) + attention_mask_indexed_dataset = get_indexed_dataset_( + data_prefix + "_attention_mask", data_impl, skip_warmup + ) + labels_indexed_dataset = get_indexed_dataset_( + data_prefix + "_labels", data_impl, skip_warmup + ) total_num_of_documents = input_ids_indexed_dataset.sizes.shape[0] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) @@ -212,8 +224,14 @@ def __init__( # Checks assert np.min(documents) >= 0 assert np.max(documents) < input_ids_indexed_dataset.sizes.shape[0] - assert input_ids_indexed_dataset.sizes.shape[0] == attention_mask_index_dataset.sizes.shape[0] - assert attention_mask_index_dataset.sizes.shape[0] == labels_indexed_dataset.sizes.shape[0] + assert ( + input_ids_indexed_dataset.sizes.shape[0] + == attention_mask_index_dataset.sizes.shape[0] + ) + assert ( + attention_mask_index_dataset.sizes.shape[0] + == labels_indexed_dataset.sizes.shape[0] + ) # Build index mappings. self.doc_idx = _build_index_mappings( @@ -251,7 +269,13 @@ def __getitem__(self, idx): def _build_index_mappings( - name, data_prefix, documents, sizes, num_samples, seq_length, seed, + name, + data_prefix, + documents, + sizes, + num_samples, + seq_length, + seed, ): """Build index mappings. We only have to build doc-idx in prompt dataset. @@ -298,8 +322,8 @@ def _build_index_mappings( torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) assert counts[0].item() == ( - torch.distributed.get_world_size() - // torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()) + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()) ) # Load mappings. start_time = time.time() @@ -329,4 +353,4 @@ def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) - return np.concatenate((doc_idx_first, doc_idx_last)) \ No newline at end of file + return np.concatenate((doc_idx_first, doc_idx_last)) diff --git a/codegeex/megatron/global_vars.py b/codegeex/megatron/global_vars.py index 230465a..5168f54 100644 --- a/codegeex/megatron/global_vars.py +++ b/codegeex/megatron/global_vars.py @@ -112,8 +112,8 @@ def _build_num_microbatches_calculator(args): ) _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(args) - - + + def _build_tokenizer(args): """Initialize tokenizer.""" global _GLOBAL_TOKENIZER diff --git a/codegeex/megatron/inference.py b/codegeex/megatron/inference.py index 53b018f..2a30404 100644 --- a/codegeex/megatron/inference.py +++ b/codegeex/megatron/inference.py @@ -18,8 +18,7 @@ def model_provider(): """Build the model.""" - model = CodeGeeXModel(num_tokentypes=0, - parallel_output=False) + model = CodeGeeXModel(num_tokentypes=0, parallel_output=False) return model @@ -43,7 +42,9 @@ def run_generation_distributed(model): socket = context.socket(zmq.REQ) socket.connect(f"tcp://{args.channel_ip}:{args.channel_port}") output_file_path = args.output_prefix + f"_finished_rank{args.gen_rank}.jsonl" - unfinished_output_file_path = args.output_prefix + f"_unfinished_rank{args.gen_rank}.jsonl" + unfinished_output_file_path = ( + args.output_prefix + f"_unfinished_rank{args.gen_rank}.jsonl" + ) problems = {} print("Building tokenizer...") tokenizer = get_tokenizer() @@ -67,7 +68,9 @@ def run_generation_distributed(model): current_spec = resp["task_id"] prompt = current_spec["prompt"] - temperature = None if "temperature" not in resp else resp["temperature"] + temperature = ( + None if "temperature" not in resp else resp["temperature"] + ) topp = None if "topp" not in resp else resp["topp"] f.flush() @@ -83,10 +86,7 @@ def run_generation_distributed(model): if args.beam_search: beams = get_token_stream( model, - [ - copy.deepcopy(tokens) - for _ in range(micro_batch_size) - ], + [copy.deepcopy(tokens) for _ in range(micro_batch_size)], return_scores=args.return_scores, prompt_length=n_token_prompt, micro_batch_size=micro_batch_size, @@ -102,29 +102,35 @@ def run_generation_distributed(model): if generated_tokens_[-1] != tokenizer.eod else generated_tokens_[:-1] ) - generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:]) - generated_code = cleanup_code(generated_code, - language_type=language_type, - dataset=args.dataset) + generated_code = tokenizer.detokenize( + generated_tokens_[n_token_prompt:] + ) + generated_code = cleanup_code( + generated_code, + language_type=language_type, + dataset=args.dataset, + ) f.write( json.dumps( { - "task_id" : current_spec['task_id'], - "prompt" : prompt, + "task_id": current_spec["task_id"], + "prompt": prompt, "generation": generated_code, - "scores" : beam.score, - "finish" : 2 if generated_tokens[i].cpu().numpy()[ - -1] == tokenizer.eod else 1, - "output" : beam.tokens, + "scores": beam.score, + "finish": 2 + if generated_tokens[i].cpu().numpy()[-1] + == tokenizer.eod + else 1, + "output": beam.tokens, } ) + "\n" ) socket.send_json( { - "rank" : args.gen_rank, - "action" : "success", - "task_id": current_spec['task_id'] + "rank": args.gen_rank, + "action": "success", + "task_id": current_spec["task_id"], } ) socket.recv() @@ -132,10 +138,7 @@ def run_generation_distributed(model): token_stream = get_token_stream( model, - [ - copy.deepcopy(tokens) - for _ in range(micro_batch_size) - ], + [copy.deepcopy(tokens) for _ in range(micro_batch_size)], return_scores=args.return_scores, prompt_length=n_token_prompt, micro_batch_size=micro_batch_size, @@ -156,33 +159,47 @@ def run_generation_distributed(model): if is_finished[i]: continue - generated_tokens_ = generated_tokens[i].cpu().numpy().tolist() + generated_tokens_ = ( + generated_tokens[i].cpu().numpy().tolist() + ) generated_tokens_ = ( generated_tokens_ if generated_tokens_[-1] != tokenizer.eod else generated_tokens_[:-1] ) - generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:]) - if generated_tokens[i].cpu().numpy()[-1] == tokenizer.eod or \ - is_code_generation_finished( - generated_code, - language_type=language_type, - dataset=args.dataset, - ): + generated_code = tokenizer.detokenize( + generated_tokens_[n_token_prompt:] + ) + if generated_tokens[i].cpu().numpy()[ + -1 + ] == tokenizer.eod or is_code_generation_finished( + generated_code, + language_type=language_type, + dataset=args.dataset, + ): is_finished[i] = True - generated_code = cleanup_code(generated_code, - language_type=language_type, - dataset=args.dataset) + generated_code = cleanup_code( + generated_code, + language_type=language_type, + dataset=args.dataset, + ) f.write( json.dumps( { - "task_id" : current_spec['task_id'], - "prompt" : prompt, + "task_id": current_spec["task_id"], + "prompt": prompt, "generation": generated_code, - "scores" : 0.0 if scores is None else scores[i].detach().cpu().item(), - "finish" : 2 if generated_tokens[i].cpu().numpy()[ - -1] == tokenizer.eod else 1, - "output" : generated_tokens[i].cpu().numpy().tolist(), + "scores": 0.0 + if scores is None + else scores[i].detach().cpu().item(), + "finish": 2 + if generated_tokens[i].cpu().numpy()[-1] + == tokenizer.eod + else 1, + "output": generated_tokens[i] + .cpu() + .numpy() + .tolist(), } ) + "\n" @@ -196,17 +213,23 @@ def run_generation_distributed(model): for i in range(micro_batch_size): if not is_finished[i]: - generated_tokens_ = generated_tokens[i].cpu().numpy().tolist() - generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:]) + generated_tokens_ = ( + generated_tokens[i].cpu().numpy().tolist() + ) + generated_code = tokenizer.detokenize( + generated_tokens_[n_token_prompt:] + ) unfinished_f.write( json.dumps( { - "task_id" : current_spec['task_id'], - "prompt" : prompt, + "task_id": current_spec["task_id"], + "prompt": prompt, "generation": generated_code, - "scores" : 0.0 if scores is None else scores[i].detach().cpu().item(), - "finish" : 0, - "output" : generated_tokens_, + "scores": 0.0 + if scores is None + else scores[i].detach().cpu().item(), + "finish": 0, + "output": generated_tokens_, } ) + "\n" @@ -214,9 +237,9 @@ def run_generation_distributed(model): socket.send_json( { - "rank" : args.gen_rank, - "action" : "success", - "task_id": current_spec['task_id'] + "rank": args.gen_rank, + "action": "success", + "task_id": current_spec["task_id"], } ) socket.recv() @@ -226,18 +249,20 @@ def run_generation_distributed(model): print(f" error: {repr(e)}") traceback.print_exc() if args.dataset.lower() == "codecontest": - socket.send_json({ - "rank" : args.gen_rank, - "action" : "fail", - "contest_name" : current_spec.name, - "micro_batch_size": micro_batch_size - }) + socket.send_json( + { + "rank": args.gen_rank, + "action": "fail", + "contest_name": current_spec.name, + "micro_batch_size": micro_batch_size, + } + ) else: socket.send_json( { - "rank" : args.gen_rank, - "action" : "fail", - "task_id": current_spec['task_id'] + "rank": args.gen_rank, + "action": "fail", + "task_id": current_spec["task_id"], } ) socket.recv() diff --git a/codegeex/megatron/initialize.py b/codegeex/megatron/initialize.py index f8ab529..f5850b7 100644 --- a/codegeex/megatron/initialize.py +++ b/codegeex/megatron/initialize.py @@ -259,7 +259,7 @@ def _initialize_distributed(): world_size=args.world_size, rank=args.rank, init_method=init_method, - timeout=timeout + timeout=timeout, ) print(f" > (rank={args.rank}) process group initialized") diff --git a/codegeex/megatron/merge_ckpt_parallel.py b/codegeex/megatron/merge_ckpt_parallel.py index c85f40f..ae0aa17 100644 --- a/codegeex/megatron/merge_ckpt_parallel.py +++ b/codegeex/megatron/merge_ckpt_parallel.py @@ -12,31 +12,31 @@ def get_change_ckpt_args(parser): """Provide extra arguments required for merging.""" - group = parser.add_argument_group(title='Mindspore to megatron') + group = parser.add_argument_group(title="Mindspore to megatron") group.add_argument( - '--load-ckpt-path', + "--load-ckpt-path", type=str, required=True, - help='dir to load model parallel partitions.', + help="dir to load model parallel partitions.", ) group.add_argument( - '--save-ckpt-path', + "--save-ckpt-path", type=str, required=True, help='path to save ".pt" checkpoint.', ) group.add_argument( - '--save-name', + "--save-name", type=str, - help='name of checkpoint.', + help="name of checkpoint.", ) group.add_argument( - '--source-tensor-model-parallel-size', + "--source-tensor-model-parallel-size", type=int, default=2, - help='original tensor model parallel size', + help="original tensor model parallel size", ) - + return parser @@ -48,82 +48,193 @@ def main(): extra_args_provider=get_change_ckpt_args, args_defaults={ "tokenizer_type": "GPT2BPETokenizer", - "no_load_rng" : True, - "no_load_optim" : True, + "no_load_rng": True, + "no_load_optim": True, }, ) - + args = get_args() model = CodeGeeXModel() print(model.state_dict) # Save the model. sd = {} - sd['module'] = model.state_dict_for_save_checkpoint() + sd["module"] = model.state_dict_for_save_checkpoint() ensure_directory_exists(args.save_ckpt_path) - + print(f"Load ckpt from {args.load_ckpt_path}...") state_dict_list = [] for i in range(args.source_tensor_model_parallel_size): try: - state_dict_list.append(torch.load(os.path.join(args.load_ckpt_path, f"mp_rank_{i:02d}_model_states.pt"), map_location="cpu")) + state_dict_list.append( + torch.load( + os.path.join( + args.load_ckpt_path, f"mp_rank_{i:02d}_model_states.pt" + ), + map_location="cpu", + ) + ) except Exception as e: print(e) exit(0) - + print(f"Merging {len(state_dict_list)} partitions into a single ckpt...") print("Merging Embedding layers...") - vocab_parallel_size = args.make_vocab_size_divisible_by // args.source_tensor_model_parallel_size + vocab_parallel_size = ( + args.make_vocab_size_divisible_by // args.source_tensor_model_parallel_size + ) for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['embedding']['word_embeddings']['weight'][i * vocab_parallel_size : (i + 1) * vocab_parallel_size, :] = state_dict_list[i]['module']['language_model']['embedding']['word_embeddings']['weight'] - - sd['module']['language_model']['embedding']['position_embeddings']['weight'] = state_dict_list[0]['module']['language_model']['embedding']['position_embeddings']['weight'] - + sd["module"]["language_model"]["embedding"]["word_embeddings"]["weight"][ + i * vocab_parallel_size : (i + 1) * vocab_parallel_size, : + ] = state_dict_list[i]["module"]["language_model"]["embedding"][ + "word_embeddings" + ][ + "weight" + ] + + sd["module"]["language_model"]["embedding"]["position_embeddings"][ + "weight" + ] = state_dict_list[0]["module"]["language_model"]["embedding"][ + "position_embeddings" + ][ + "weight" + ] + print("Merging QueryEmbedding layers...") - query_parallel_size = args.max_position_embeddings // args.source_tensor_model_parallel_size + query_parallel_size = ( + args.max_position_embeddings // args.source_tensor_model_parallel_size + ) for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight'][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[i]['module']['language_model']['topQueryEmbedding']['top_query_embeddings'].pop('weight', None) - + sd["module"]["language_model"]["topQueryEmbedding"]["top_query_embeddings"][ + "weight" + ][i * query_parallel_size : (i + 1) * query_parallel_size, :] = state_dict_list[ + i + ][ + "module" + ][ + "language_model" + ][ + "topQueryEmbedding" + ][ + "top_query_embeddings" + ].pop( + "weight", None + ) + print("Merging Transformer layers...") - for layer_name in sd['module']['language_model']['transformer'].keys(): + for layer_name in sd["module"]["language_model"]["transformer"].keys(): if "layernorm" in layer_name: - sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][layer_name] = state_dict_list[ + 0 + ]["module"]["language_model"]["transformer"].pop(layer_name, None) elif "attention" in layer_name and "weight" in layer_name: if "dense" in layer_name: - hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size + hidden_parallel_size = ( + sd["module"]["language_model"]["transformer"][layer_name].shape[1] + // args.source_tensor_model_parallel_size + ) for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][layer_name][ + :, i * hidden_parallel_size : (i + 1) * hidden_parallel_size + ] = state_dict_list[i]["module"]["language_model"][ + "transformer" + ].pop( + layer_name, None + ) else: - hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + hidden_parallel_size = ( + sd["module"]["language_model"]["transformer"][layer_name].shape[0] + // args.source_tensor_model_parallel_size + ) for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][layer_name][ + i * hidden_parallel_size : (i + 1) * hidden_parallel_size, : + ] = state_dict_list[i]["module"]["language_model"][ + "transformer" + ].pop( + layer_name, None + ) elif "weight" in layer_name and "dense" in layer_name: if "h_to_4h" in layer_name: - hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + hidden_parallel_size = ( + sd["module"]["language_model"]["transformer"][layer_name].shape[0] + // args.source_tensor_model_parallel_size + ) for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size, :] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][layer_name][ + i * hidden_parallel_size : (i + 1) * hidden_parallel_size, : + ] = state_dict_list[i]["module"]["language_model"][ + "transformer" + ].pop( + layer_name, None + ) else: - hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[1] // args.source_tensor_model_parallel_size + hidden_parallel_size = ( + sd["module"]["language_model"]["transformer"][layer_name].shape[1] + // args.source_tensor_model_parallel_size + ) for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][:, i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][layer_name][ + :, i * hidden_parallel_size : (i + 1) * hidden_parallel_size + ] = state_dict_list[i]["module"]["language_model"][ + "transformer" + ].pop( + layer_name, None + ) elif "bias" in layer_name: if "mlp" in layer_name: if "4h_to_h" in layer_name: - sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][ + layer_name + ] = state_dict_list[0]["module"]["language_model"][ + "transformer" + ].pop( + layer_name, None + ) else: - hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + hidden_parallel_size = ( + sd["module"]["language_model"]["transformer"][layer_name].shape[ + 0 + ] + // args.source_tensor_model_parallel_size + ) for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][layer_name][ + i * hidden_parallel_size : (i + 1) * hidden_parallel_size + ] = state_dict_list[i]["module"]["language_model"][ + "transformer" + ].pop( + layer_name, None + ) elif "attention" in layer_name: if "dense" in layer_name: - sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][ + layer_name + ] = state_dict_list[0]["module"]["language_model"][ + "transformer" + ].pop( + layer_name, None + ) else: - hidden_parallel_size = sd['module']['language_model']['transformer'][layer_name].shape[0] // args.source_tensor_model_parallel_size + hidden_parallel_size = ( + sd["module"]["language_model"]["transformer"][layer_name].shape[ + 0 + ] + // args.source_tensor_model_parallel_size + ) for i in range(args.source_tensor_model_parallel_size): - sd['module']['language_model']['transformer'][layer_name][i * hidden_parallel_size : (i + 1) * hidden_parallel_size] = state_dict_list[i]['module']['language_model']['transformer'].pop(layer_name, None) + sd["module"]["language_model"]["transformer"][layer_name][ + i * hidden_parallel_size : (i + 1) * hidden_parallel_size + ] = state_dict_list[i]["module"]["language_model"][ + "transformer" + ].pop( + layer_name, None + ) else: - sd['module']['language_model']['transformer'][layer_name] = state_dict_list[0]['module']['language_model']['transformer'].pop(layer_name, None) - + sd["module"]["language_model"]["transformer"][layer_name] = state_dict_list[ + 0 + ]["module"]["language_model"]["transformer"].pop(layer_name, None) + if args.save_ckpt_path.endswith(".pt"): save_ckpt_path = args.save_ckpt_path else: @@ -131,11 +242,13 @@ def main(): if args.save_name: save_ckpt_path = os.path.join(args.save_ckpt_path, args.save_name) else: - save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt") - + save_ckpt_path = os.path.join( + args.save_ckpt_path, "mp_rank_00_model_states.pt" + ) + torch.save(sd, save_ckpt_path) print(f"Converted checkpoint saved in {save_ckpt_path}.") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/codegeex/megatron/mindspore_to_megatron.py b/codegeex/megatron/mindspore_to_megatron.py index da66c79..176186d 100644 --- a/codegeex/megatron/mindspore_to_megatron.py +++ b/codegeex/megatron/mindspore_to_megatron.py @@ -22,8 +22,9 @@ import numpy as np import torch -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.pardir))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) +) from codegeex.megatron import get_args from codegeex.megatron.model import CodeGeeXModel @@ -33,18 +34,18 @@ def get_change_ckpt_args(parser): """Provide extra arguments required for merging.""" - group = parser.add_argument_group(title='Mindspore to megatron') + group = parser.add_argument_group(title="Mindspore to megatron") group.add_argument( - '--npy-ckpt-path', + "--npy-ckpt-path", type=str, required=True, - help='path of npy checkpoint.', + help="path of npy checkpoint.", ) group.add_argument( - '--save-ckpt-path', + "--save-ckpt-path", type=str, required=True, - help='path to save checkpoint.', + help="path to save checkpoint.", ) return parser @@ -53,238 +54,248 @@ def get_change_ckpt_args(parser): def loadModelFromNp(sd, args): num_layers = args.num_layers npCkptPath = args.npy_ckpt_path - languageModel = sd['module']['language_model'] + languageModel = sd["module"]["language_model"] loadEmbeddingFromNp(npCkptPath, languageModel) - transformer = sd['module']['language_model']['transformer'] + transformer = sd["module"]["language_model"]["transformer"] for layerID in range(num_layers): loadAttentionLayerFromNp(npCkptPath, transformer, layerID) loadQueryLayerFromNp(npCkptPath, transformer) - transformer['final_layernorm.weight'][:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.layernorm.gamma.npy') - ).float() - transformer['final_layernorm.bias'][:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.layernorm.beta.npy') - ).float() + transformer["final_layernorm.weight"][:] = torch.tensor( + np.load(npCkptPath + f"backbone.layernorm.gamma.npy") + ).float() + transformer["final_layernorm.bias"][:] = torch.tensor( + np.load(npCkptPath + f"backbone.layernorm.beta.npy") + ).float() def loadEmbeddingFromNp(npCkptPath, languageModel, vocabSize=52224): - word_embedding_np = \ - np.load(npCkptPath + 'backbone.embedding.word_embedding.embedding_table.npy') - languageModel['embedding']['word_embeddings']['weight'][:vocabSize, :] = \ - torch.tensor(word_embedding_np).float() + word_embedding_np = np.load( + npCkptPath + "backbone.embedding.word_embedding.embedding_table.npy" + ) + languageModel["embedding"]["word_embeddings"]["weight"][ + :vocabSize, : + ] = torch.tensor(word_embedding_np).float() - position_embeddings_np = \ - np.load(npCkptPath + 'backbone.embedding.position_embedding.embedding_table.npy') - languageModel['embedding']['position_embeddings']['weight'][:, :] = \ - torch.tensor(position_embeddings_np).float() + position_embeddings_np = np.load( + npCkptPath + "backbone.embedding.position_embedding.embedding_table.npy" + ) + languageModel["embedding"]["position_embeddings"]["weight"][:, :] = torch.tensor( + position_embeddings_np + ).float() - topQueryEmbedding_np = \ - np.load(npCkptPath + 'backbone.top_query_embedding.embedding_table.npy') - languageModel['topQueryEmbedding']['top_query_embeddings']['weight'][:, :] = \ - torch.tensor(topQueryEmbedding_np).float() + topQueryEmbedding_np = np.load( + npCkptPath + "backbone.top_query_embedding.embedding_table.npy" + ) + languageModel["topQueryEmbedding"]["top_query_embeddings"]["weight"][ + :, : + ] = torch.tensor(topQueryEmbedding_np).float() def loadAttentionLayerFromNp(npCkptPath, transformer, layerID): - attention_dense1_weight_np = \ - np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense1.weight.npy') - attention_dense2_weight_np = \ - np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense2.weight.npy') - attention_dense3_weight_np = \ - np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense3.weight.npy') - - attention_dense1_bias_np = \ - np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense1.bias.npy') - attention_dense2_bias_np = \ - np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense2.bias.npy') - attention_dense3_bias_np = \ - np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.dense3.bias.npy') - - query_weight = transformer[f'layers.{layerID}.attention.query.weight'] - key_weight = transformer[f'layers.{layerID}.attention.key.weight'] - value_weight = transformer[f'layers.{layerID}.attention.value.weight'] + attention_dense1_weight_np = np.load( + npCkptPath + f"backbone.blocks.{layerID}.attention.dense1.weight.npy" + ) + attention_dense2_weight_np = np.load( + npCkptPath + f"backbone.blocks.{layerID}.attention.dense2.weight.npy" + ) + attention_dense3_weight_np = np.load( + npCkptPath + f"backbone.blocks.{layerID}.attention.dense3.weight.npy" + ) + + attention_dense1_bias_np = np.load( + npCkptPath + f"backbone.blocks.{layerID}.attention.dense1.bias.npy" + ) + attention_dense2_bias_np = np.load( + npCkptPath + f"backbone.blocks.{layerID}.attention.dense2.bias.npy" + ) + attention_dense3_bias_np = np.load( + npCkptPath + f"backbone.blocks.{layerID}.attention.dense3.bias.npy" + ) + + query_weight = transformer[f"layers.{layerID}.attention.query.weight"] + key_weight = transformer[f"layers.{layerID}.attention.key.weight"] + value_weight = transformer[f"layers.{layerID}.attention.value.weight"] query_weight[:] = torch.tensor(attention_dense1_weight_np).float() key_weight[:] = torch.tensor(attention_dense2_weight_np).float() value_weight[:] = torch.tensor(attention_dense3_weight_np).float() - query_bias = transformer[f'layers.{layerID}.attention.query.bias'] - key_bias = transformer[f'layers.{layerID}.attention.key.bias'] - value_bias = transformer[f'layers.{layerID}.attention.value.bias'] + query_bias = transformer[f"layers.{layerID}.attention.query.bias"] + key_bias = transformer[f"layers.{layerID}.attention.key.bias"] + value_bias = transformer[f"layers.{layerID}.attention.value.bias"] query_bias[:] = torch.tensor(attention_dense1_bias_np).float() key_bias[:] = torch.tensor(attention_dense2_bias_np).float() value_bias[:] = torch.tensor(attention_dense3_bias_np).float() - att_dense_weight = transformer[f'layers.{layerID}.attention.dense.weight'] - att_dense_weight[:, :] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.projection.weight.npy').transpose() - ).float() - att_dense_bias = transformer[f'layers.{layerID}.attention.dense.bias'] - att_dense_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.attention.projection.bias.npy') - ).float() - - mlp_dense_h_to_4h_weight = transformer[f'layers.{layerID}.mlp.dense_h_to_4h.weight'] - mlp_dense_h_to_4h_weight[:, :] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.output.mapping.weight.npy').transpose() - ).float() - mlp_dense_h_to_4h_bias = transformer[f'layers.{layerID}.mlp.dense_h_to_4h.bias'] - mlp_dense_h_to_4h_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.output.mapping.bias.npy') - ).float() - - mlp_dense_4h_to_h_weight = transformer[f'layers.{layerID}.mlp.dense_4h_to_h.weight'] - mlp_dense_4h_to_h_weight[:, :] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.output.projection.weight.npy').transpose() - ).float() - mlp_dense_4h_to_h_bias = transformer[f'layers.{layerID}.mlp.dense_4h_to_h.bias'] - mlp_dense_4h_to_h_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.output.projection.bias.npy') - ).float() - - input_layernorm_weight = transformer[f'layers.{layerID}.input_layernorm.weight'] - input_layernorm_weight[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.gamma.npy') - ).float() - input_layernorm_bias = transformer[f'layers.{layerID}.input_layernorm.bias'] - input_layernorm_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.beta.npy') - ).float() - - post_attention_layernorm_weight = transformer[f'layers.{layerID}.post_attention_layernorm.weight'] - post_attention_layernorm_weight[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.gamma.npy') - ).float() - post_attention_layernorm_bias = transformer[f'layers.{layerID}.post_attention_layernorm.bias'] - post_attention_layernorm_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.beta.npy') - ).float() - - input_layernorm_weight = transformer[f'layers.{layerID}.input_layernorm.weight'] - input_layernorm_weight[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.gamma.npy') - ).float() - input_layernorm_bias = transformer[f'layers.{layerID}.input_layernorm.bias'] - input_layernorm_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm1.beta.npy') - ).float() - - post_attention_layernorm_weight = transformer[f'layers.{layerID}.post_attention_layernorm.weight'] - post_attention_layernorm_weight[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.gamma.npy') - ).float() - post_attention_layernorm_bias = transformer[f'layers.{layerID}.post_attention_layernorm.bias'] - post_attention_layernorm_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.blocks.{layerID}.layernorm2.beta.npy') - ).float() + att_dense_weight = transformer[f"layers.{layerID}.attention.dense.weight"] + att_dense_weight[:, :] = torch.tensor( + np.load( + npCkptPath + f"backbone.blocks.{layerID}.attention.projection.weight.npy" + ).transpose() + ).float() + att_dense_bias = transformer[f"layers.{layerID}.attention.dense.bias"] + att_dense_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.attention.projection.bias.npy") + ).float() + + mlp_dense_h_to_4h_weight = transformer[f"layers.{layerID}.mlp.dense_h_to_4h.weight"] + mlp_dense_h_to_4h_weight[:, :] = torch.tensor( + np.load( + npCkptPath + f"backbone.blocks.{layerID}.output.mapping.weight.npy" + ).transpose() + ).float() + mlp_dense_h_to_4h_bias = transformer[f"layers.{layerID}.mlp.dense_h_to_4h.bias"] + mlp_dense_h_to_4h_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.output.mapping.bias.npy") + ).float() + + mlp_dense_4h_to_h_weight = transformer[f"layers.{layerID}.mlp.dense_4h_to_h.weight"] + mlp_dense_4h_to_h_weight[:, :] = torch.tensor( + np.load( + npCkptPath + f"backbone.blocks.{layerID}.output.projection.weight.npy" + ).transpose() + ).float() + mlp_dense_4h_to_h_bias = transformer[f"layers.{layerID}.mlp.dense_4h_to_h.bias"] + mlp_dense_4h_to_h_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.output.projection.bias.npy") + ).float() + + input_layernorm_weight = transformer[f"layers.{layerID}.input_layernorm.weight"] + input_layernorm_weight[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.layernorm1.gamma.npy") + ).float() + input_layernorm_bias = transformer[f"layers.{layerID}.input_layernorm.bias"] + input_layernorm_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.layernorm1.beta.npy") + ).float() + + post_attention_layernorm_weight = transformer[ + f"layers.{layerID}.post_attention_layernorm.weight" + ] + post_attention_layernorm_weight[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.layernorm2.gamma.npy") + ).float() + post_attention_layernorm_bias = transformer[ + f"layers.{layerID}.post_attention_layernorm.bias" + ] + post_attention_layernorm_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.layernorm2.beta.npy") + ).float() + + input_layernorm_weight = transformer[f"layers.{layerID}.input_layernorm.weight"] + input_layernorm_weight[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.layernorm1.gamma.npy") + ).float() + input_layernorm_bias = transformer[f"layers.{layerID}.input_layernorm.bias"] + input_layernorm_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.layernorm1.beta.npy") + ).float() + + post_attention_layernorm_weight = transformer[ + f"layers.{layerID}.post_attention_layernorm.weight" + ] + post_attention_layernorm_weight[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.layernorm2.gamma.npy") + ).float() + post_attention_layernorm_bias = transformer[ + f"layers.{layerID}.post_attention_layernorm.bias" + ] + post_attention_layernorm_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.blocks.{layerID}.layernorm2.beta.npy") + ).float() def loadQueryLayerFromNp(npCkptPath, transformer): - attention_dense1_weight_np = \ - np.load(npCkptPath + f'backbone.top_query_layer.attention.dense1.weight.npy') - attention_dense1_bias_np = \ - np.load(npCkptPath + f'backbone.top_query_layer.attention.dense1.bias.npy') - attention_dense2_weight_np = \ - np.load(npCkptPath + f'backbone.top_query_layer.attention.dense2.weight.npy') - attention_dense2_bias_np = \ - np.load(npCkptPath + f'backbone.top_query_layer.attention.dense2.bias.npy') - attention_dense3_weight_np = \ - np.load(npCkptPath + f'backbone.top_query_layer.attention.dense3.weight.npy') - attention_dense3_bias_np = \ - np.load(npCkptPath + f'backbone.top_query_layer.attention.dense3.bias.npy') - - query_weight = transformer[f'topQueryLayer.attention.query.weight'] - query_weight[:, :] = \ - torch.tensor(attention_dense1_weight_np).float() - query_bias = transformer[f'topQueryLayer.attention.query.bias'] + attention_dense1_weight_np = np.load( + npCkptPath + f"backbone.top_query_layer.attention.dense1.weight.npy" + ) + attention_dense1_bias_np = np.load( + npCkptPath + f"backbone.top_query_layer.attention.dense1.bias.npy" + ) + attention_dense2_weight_np = np.load( + npCkptPath + f"backbone.top_query_layer.attention.dense2.weight.npy" + ) + attention_dense2_bias_np = np.load( + npCkptPath + f"backbone.top_query_layer.attention.dense2.bias.npy" + ) + attention_dense3_weight_np = np.load( + npCkptPath + f"backbone.top_query_layer.attention.dense3.weight.npy" + ) + attention_dense3_bias_np = np.load( + npCkptPath + f"backbone.top_query_layer.attention.dense3.bias.npy" + ) + + query_weight = transformer[f"topQueryLayer.attention.query.weight"] + query_weight[:, :] = torch.tensor(attention_dense1_weight_np).float() + query_bias = transformer[f"topQueryLayer.attention.query.bias"] query_bias[:] = torch.tensor(attention_dense1_bias_np).float() - key_weight = transformer[f'topQueryLayer.attention.key.weight'] - key_weight[:, :] = \ - torch.tensor(attention_dense2_weight_np).float() - key_bias = transformer[f'topQueryLayer.attention.key.bias'] + key_weight = transformer[f"topQueryLayer.attention.key.weight"] + key_weight[:, :] = torch.tensor(attention_dense2_weight_np).float() + key_bias = transformer[f"topQueryLayer.attention.key.bias"] key_bias[:] = torch.tensor(attention_dense2_bias_np).float() - value_weight = transformer[f'topQueryLayer.attention.value.weight'] - value_weight[:, :] = \ - torch.tensor(attention_dense3_weight_np).float() - value_bias = transformer[f'topQueryLayer.attention.value.bias'] + value_weight = transformer[f"topQueryLayer.attention.value.weight"] + value_weight[:, :] = torch.tensor(attention_dense3_weight_np).float() + value_bias = transformer[f"topQueryLayer.attention.value.bias"] value_bias[:] = torch.tensor(attention_dense3_bias_np).float() - att_dense_weight = transformer[f'topQueryLayer.attention.dense.weight'] - att_dense_weight[:, :] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.attention.projection.weight.npy') - .transpose() - ).float() - att_dense_bias = transformer[f'topQueryLayer.attention.dense.bias'] - att_dense_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.attention.projection.bias.npy') - ).float() - - mlp_dense_h_to_4h_weight = transformer[f'topQueryLayer.mlp.dense_h_to_4h.weight'] - mlp_dense_h_to_4h_weight[:, :] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.output.mapping.weight.npy') - .transpose() - ).float() - mlp_dense_h_to_4h_bias = transformer[f'topQueryLayer.mlp.dense_h_to_4h.bias'] - mlp_dense_h_to_4h_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.output.mapping.bias.npy') - ).float() - - mlp_dense_4h_to_h_weight = transformer[f'topQueryLayer.mlp.dense_4h_to_h.weight'] - mlp_dense_4h_to_h_weight[:, :] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.output.projection.weight.npy') - .transpose() - ).float() - mlp_dense_4h_to_h_bias = transformer[f'topQueryLayer.mlp.dense_4h_to_h.bias'] - mlp_dense_4h_to_h_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.output.projection.bias.npy') - ).float() - - input_layernorm_weight = transformer[f'topQueryLayer.input_layernorm.weight'] - input_layernorm_weight[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.layernorm1.gamma.npy') - ).float() - input_layernorm_bias = transformer[f'topQueryLayer.input_layernorm.bias'] - input_layernorm_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.layernorm1.beta.npy') - ).float() - - post_attention_layernorm_weight = transformer[f'topQueryLayer.post_attention_layernorm.weight'] - post_attention_layernorm_weight[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.layernorm2.gamma.npy') - ).float() - post_attention_layernorm_bias = transformer[f'topQueryLayer.post_attention_layernorm.bias'] - post_attention_layernorm_bias[:] = \ - torch.tensor( - np.load(npCkptPath + f'backbone.top_query_layer.layernorm2.beta.npy') - ).float() + att_dense_weight = transformer[f"topQueryLayer.attention.dense.weight"] + att_dense_weight[:, :] = torch.tensor( + np.load( + npCkptPath + f"backbone.top_query_layer.attention.projection.weight.npy" + ).transpose() + ).float() + att_dense_bias = transformer[f"topQueryLayer.attention.dense.bias"] + att_dense_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.top_query_layer.attention.projection.bias.npy") + ).float() + + mlp_dense_h_to_4h_weight = transformer[f"topQueryLayer.mlp.dense_h_to_4h.weight"] + mlp_dense_h_to_4h_weight[:, :] = torch.tensor( + np.load( + npCkptPath + f"backbone.top_query_layer.output.mapping.weight.npy" + ).transpose() + ).float() + mlp_dense_h_to_4h_bias = transformer[f"topQueryLayer.mlp.dense_h_to_4h.bias"] + mlp_dense_h_to_4h_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.top_query_layer.output.mapping.bias.npy") + ).float() + + mlp_dense_4h_to_h_weight = transformer[f"topQueryLayer.mlp.dense_4h_to_h.weight"] + mlp_dense_4h_to_h_weight[:, :] = torch.tensor( + np.load( + npCkptPath + f"backbone.top_query_layer.output.projection.weight.npy" + ).transpose() + ).float() + mlp_dense_4h_to_h_bias = transformer[f"topQueryLayer.mlp.dense_4h_to_h.bias"] + mlp_dense_4h_to_h_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.top_query_layer.output.projection.bias.npy") + ).float() + + input_layernorm_weight = transformer[f"topQueryLayer.input_layernorm.weight"] + input_layernorm_weight[:] = torch.tensor( + np.load(npCkptPath + f"backbone.top_query_layer.layernorm1.gamma.npy") + ).float() + input_layernorm_bias = transformer[f"topQueryLayer.input_layernorm.bias"] + input_layernorm_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.top_query_layer.layernorm1.beta.npy") + ).float() + + post_attention_layernorm_weight = transformer[ + f"topQueryLayer.post_attention_layernorm.weight" + ] + post_attention_layernorm_weight[:] = torch.tensor( + np.load(npCkptPath + f"backbone.top_query_layer.layernorm2.gamma.npy") + ).float() + post_attention_layernorm_bias = transformer[ + f"topQueryLayer.post_attention_layernorm.bias" + ] + post_attention_layernorm_bias[:] = torch.tensor( + np.load(npCkptPath + f"backbone.top_query_layer.layernorm2.beta.npy") + ).float() def main(): @@ -295,8 +306,8 @@ def main(): extra_args_provider=get_change_ckpt_args, args_defaults={ "tokenizer_type": "GPT2BPETokenizer", - "no_load_rng" : True, - "no_load_optim" : True, + "no_load_rng": True, + "no_load_optim": True, }, ) @@ -307,13 +318,13 @@ def main(): # Save the model. sd = {} - sd['module'] = model.state_dict_for_save_checkpoint() + sd["module"] = model.state_dict_for_save_checkpoint() ensure_directory_exists(args.save_ckpt_path) loadModelFromNp(sd, args) - print('> saving merged model to {}'.format(args.save_ckpt_path)) + print("> saving merged model to {}".format(args.save_ckpt_path)) torch.save(sd, args.save_ckpt_path) print(f"Converted checkpoint saved in {args.save_ckpt_path}.") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/codegeex/megatron/model/__init__.py b/codegeex/megatron/model/__init__.py index f4cc049..cd03b5b 100644 --- a/codegeex/megatron/model/__init__.py +++ b/codegeex/megatron/model/__init__.py @@ -17,4 +17,4 @@ from .distributed import DistributedDataParallel from .codegeex_model import CodeGeeXModel from .language_model import get_language_model -from .module import Float16Module \ No newline at end of file +from .module import Float16Module diff --git a/codegeex/megatron/model/codegeex_model.py b/codegeex/megatron/model/codegeex_model.py index c41168e..0e998b5 100644 --- a/codegeex/megatron/model/codegeex_model.py +++ b/codegeex/megatron/model/codegeex_model.py @@ -18,9 +18,17 @@ from codegeex.megatron.model import LayerNorm from codegeex.megatron.enums import AttnMaskType from codegeex.megatron.model.module import MegatronModule -from codegeex.megatron.model.language_model import parallel_lm_logits, get_language_model, EmbeddingPipe, QueryEmbeddingPipe +from codegeex.megatron.model.language_model import ( + parallel_lm_logits, + get_language_model, + EmbeddingPipe, + QueryEmbeddingPipe, +) from codegeex.megatron.model.utils import init_method_normal, scaled_init_method_normal -from codegeex.megatron.model.transformer import ParallelTransformerLayerPipe, ParallelTopQueryLayerPipe +from codegeex.megatron.model.transformer import ( + ParallelTransformerLayerPipe, + ParallelTopQueryLayerPipe, +) from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec @@ -38,36 +46,40 @@ def __init__(self, num_tokentypes=0, parallel_output=False): num_tokentypes=num_tokentypes, add_pooler=False, init_method=init_method_normal(args.init_method_std), - scaled_init_method=scaled_init_method_normal(args.init_method_std, - args.num_layers)) + scaled_init_method=scaled_init_method_normal( + args.init_method_std, args.num_layers + ), + ) def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) - + def forward( - self, - input_ids, - position_ids, - attention_mask, - labels=None, - tokentype_ids=None, - layer_past=None, - get_key_value=False, - forward_method_parallel_output=None, - prompt_length=None, - context_length=None, + self, + input_ids, + position_ids, + attention_mask, + labels=None, + tokentype_ids=None, + layer_past=None, + get_key_value=False, + forward_method_parallel_output=None, + prompt_length=None, + context_length=None, ): # Language model. - lm_output = self.language_model(input_ids, - position_ids, - attention_mask, - tokentype_ids=tokentype_ids, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + tokentype_ids=tokentype_ids, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: lm_output, presents = lm_output @@ -80,7 +92,8 @@ def forward( output = parallel_lm_logits( lm_output, self.language_model.embedding.word_embeddings.weight, - parallel_output) + parallel_output, + ) if get_key_value: output = [output, presents] @@ -96,13 +109,16 @@ def forward( return loss - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[ + self._language_model_key + ] = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) return state_dict_ def load_state_dict(self, state_dict, strict=True): @@ -134,7 +150,7 @@ def __init__(self, num_tokentypes=0, parallel_output=True): init_method = init_method_normal(args.init_method_std) self.specs = [] - + # Embedding layer self.specs.append( TiedLayerSpec( diff --git a/codegeex/megatron/model/distributed.py b/codegeex/megatron/model/distributed.py index 079aac7..df7f083 100644 --- a/codegeex/megatron/model/distributed.py +++ b/codegeex/megatron/model/distributed.py @@ -67,7 +67,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): return self.module.state_dict(destination, prefix, keep_vars) def state_dict_for_save_checkpoint( - self, destination=None, prefix="", keep_vars=False + self, destination=None, prefix="", keep_vars=False ): return self.module.state_dict_for_save_checkpoint( destination, prefix, keep_vars @@ -94,7 +94,7 @@ class DistributedDataParallel(DistributedDataParallelBase): """ def __init__( - self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers + self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers ): super(DistributedDataParallel, self).__init__(module) @@ -128,7 +128,7 @@ def _get_buffer_type(param): if param.requires_grad: dtype = _get_buffer_type(param) type_num_elements[dtype] = ( - type_num_elements.get(dtype, 0) + param.data.nelement() + type_num_elements.get(dtype, 0) + param.data.nelement() ) # Allocate the buffer. @@ -210,6 +210,6 @@ def allreduce_gradients(self): coalesced, group=mpu.get_data_parallel_group() ) for buf, synced in zip( - grads, _unflatten_dense_tensors(coalesced, grads) + grads, _unflatten_dense_tensors(coalesced, grads) ): buf.copy_(synced) diff --git a/codegeex/megatron/model/language_model.py b/codegeex/megatron/model/language_model.py index e0b5a9c..aa0b12b 100644 --- a/codegeex/megatron/model/language_model.py +++ b/codegeex/megatron/model/language_model.py @@ -41,7 +41,7 @@ def get_shrink_embedding_gradient_alpha(iteration): return 1.0 else: return alpha + (1 - alpha) * (args.iteration - x1) / x2 - + def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" @@ -50,14 +50,17 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non # Matrix multiply. args = get_args() if args.shrink_logit_embedding_gradient: - if hasattr(args, 'iteration'): + if hasattr(args, "iteration"): alpha = get_shrink_embedding_gradient_alpha(args.iteration + 1) else: alpha = args.shrink_embedding_gradient_alpha - word_embeddings_weight = word_embeddings_weight if alpha == 1.0 \ + word_embeddings_weight = ( + word_embeddings_weight + if alpha == 1.0 else ( - word_embeddings_weight * alpha + - word_embeddings_weight.detach() * (1 - alpha) + word_embeddings_weight * alpha + + word_embeddings_weight.detach() * (1 - alpha) + ) ) if bias is None: logits_parallel = F.linear(input_parallel, word_embeddings_weight.half()) @@ -71,10 +74,10 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non def get_language_model( - num_tokentypes, - add_pooler, - init_method=None, - scaled_init_method=None, + num_tokentypes, + add_pooler, + init_method=None, + scaled_init_method=None, ): """Build language model and return along with the key to save.""" args = get_args() @@ -83,16 +86,19 @@ def get_language_model( init_method = init_method_normal(args.init_method_std) if scaled_init_method is None: - scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) + scaled_init_method = scaled_init_method_normal( + args.init_method_std, args.num_layers + ) # Language model. language_model = TransformerLanguageModel( init_method=init_method, output_layer_init_method=scaled_init_method, num_tokentypes=num_tokentypes, - add_pooler=add_pooler) + add_pooler=add_pooler, + ) # key used for checkpoints. - language_model_key = 'language_model' + language_model_key = "language_model" return language_model, language_model_key @@ -121,27 +127,29 @@ def __init__( num_tokentypes=0, ): super(Embedding, self).__init__() - + args = get_args() - + self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes self.max_sequence_length = max_sequence_length - + # Word embeddings (parallel). self.word_embeddings = mpu.VocabParallelEmbedding( - vocab_size, self.hidden_size, init_method=self.init_method) - self._word_embeddings_key = 'word_embeddings' - + vocab_size, self.hidden_size, init_method=self.init_method + ) + self._word_embeddings_key = "word_embeddings" + self.vocab_size = vocab_size # Position embedding (serial). self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) + max_sequence_length, self.hidden_size + ) self.position_embeddings = self.position_embeddings.half() - self._position_embeddings_key = 'position_embeddings' - + self._position_embeddings_key = "position_embeddings" + # Initialize the position embeddings. self.init_method(self.position_embeddings.weight) @@ -149,10 +157,11 @@ def __init__( # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' + self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, - self.hidden_size) + self.tokentype_embeddings = torch.nn.Embedding( + self.num_tokentypes, self.hidden_size + ) # Initialize the token-type embeddings. self.init_method(self.tokentype_embeddings.weight) else: @@ -167,13 +176,13 @@ def add_tokentype_embeddings(self, num_tokentypes): This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') + raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), - flush=True) + print( + "adding embedding for {} tokentypes".format(num_tokentypes), flush=True + ) self.num_tokentypes = num_tokentypes - self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, - self.hidden_size) + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. self.init_method(self.tokentype_embeddings.weight) @@ -194,20 +203,24 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): return embeddings def state_dict_for_save_checkpoint( - self, destination=None, prefix='', keep_vars=False, + self, + destination=None, + prefix="", + keep_vars=False, ): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict( + destination, prefix, keep_vars + ) + state_dict_[ + self._position_embeddings_key + ] = self.position_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[ + self._tokentype_embeddings_key + ] = self.tokentype_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ @@ -221,11 +234,12 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] - vocab_len = state_dict_['weight'].shape[0] - state_dict_["weight"] = state_dict_["weight"][:self.vocab_size // get_tensor_model_parallel_world_size()] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] + vocab_len = state_dict_["weight"].shape[0] + state_dict_["weight"] = state_dict_["weight"][ + : self.vocab_size // get_tensor_model_parallel_world_size() + ] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -235,18 +249,20 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] - - pos_len = state_dict_['weight'].shape[0] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] + + pos_len = state_dict_["weight"].shape[0] max_seq_len = self.max_sequence_length if pos_len < max_seq_len: print_rank_0(f"Position embedding padded {pos_len} -> {max_seq_len}.") position_embeddings_padded = torch.nn.Embedding( - max_seq_len - pos_len, self.hidden_size).half() + max_seq_len - pos_len, self.hidden_size + ).half() self.init_method(position_embeddings_padded.weight) - state_dict_['weight'] = torch.cat([state_dict_['weight'], position_embeddings_padded.weight], dim=0) + state_dict_["weight"] = torch.cat( + [state_dict_["weight"], position_embeddings_padded.weight], dim=0 + ) # self.position_embeddings = self.position_embeddings.half() self.position_embeddings.load_state_dict(state_dict_, strict=strict) @@ -259,15 +275,18 @@ def load_state_dict(self, state_dict, strict=True): else: # for backward compatibility. for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] + if "tokentype_embeddings" in key: + state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[ + key + ] if len(state_dict_.keys()) > 0: - self.tokentype_embeddings.load_state_dict(state_dict_, - strict=strict) + self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', flush=True) + print( + "***WARNING*** expected tokentype embeddings in the " + "checkpoint but could not find it", + flush=True, + ) class EmbeddingPipe(Embedding): @@ -318,26 +337,29 @@ class QueryEmbedding(MegatronModule): will ignore this embedding """ - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - init_method, - num_tokentypes=0): + def __init__( + self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + init_method, + num_tokentypes=0, + ): super(QueryEmbedding, self).__init__() self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes self.max_sequence_length = max_sequence_length - + # Top query position embedding (serial). self.top_query_embeddings = mpu.VocabParallelEmbedding( - max_sequence_length, self.hidden_size, init_method=self.init_method) + max_sequence_length, self.hidden_size, init_method=self.init_method + ) self.top_query_embeddings = self.top_query_embeddings.half() - self._top_query_embeddings_key = 'top_query_embeddings' - + self._top_query_embeddings_key = "top_query_embeddings" + # Initialize the top query position embeddings. self.init_method(self.top_query_embeddings.weight) @@ -345,10 +367,11 @@ def __init__(self, # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' + self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, - self.hidden_size) + self.tokentype_embeddings = torch.nn.Embedding( + self.num_tokentypes, self.hidden_size + ) # Initialize the token-type embeddings. self.init_method(self.tokentype_embeddings.weight) else: @@ -363,13 +386,13 @@ def add_tokentype_embeddings(self, num_tokentypes): This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') + raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), - flush=True) + print( + "adding embedding for {} tokentypes".format(num_tokentypes), flush=True + ) self.num_tokentypes = num_tokentypes - self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, - self.hidden_size) + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. self.init_method(self.tokentype_embeddings.weight) @@ -388,18 +411,19 @@ def forward(self, position_ids, tokentype_ids=None): return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._top_query_embeddings_key] \ - = self.top_query_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[ + self._top_query_embeddings_key + ] = self.top_query_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[ + self._tokentype_embeddings_key + ] = self.tokentype_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ @@ -413,17 +437,19 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'top_query_embeddings' in key: - state_dict_[key.split('top_query_embeddings.')[1]] \ - = state_dict[key] - pos_len = state_dict_['weight'].shape[0] + if "top_query_embeddings" in key: + state_dict_[key.split("top_query_embeddings.")[1]] = state_dict[key] + pos_len = state_dict_["weight"].shape[0] max_seq_len = self.max_sequence_length // get_tensor_model_parallel_world_size() if pos_len < max_seq_len: print_rank_0(f"Top query embedding padded {pos_len} -> {max_seq_len}.") top_query_embeddings_padded = torch.nn.Embedding( - max_seq_len - pos_len, self.hidden_size).half() + max_seq_len - pos_len, self.hidden_size + ).half() self.init_method(top_query_embeddings_padded.weight) - state_dict_['weight'] = torch.cat([state_dict_['weight'], top_query_embeddings_padded.weight], dim=0) + state_dict_["weight"] = torch.cat( + [state_dict_["weight"], top_query_embeddings_padded.weight], dim=0 + ) self.top_query_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. @@ -434,15 +460,18 @@ def load_state_dict(self, state_dict, strict=True): else: # for backward compatibility. for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] + if "tokentype_embeddings" in key: + state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[ + key + ] if len(state_dict_.keys()) > 0: - self.tokentype_embeddings.load_state_dict(state_dict_, - strict=strict) + self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', flush=True) + print( + "***WARNING*** expected tokentype embeddings in the " + "checkpoint but could not find it", + flush=True, + ) class QueryEmbeddingPipe(QueryEmbedding): @@ -462,7 +491,8 @@ def forward(self, inputs, **kwargs): tokentype_ids = None embeddings = super().forward( - position_ids, tokentype_ids=tokentype_ids, + position_ids, + tokentype_ids=tokentype_ids, ) # If cmd args has attn_mask, we don't forward it as an activation. @@ -476,8 +506,8 @@ def forward(self, inputs, **kwargs): def word_embeddings_weight(self): """Easy accessory for the DeepSpeed pipeline engine to tie embeddings across stages.""" return self.top_query_embeddings.weight - - + + class TransformerLanguageModel(MegatronModule): """Transformer language model. @@ -497,11 +527,9 @@ class TransformerLanguageModel(MegatronModule): will ignore this embedding """ - def __init__(self, - init_method, - output_layer_init_method, - num_tokentypes=0, - add_pooler=False): + def __init__( + self, init_method, output_layer_init_method, num_tokentypes=0, add_pooler=False + ): super(TransformerLanguageModel, self).__init__() args = get_args() @@ -511,82 +539,97 @@ def __init__(self, self.add_pooler = add_pooler # Embeddings - self.embedding = Embedding(self.hidden_size, - args.padded_vocab_size, - args.max_position_embeddings, - args.hidden_dropout, - self.init_method, - self.num_tokentypes) - self._embedding_key = 'embedding' + self.embedding = Embedding( + self.hidden_size, + args.padded_vocab_size, + args.max_position_embeddings, + args.hidden_dropout, + self.init_method, + self.num_tokentypes, + ) + self._embedding_key = "embedding" # Query embeddings - self.topQueryEmbedding = QueryEmbedding(self.hidden_size, - args.padded_vocab_size, - args.max_position_embeddings, - args.hidden_dropout, - self.init_method, - self.num_tokentypes) - self._topQueryEmbedding_key = 'topQueryEmbedding' + self.topQueryEmbedding = QueryEmbedding( + self.hidden_size, + args.padded_vocab_size, + args.max_position_embeddings, + args.hidden_dropout, + self.init_method, + self.num_tokentypes, + ) + self._topQueryEmbedding_key = "topQueryEmbedding" # Transformer self.transformer = ParallelTransformer( - self.init_method, - output_layer_init_method) - self._transformer_key = 'transformer' + self.init_method, output_layer_init_method + ) + self._transformer_key = "transformer" def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.transformer.set_input_tensor(input_tensor) - + def forward( - self, - input_ids, - position_ids, - attention_mask, - tokentype_ids=None, - layer_past=None, - get_key_value=False, - pooling_sequence_index=0, - prompt_length=None, - context_length=None, + self, + input_ids, + position_ids, + attention_mask, + tokentype_ids=None, + layer_past=None, + get_key_value=False, + pooling_sequence_index=0, + prompt_length=None, + context_length=None, ): # Embeddings. - embedding_output = self.embedding(input_ids, position_ids, - tokentype_ids=tokentype_ids) + embedding_output = self.embedding( + input_ids, position_ids, tokentype_ids=tokentype_ids + ) query_position_ids = position_ids - queryEmbedding_out = self.topQueryEmbedding(query_position_ids, - tokentype_ids=tokentype_ids) + queryEmbedding_out = self.topQueryEmbedding( + query_position_ids, tokentype_ids=tokentype_ids + ) # Transformer. - transformer_output = self.transformer(embedding_output, - queryEmbedding_out, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length, ) + transformer_output = self.transformer( + embedding_output, + queryEmbedding_out, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) return transformer_output - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._embedding_key] \ - = self.embedding.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._topQueryEmbedding_key] \ - = self.topQueryEmbedding.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._transformer_key] \ - = self.transformer.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[ + self._embedding_key + ] = self.embedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[ + self._topQueryEmbedding_key + ] = self.topQueryEmbedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[ + self._transformer_key + ] = self.transformer.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) if self.add_pooler: - state_dict_[self._pooler_key] \ - = self.pooler.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) return state_dict_ @@ -600,7 +643,7 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if '_embeddings' in key: + if "_embeddings" in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) @@ -610,7 +653,7 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if '_embeddings' in key: + if "_embeddings" in key: state_dict_[key] = state_dict[key] self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict) @@ -621,13 +664,13 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'transformer.' in key: - state_dict_[key.split('transformer.')[1]] = state_dict[key] + if "transformer." in key: + state_dict_[key.split("transformer.")[1]] = state_dict[key] self.transformer.load_state_dict(state_dict_, strict=strict) # Pooler. if self.add_pooler: - assert 'pooler' in state_dict, \ - 'could not find data for pooler in the checkpoint' - self.pooler.load_state_dict(state_dict[self._pooler_key], - strict=strict) + assert ( + "pooler" in state_dict + ), "could not find data for pooler in the checkpoint" + self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) diff --git a/codegeex/megatron/model/transformer.py b/codegeex/megatron/model/transformer.py index 02e66aa..0f2ef61 100644 --- a/codegeex/megatron/model/transformer.py +++ b/codegeex/megatron/model/transformer.py @@ -107,8 +107,7 @@ class ParallelSelfAttention(MegatronModule): and returns output of the same size. """ - def __init__(self, init_method, - output_layer_init_method, layer_number): + def __init__(self, init_method, output_layer_init_method, layer_number): super(ParallelSelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 @@ -118,13 +117,16 @@ def __init__(self, init_method, # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide( - args.hidden_size // 2 if args.compress else args.hidden_size, - world_size) + args.hidden_size // 2 if args.compress else args.hidden_size, world_size + ) self.hidden_size_per_attention_head = mpu.divide( - args.hidden_size // 2 if args.compress else args.hidden_size, args.num_attention_heads) + args.hidden_size // 2 if args.compress else args.hidden_size, + args.num_attention_heads, + ) self.num_attention_heads_per_partition = mpu.divide( - args.num_attention_heads, world_size) - if hasattr(args, 'attention_upweight'): + args.num_attention_heads, world_size + ) + if hasattr(args, "attention_upweight"): self.attention_upweight = args.attention_upweight else: self.attention_upweight = None @@ -133,17 +135,20 @@ def __init__(self, init_method, args.hidden_size, args.hidden_size // 2 if args.compress else args.hidden_size, gather_output=False, - init_method=init_method) + init_method=init_method, + ) self.key = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size // 2 if args.compress else args.hidden_size, gather_output=False, - init_method=init_method) + init_method=init_method, + ) self.value = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size // 2 if args.compress else args.hidden_size, gather_output=False, - init_method=init_method) + init_method=init_method, + ) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.softmax = torch.nn.Softmax(dim=-1) @@ -159,16 +164,17 @@ def __init__(self, init_method, args.hidden_size, input_is_parallel=True if args.tensor_model_parallel_size > 1 else False, init_method=output_layer_init_method, - skip_bias_add=True) + skip_bias_add=True, + ) def forward( - self, - hidden_states, - attention_mask, - layer_past=None, - get_key_value=False, - prompt_length=None, - context_length=None, + self, + hidden_states, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, ): # hidden_states: [sq, b, h] @@ -180,19 +186,22 @@ def forward( key_layer, _ = self.key(hidden_states) value_layer, _ = self.value(hidden_states) - new_query_layer_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) query_layer = query_layer.view(*new_query_layer_shape) - new_query_layer_shape = key_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_query_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) key_layer = key_layer.view(*new_query_layer_shape) - new_query_layer_shape = value_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_query_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) value_layer = value_layer.view(*new_query_layer_shape) # ================================== @@ -201,10 +210,10 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), - key_layer), dim=0) - value_layer = torch.cat((past_value.type_as(value_layer), - value_layer), dim=0) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) if get_key_value: present = (key_layer, value_layer) @@ -213,30 +222,45 @@ def forward( # =================================== # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) - key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) + query_layer = query_layer.contiguous().view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.contiguous().view( + output_size[3], output_size[0] * output_size[1], -1 + ) # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.matmul(query_layer.transpose(0, 1), - key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor + matmul_result = ( + torch.matmul( + query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2) + ) + / self.norm_factor + ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) if self.attention_upweight is not None and layer_past is None: - log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3), - device=torch.cuda.current_device(), - dtype=torch.half if self.fp16 else torch.float32) + log_attention_weights = torch.zeros( + attention_scores.size(3), + attention_scores.size(3), + device=torch.cuda.current_device(), + dtype=torch.half if self.fp16 else torch.float32, + ) if prompt_length is None: log_attention_weights = self.attention_upweight else: - log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight + log_attention_weights[ + :prompt_length, :prompt_length + ] = self.attention_upweight attention_scores += log_attention_weights # ================================================== @@ -247,14 +271,12 @@ def forward( with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ - ..., - attention_scores.size(3) - 1, - :attention_scores.size(3)].unsqueeze(2) + ..., attention_scores.size(3) - 1, : attention_scores.size(3) + ].unsqueeze(2) else: attention_mask = attention_mask[ - ..., - :attention_scores.size(3), - :attention_scores.size(3)] + ..., : attention_scores.size(3), : attention_scores.size(3) + ] # =========================== # Attention probs and dropout @@ -285,19 +307,26 @@ def forward( # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) - # change view [sq, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [sq, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) - context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) + context_layer = torch.bmm( + attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0) + ) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) @@ -306,8 +335,9 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) context_layer = context_layer.view(*new_context_layer_shape) # ================= @@ -329,44 +359,47 @@ class ParallelTopQuerySelfAttention(MegatronModule): and returns output of the same size. """ - def __init__(self, init_method, - output_layer_init_method, layer_number): + def __init__(self, init_method, output_layer_init_method, layer_number): super(ParallelTopQuerySelfAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.layer_number = max(1, layer_number) - if hasattr(args, 'attention_upweight_top'): + if hasattr(args, "attention_upweight_top"): self.attention_upweight = args.attention_upweight_top else: self.attention_upweight = None # Per attention head and per partition values. world_size = mpu.get_model_parallel_world_size() - self.hidden_size_per_partition = mpu.divide(args.hidden_size, - world_size) + self.hidden_size_per_partition = mpu.divide(args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( - args.hidden_size, args.num_attention_heads) + args.hidden_size, args.num_attention_heads + ) self.num_attention_heads_per_partition = mpu.divide( - args.num_attention_heads, world_size) + args.num_attention_heads, world_size + ) self.query = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, - init_method=init_method) + init_method=init_method, + ) self.key = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, - init_method=init_method) + init_method=init_method, + ) self.value = mpu.ColumnParallelLinear( args.hidden_size, args.hidden_size, gather_output=False, - init_method=init_method) + init_method=init_method, + ) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.softmax = torch.nn.Softmax(dim=-1) @@ -382,17 +415,18 @@ def __init__(self, init_method, args.hidden_size, input_is_parallel=True if args.tensor_model_parallel_size > 1 else False, init_method=output_layer_init_method, - skip_bias_add=True) + skip_bias_add=True, + ) def forward( - self, - hidden_states, - query_hidden_state, - attention_mask, - layer_past=None, - get_key_value=False, - prompt_length=None, - context_length=None, + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, ): # hidden_states: [sq, b, h] @@ -401,19 +435,22 @@ def forward( key_layer, _ = self.key(hidden_states) value_layer, _ = self.value(hidden_states) - new_query_layer_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) query_layer = query_layer.view(*new_query_layer_shape) - new_query_layer_shape = key_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_query_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) key_layer = key_layer.view(*new_query_layer_shape) - new_query_layer_shape = value_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_query_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) value_layer = value_layer.view(*new_query_layer_shape) # ================================== @@ -422,10 +459,10 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), - key_layer), dim=0) - value_layer = torch.cat((past_value.type_as(value_layer), - value_layer), dim=0) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) if get_key_value: present = (key_layer, value_layer) @@ -434,30 +471,45 @@ def forward( # =================================== # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) # [s, b, np, hn] -> [s, b * np, hn] - query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) - key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) + query_layer = query_layer.contiguous().view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.contiguous().view( + output_size[3], output_size[0] * output_size[1], -1 + ) # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.matmul(query_layer.transpose(0, 1), - key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor + matmul_result = ( + torch.matmul( + query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2) + ) + / self.norm_factor + ) # change view to [b, np, s, s] attention_scores = matmul_result.view(*output_size) if self.attention_upweight is not None and layer_past is None: - log_attention_weights = torch.zeros(attention_scores.size(3), attention_scores.size(3), - device=torch.cuda.current_device(), - dtype=torch.half if self.fp16 else torch.float32) + log_attention_weights = torch.zeros( + attention_scores.size(3), + attention_scores.size(3), + device=torch.cuda.current_device(), + dtype=torch.half if self.fp16 else torch.float32, + ) if prompt_length is None: log_attention_weights = self.attention_upweight else: - log_attention_weights[:prompt_length, :prompt_length] = self.attention_upweight + log_attention_weights[ + :prompt_length, :prompt_length + ] = self.attention_upweight attention_scores += log_attention_weights # ================================================== @@ -468,14 +520,12 @@ def forward( with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ - ..., - attention_scores.size(3) - 1, - :attention_scores.size(3)].unsqueeze(2) + ..., attention_scores.size(3) - 1, : attention_scores.size(3) + ].unsqueeze(2) else: attention_mask = attention_mask[ - ..., - :attention_scores.size(3), - :attention_scores.size(3)] + ..., : attention_scores.size(3), : attention_scores.size(3) + ] # =========================== # Attention probs and dropout @@ -506,20 +556,27 @@ def forward( # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) # change view [sq, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) + context_layer = torch.bmm( + attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0) + ) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) @@ -528,8 +585,9 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) context_layer = context_layer.view(*new_context_layer_shape) # ================= @@ -577,53 +635,51 @@ class ParallelTransformerLayer(MegatronModule): output of the same size. """ - def __init__(self, init_method, - output_layer_init_method, layer_number): + def __init__(self, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayer, self).__init__() self.layer_number = layer_number - self.apply_residual_connection_post_layernorm \ - = args.apply_residual_connection_post_layernorm + self.apply_residual_connection_post_layernorm = ( + args.apply_residual_connection_post_layernorm + ) # Layernorm on the input data. - self.input_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon) + self.input_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) # Self attention. - self.attention = ParallelSelfAttention(init_method, - output_layer_init_method, - layer_number) + self.attention = ParallelSelfAttention( + init_method, output_layer_init_method, layer_number + ) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon) - if hasattr(args, 'attention_upweight'): + args.hidden_size, eps=args.layernorm_epsilon + ) + if hasattr(args, "attention_upweight"): self.attention_upweight = args.attention_upweight else: self.attention_upweight = None - if hasattr(args, 'ln_fp16'): + if hasattr(args, "ln_fp16"): self.ln_fp16 = args.ln_fp16 else: self.ln_fp16 = False # MLP - self.mlp = ParallelMLP(init_method, - output_layer_init_method, - scale=2 if args.compress else 4) + self.mlp = ParallelMLP( + init_method, output_layer_init_method, scale=2 if args.compress else 4 + ) def forward( - self, - hidden_states, - attention_mask, - layer_past=None, - get_key_value=False, - prompt_length=None, - context_length=None, + self, + hidden_states, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, ): # hidden_states: [b, s, h] if self.ln_fp16: @@ -632,13 +688,14 @@ def forward( layernorm_output = self.input_layernorm(hidden_states.float()).half() # Self attention. - attention_output, attention_bias = \ - self.attention(layernorm_output, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + attention_output, attention_bias = self.attention( + layernorm_output, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: attention_output, presents = attention_output @@ -649,8 +706,8 @@ def forward( else: residual = hidden_states - # jit scripting for a nn.module (with dropout) is not - # trigerring the fusion kernel. For now, we use two + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: @@ -667,13 +724,16 @@ def forward( attention_output, attention_bias.expand_as(residual), residual, - self.hidden_dropout) + self.hidden_dropout, + ) # Layer norm post the self attention. if self.ln_fp16: layernorm_output = self.post_attention_layernorm(layernorm_input) else: - layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() + layernorm_output = self.post_attention_layernorm( + layernorm_input.float() + ).half() mlp_output, _ = self.mlp(layernorm_output) @@ -736,52 +796,49 @@ class ParallelTopQueryLayer(MegatronModule): output of the same size. """ - def __init__(self, init_method, - output_layer_init_method, layer_number): + def __init__(self, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTopQueryLayer, self).__init__() self.layer_number = layer_number - self.apply_residual_connection_post_layernorm \ - = args.apply_residual_connection_post_layernorm + self.apply_residual_connection_post_layernorm = ( + args.apply_residual_connection_post_layernorm + ) # Layernorm on the input data. - self.input_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon) + self.input_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) # Self attention. - self.attention = ParallelTopQuerySelfAttention(init_method, - output_layer_init_method, - layer_number) + self.attention = ParallelTopQuerySelfAttention( + init_method, output_layer_init_method, layer_number + ) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon) + args.hidden_size, eps=args.layernorm_epsilon + ) - if hasattr(args, 'ln_fp16'): + if hasattr(args, "ln_fp16"): self.ln_fp16 = args.ln_fp16 else: self.ln_fp16 = False # MLP - self.mlp = ParallelMLP(init_method, - output_layer_init_method) + self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward( - self, - hidden_states, - query_hidden_state, - attention_mask, - layer_past=None, - get_key_value=False, - prompt_length=None, - context_length=None, + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, ): # hidden_states: [b, s, h] assert query_hidden_state != None @@ -793,14 +850,15 @@ def forward( layernorm_output = self.input_layernorm(hidden_states.float()).half() # Self attention. - attention_output, attention_bias = \ - self.attention(layernorm_output, - query_hidden_state, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + attention_output, attention_bias = self.attention( + layernorm_output, + query_hidden_state, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: attention_output, presents = attention_output @@ -829,13 +887,16 @@ def forward( attention_output, attention_bias.expand_as(residual), residual, - self.hidden_dropout) + self.hidden_dropout, + ) # Layer norm post the self attention. if self.ln_fp16: layernorm_output = self.post_attention_layernorm(layernorm_input) else: - layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() + layernorm_output = self.post_attention_layernorm( + layernorm_input.float() + ).half() # MLP. mlp_output, _ = self.mlp(layernorm_output) @@ -883,15 +944,17 @@ def forward(self, inputs, **kwargs): self._args = get_args() hidden_states, query_hidden_state = inputs attention_mask = self._args.attn_mask - return super().forward(hidden_states, query_hidden_state, attention_mask, **kwargs) + return super().forward( + hidden_states, query_hidden_state, attention_mask, **kwargs + ) elif len(inputs) == 3: # Attention mask is an activation. hidden_states, query_hidden_state, attention_mask = inputs[0], inputs[1] return super().forward(*inputs, **kwargs), attention_mask else: raise RuntimeError("Received more inputs than understood.") - - + + class ParallelTransformer(MegatronModule): """Transformer class.""" @@ -913,39 +976,39 @@ def __init__(self, init_method, output_layer_init_method): if self.num_unique_layers is None: self.num_unique_layers = self.num_layers - assert self.num_layers % self.num_unique_layers == 0, \ - 'number of layers should be divisible by number of unique layers' - self.param_sharing_style = 'grouped' + assert ( + self.num_layers % self.num_unique_layers == 0 + ), "number of layers should be divisible by number of unique layers" + self.param_sharing_style = "grouped" # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( - init_method, - output_layer_init_method, layer_number) + init_method, output_layer_init_method, layer_number + ) self.layers = torch.nn.ModuleList( - [build_layer(i + 1) for i in range(self.num_unique_layers)]) + [build_layer(i + 1) for i in range(self.num_unique_layers)] + ) self.topQueryLayer = ParallelTopQueryLayer( - init_method, - output_layer_init_method, self.num_unique_layers) + init_method, output_layer_init_method, self.num_unique_layers + ) # Final layer norm before output. - if hasattr(args, 'ln_fp16'): + if hasattr(args, "ln_fp16"): self.ln_fp16 = args.ln_fp16 else: self.ln_fp16 = False - self.final_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon) + self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) def _get_layer_index(self, layer_number): - if self.param_sharing_style == 'grouped': + if self.param_sharing_style == "grouped": return layer_number % self.num_unique_layers - if self.param_sharing_style == 'spaced': + if self.param_sharing_style == "spaced": return layer_number // (self.num_layers // self.num_unique_layers) - assert False, 'should not be here' + assert False, "should not be here" def _get_layer(self, layer_number): return self.layers[self._get_layer_index(layer_number)] @@ -968,8 +1031,8 @@ def custom_forward(*inputs): l = 0 while l < self.num_layers: hidden_states = mpu.checkpoint( - custom(l, l + self.checkpoint_num_layers), - hidden_states, attention_mask) + custom(l, l + self.checkpoint_num_layers), hidden_states, attention_mask + ) l += self.checkpoint_num_layers return hidden_states @@ -983,35 +1046,34 @@ def set_input_tensor(self, input_tensor): used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor - + def forward( - self, - hidden_states, - query_hidden_state, - attention_mask, - layer_past=None, - get_key_value=False, - prompt_length=None, - context_length=None, + self, + hidden_states, + query_hidden_state, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, ): # Checks if layer_past is not None: - assert get_key_value, \ - 'for not None values in layer_past, ' \ - 'expected get_key_value to be set' + assert get_key_value, ( + "for not None values in layer_past, " "expected get_key_value to be set" + ) if get_key_value: - assert not self.checkpoint_activations, \ - 'get_key_value does not work with ' \ - 'activation checkpointing' + assert not self.checkpoint_activations, ( + "get_key_value does not work with " "activation checkpointing" + ) # data format change to avoid explicit tranposes : [b s h] --> [s b h] hidden_states = hidden_states.transpose(0, 1).contiguous() query_hidden_state = query_hidden_state.transpose(0, 1).contiguous() if self.checkpoint_activations: - hidden_states = self._checkpointed_forward(hidden_states, - attention_mask) + hidden_states = self._checkpointed_forward(hidden_states, attention_mask) else: if get_key_value: presents = [] @@ -1020,12 +1082,14 @@ def forward( past = None if layer_past is not None: past = layer_past[index] - hidden_states = layer(hidden_states, - attention_mask, - layer_past=past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + hidden_states = layer( + hidden_states, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: hidden_states, present = hidden_states presents.append(present) @@ -1041,13 +1105,15 @@ def forward( past = None if layer_past is not None: past = layer_past[self.num_layers] - hidden_states = self.topQueryLayer(hidden_states_, - query_hidden_state, - attention_mask, - layer_past=past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + hidden_states = self.topQueryLayer( + hidden_states_, + query_hidden_state, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: hidden_states, present = hidden_states diff --git a/codegeex/megatron/model/utils.py b/codegeex/megatron/model/utils.py index 6543f7b..fc2d68a 100644 --- a/codegeex/megatron/model/utils.py +++ b/codegeex/megatron/model/utils.py @@ -55,14 +55,18 @@ def get_linear_layer(rows, columns, init_method): def fast_gelu(x): """Mindspore's fast gelu implementation.""" - return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(0.851 * (x - torch.abs(x))) + return ( + x + / (1 + torch.exp(-1.702 * torch.abs(x))) + * torch.exp(0.851 * (x - torch.abs(x))) + ) @torch.jit.script def gelu_impl(x): """OpenAI's gelu implementation.""" return ( - 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) ) @@ -74,10 +78,10 @@ def openai_gelu(x): @torch.jit.script def erf_gelu(x): return ( - x - * 0.5 - * ( - torch.erf(x / 1.41421).to(dtype=x.dtype) - + torch.ones_like(x).to(dtype=x.dtype) - ) + x + * 0.5 + * ( + torch.erf(x / 1.41421).to(dtype=x.dtype) + + torch.ones_like(x).to(dtype=x.dtype) + ) ) diff --git a/codegeex/megatron/mpu/layers.py b/codegeex/megatron/mpu/layers.py index 1e32624..c76c77a 100644 --- a/codegeex/megatron/mpu/layers.py +++ b/codegeex/megatron/mpu/layers.py @@ -287,7 +287,7 @@ def __init__( self.skip_bias_add = skip_bias_add self.params_dtype = params_dtype self.device = device - + # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. @@ -299,7 +299,9 @@ def __init__( torch.empty( self.output_size_per_partition, self.input_size, - dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype, + dtype=self.params_dtype + if self.params_dtype is not None + else args.params_dtype, ) ) self.master_weight = _initialize_affine_weight_cpu( @@ -317,8 +319,12 @@ def __init__( torch.empty( self.output_size_per_partition, self.input_size, - device=self.device if self.device is not None else torch.cuda.current_device(), - dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype, + device=self.device + if self.device is not None + else torch.cuda.current_device(), + dtype=self.params_dtype + if self.params_dtype is not None + else args.params_dtype, ) ) _initialize_affine_weight_gpu( @@ -330,15 +336,23 @@ def __init__( if bias and not skip_init: if args.use_cpu_initialization: self.bias = Parameter( - torch.empty(self.output_size_per_partition, - dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype) + torch.empty( + self.output_size_per_partition, + dtype=self.params_dtype + if self.params_dtype is not None + else args.params_dtype, + ) ) else: self.bias = Parameter( torch.empty( self.output_size_per_partition, - device=self.device if self.device is not None else torch.cuda.current_device(), - dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype, + device=self.device + if self.device is not None + else torch.cuda.current_device(), + dtype=self.params_dtype + if self.params_dtype is not None + else args.params_dtype, ) ) set_tensor_model_parallel_attributes(self.bias, True, 0, stride) @@ -420,7 +434,7 @@ def __init__( self.skip_bias_add = skip_bias_add self.params_dtype = params_dtype self.device = device - + # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. @@ -432,7 +446,9 @@ def __init__( torch.empty( self.output_size, self.input_size_per_partition, - dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype, + dtype=self.params_dtype + if self.params_dtype is not None + else args.params_dtype, ) ) self.master_weight = _initialize_affine_weight_cpu( @@ -450,8 +466,12 @@ def __init__( torch.empty( self.output_size, self.input_size_per_partition, - device=self.device if self.device is not None else torch.cuda.current_device(), - dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype, + device=self.device + if self.device is not None + else torch.cuda.current_device(), + dtype=self.params_dtype + if self.params_dtype is not None + else args.params_dtype, ) ) _initialize_affine_weight_gpu( @@ -459,19 +479,27 @@ def __init__( ) else: self.register_parameter("weight", None) - + if bias and not skip_init: if args.use_cpu_initialization: self.bias = Parameter( - torch.empty(self.output_size, - dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype) + torch.empty( + self.output_size, + dtype=self.params_dtype + if self.params_dtype is not None + else args.params_dtype, + ) ) else: self.bias = Parameter( torch.empty( self.output_size, - device=self.device if self.device is not None else torch.cuda.current_device(), - dtype=self.params_dtype if self.params_dtype is not None else args.params_dtype, + device=self.device + if self.device is not None + else torch.cuda.current_device(), + dtype=self.params_dtype + if self.params_dtype is not None + else args.params_dtype, ) ) # Always initialize bias to zero. diff --git a/codegeex/megatron/optimizer/__init__.py b/codegeex/megatron/optimizer/__init__.py index b84bfe9..4176e66 100644 --- a/codegeex/megatron/optimizer/__init__.py +++ b/codegeex/megatron/optimizer/__init__.py @@ -60,9 +60,9 @@ def get_megatron_optimizer(model): if args.cpu_optimizer: raise NotImplementedError("need to add cpu adam") - + param_groups = _get_params_for_weight_decay_optimization(model) - + if args.optimizer == "adam": optimizer = Adam( param_groups, diff --git a/codegeex/megatron/optimizer/clip_grads.py b/codegeex/megatron/optimizer/clip_grads.py index 281d84d..6733040 100644 --- a/codegeex/megatron/optimizer/clip_grads.py +++ b/codegeex/megatron/optimizer/clip_grads.py @@ -97,12 +97,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ) # Since we will be summing across data parallel groups, # we need the pow(norm-type). - total_norm = grad_norm ** norm_type + total_norm = grad_norm**norm_type else: for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm ** norm_type + total_norm += grad_norm**norm_type # Sum across all model-parallel GPUs. torch.distributed.all_reduce( diff --git a/codegeex/megatron/tokenizer/gpt2_tokenization.py b/codegeex/megatron/tokenizer/gpt2_tokenization.py index bcf8cf6..2c22d11 100644 --- a/codegeex/megatron/tokenizer/gpt2_tokenization.py +++ b/codegeex/megatron/tokenizer/gpt2_tokenization.py @@ -69,10 +69,10 @@ def bytes_to_unicode(): ) cs = bs[:] n = 0 - for b in range(2 ** 8): + for b in range(2**8): if b not in bs: bs.append(b) - cs.append(2 ** 8 + n) + cs.append(2**8 + n) n += 1 cs = [_chr(n) for n in cs] return dict(zip(bs, cs)) diff --git a/codegeex/megatron/tokenizer/tokenizer.py b/codegeex/megatron/tokenizer/tokenizer.py index da55e2a..923a4e6 100644 --- a/codegeex/megatron/tokenizer/tokenizer.py +++ b/codegeex/megatron/tokenizer/tokenizer.py @@ -23,7 +23,7 @@ from transformers import AutoTokenizer -def encode_whitespaces(text: str, start_extra_id: int=10, max_len: int=10): +def encode_whitespaces(text: str, start_extra_id: int = 10, max_len: int = 10): """Encode whitespaces to extra tokens. >>> encode_whitespaces('a\\n b\\n c', 10, 10) @@ -34,7 +34,7 @@ def encode_whitespaces(text: str, start_extra_id: int=10, max_len: int=10): return text -def decode_whitespaces(text: str, start_extra_id: int=10, max_len: int=10): +def decode_whitespaces(text: str, start_extra_id: int = 10, max_len: int = 10): """Decode the whitespace-encoded strings produced by encode_whitespace. >>> text = 'a\\n b\\n c' @@ -63,9 +63,7 @@ def build_hgf_tokenizer(args): ws_start_id = args.ws_encoding_start_id if "ws_encoding_start_id" in args else None ws_len = args.ws_encoding_length if "ws_encoding_length" in args else None - return HgfTokenizerWrapper( - tokenizer, ws_start=ws_start_id, ws_len=ws_len - ) + return HgfTokenizerWrapper(tokenizer, ws_start=ws_start_id, ws_len=ws_len) def build_tokenizer(args): @@ -218,10 +216,10 @@ class HgfTokenizerWrapper(AbstractTokenizer): """Wrapper for Hugging Face tokenizer.""" def __init__( - self, - tokenizer, - ws_start: int = None, - ws_len: int = None, + self, + tokenizer, + ws_start: int = None, + ws_len: int = None, ): super(HgfTokenizerWrapper, self).__init__(tokenizer.__class__.__name__) self.tokenizer = tokenizer diff --git a/codegeex/megatron/tools/collect_env.py b/codegeex/megatron/tools/collect_env.py index 975f067..76b1ece 100644 --- a/codegeex/megatron/tools/collect_env.py +++ b/codegeex/megatron/tools/collect_env.py @@ -1,7 +1,13 @@ import os -ENV_NAMES = ["CUDA_HOME", "LD_LIBRARY_PATH", "PATH", "TORCH_EXTENSIONS_DIR", "CUDA_LAUNCH_BLOCKING"] +ENV_NAMES = [ + "CUDA_HOME", + "LD_LIBRARY_PATH", + "PATH", + "TORCH_EXTENSIONS_DIR", + "CUDA_LAUNCH_BLOCKING", +] def main(): diff --git a/codegeex/megatron/tools/finetune_codegeex.py b/codegeex/megatron/tools/finetune_codegeex.py index 4a56602..0cec07c 100644 --- a/codegeex/megatron/tools/finetune_codegeex.py +++ b/codegeex/megatron/tools/finetune_codegeex.py @@ -8,7 +8,7 @@ from deepspeed.runtime.utils import see_memory_usage from functools import partial -from codegeex.megatron import get_args, print_rank_0, get_timers,get_tokenizer, mpu +from codegeex.megatron import get_args, print_rank_0, get_timers, get_tokenizer, mpu from codegeex.megatron.data.prompt_dataset import build_train_valid_test_datasets from codegeex.megatron.model import CodeGeeXModel from codegeex.megatron.training import pretrain @@ -61,7 +61,7 @@ def model_provider(pre_process=True, post_process=True): num_tokentypes=0, parallel_output=True, ) - + if args.load_state is not None: timers = get_timers() print_rank_0("Loading warmstarting model states ...") @@ -69,7 +69,8 @@ def model_provider(pre_process=True, post_process=True): mp_rank = mpu.get_tensor_model_parallel_rank() if os.path.isdir(args.load_state): model_path = os.path.join( - args.load_state, "mp_rank_{:02d}_model_states.pt".format(mp_rank) + args.load_state, + "mp_rank_{:02d}_model_states.pt".format(mp_rank), ) else: model_path = args.load_state @@ -81,7 +82,7 @@ def model_provider(pre_process=True, post_process=True): timers("load-model-states").stop() timers.log(["load-model-states"]) see_memory_usage(f"After Building Model", force=True) - + return model @@ -150,7 +151,7 @@ def get_batch_pipe(data): args.reset_attention_mask, args.eod_mask_loss, ) - + return (tokens, position_ids, attention_mask), (labels, loss_mask) @@ -166,7 +167,9 @@ def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor): losses = prob * losses loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min(loss_mask.sum(), 1e-8) + loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min( + loss_mask.sum(), 1e-8 + ) return loss @@ -184,7 +187,9 @@ def valid_loss_func(loss_mask, output_tensor): def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor): loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min(loss_mask.sum(), 1e-8) + loss = torch.sum(losses.view(-1) * loss_mask) / torch.clamp_min( + loss_mask.sum(), 1e-8 + ) return loss @@ -193,7 +198,7 @@ def compute_lm_loss(losses: torch.Tensor, loss_mask: torch.Tensor): # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) - + return loss, {"lm loss": averaged_loss[0]} @@ -230,15 +235,19 @@ def valid_forward_step(data_iterator, model): def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() - + print_rank_0("> building train, validation, and test datasets " "for GPT ...") if args.co_evaluation: + def dataset_partition_path_parsing(data_path): dataset_path = {} for index in range(len(data_path)): dataset_path[data_path[index]] = data_path[index] return dataset_path - assert args.valid_data_path is not None, "Valid data path must be given when --co-evaluation is turned on." + + assert ( + args.valid_data_path is not None + ), "Valid data path must be given when --co-evaluation is turned on." valid_data_path = dataset_partition_path_parsing(args.valid_data_path) if args.test_data_path is not None: test_data_path = dataset_partition_path_parsing(args.test_data_path) @@ -323,4 +332,4 @@ def dataset_partition_path_parsing(data_path): forward_step, valid_forward_step, args_defaults={"tokenizer_type": "GPT2BPETokenizer"}, - ) \ No newline at end of file + ) diff --git a/codegeex/megatron/tools/pretrain_codegeex.py b/codegeex/megatron/tools/pretrain_codegeex.py index 2569d49..b2df0b0 100644 --- a/codegeex/megatron/tools/pretrain_codegeex.py +++ b/codegeex/megatron/tools/pretrain_codegeex.py @@ -9,9 +9,9 @@ from deepspeed.runtime.utils import see_memory_usage from functools import partial -from codegeex.megatron import get_args, print_rank_0, get_timers,get_tokenizer, mpu +from codegeex.megatron import get_args, print_rank_0, get_timers, get_tokenizer, mpu from codegeex.megatron.data.prompt_dataset import build_train_valid_test_datasets -from codegeex.megatron.model import CodeGeeXModel #, CodeGeeXModelPipe +from codegeex.megatron.model import CodeGeeXModel # , CodeGeeXModelPipe from codegeex.megatron.training import pretrain from codegeex.megatron.utils import get_ltor_masks_and_position_ids from codegeex.megatron.utils import average_losses_across_data_parallel_group @@ -62,7 +62,7 @@ def model_provider(pre_process=True, post_process=True): num_tokentypes=0, parallel_output=True, ) - + if args.load_state is not None: timers = get_timers() print_rank_0("Loading warmstarting model states ...") @@ -70,7 +70,8 @@ def model_provider(pre_process=True, post_process=True): mp_rank = mpu.get_tensor_model_parallel_rank() if os.path.isdir(args.load_state): model_path = os.path.join( - args.load_state, "mp_rank_{:02d}_model_states.pt".format(mp_rank) + args.load_state, + "mp_rank_{:02d}_model_states.pt".format(mp_rank), ) else: model_path = args.load_state @@ -82,7 +83,7 @@ def model_provider(pre_process=True, post_process=True): timers("load-model-states").stop() timers.log(["load-model-states"]) see_memory_usage(f"After Building Model", force=True) - + return model diff --git a/codegeex/megatron/training.py b/codegeex/megatron/training.py index 593649d..5bda88c 100644 --- a/codegeex/megatron/training.py +++ b/codegeex/megatron/training.py @@ -66,7 +66,6 @@ import pathlib - def print_datetime(string): """Note that this call will sync across all ranks.""" torch.distributed.barrier() @@ -191,11 +190,22 @@ def pretrain( if args.co_evaluation: for key, value in valid_data_iterator.items(): evaluate_and_print_results( - prefix, valid_forward_step_func, value, model, iteration, False, tag=key + prefix, + valid_forward_step_func, + value, + model, + iteration, + False, + tag=key, ) else: evaluate_and_print_results( - prefix, valid_forward_step_func, valid_data_iterator, model, iteration, False + prefix, + valid_forward_step_func, + valid_data_iterator, + model, + iteration, + False, ) if args.save and iteration != 0: @@ -406,7 +416,7 @@ def setup_model_and_optimizer(model_provider_func): print_rank_0("DeepSpeed is enabled.") pp = mpu.get_pipeline_model_parallel_world_size() print_rank_0(pp) - + model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model[0], optimizer=optimizer, @@ -440,13 +450,17 @@ def setup_model_and_optimizer(model_provider_func): timers("load-checkpoint").start() if args.low_memory_load: load_start = time.perf_counter() - with FileLock(os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1): + with FileLock( + os.path.join(pathlib.Path.home(), "checkpoint_lock"), timeout=-1 + ): this_rank_load_start = time.perf_counter() print(f"Rank {args.rank} is loading checkpoint ...") args.iteration = load_checkpoint(model, optimizer, lr_scheduler) this_rank_load_time = time.perf_counter() - this_rank_load_start load_time = time.perf_counter() - load_start - print(f"Rank {args.rank} loaded checkpoint, this rank time: {this_rank_load_time}, total time: {load_time}") + print( + f"Rank {args.rank} loaded checkpoint, this rank time: {this_rank_load_time}, total time: {load_time}" + ) else: args.iteration = load_checkpoint(model, optimizer, lr_scheduler) print(f"Rank {args.rank} loaded checkpoint and waiting for other ranks") @@ -918,7 +932,7 @@ def train( timers("interval-time").start() print_datetime("before the start of training step") report_memory_flag = True - + while iteration < args.train_iters and ( args.train_tokens is None or args.consumed_train_tokens < args.train_tokens ): @@ -979,20 +993,38 @@ def train( if args.co_evaluation: for key, value in valid_data_iterator.items(): evaluate_and_print_results( - prefix, valid_forward_step_func, value, model, iteration, False, tag=key + prefix, + valid_forward_step_func, + value, + model, + iteration, + False, + tag=key, ) else: if args.gold: evaluate_and_print_results_gold( - prefix, forward_step_func, valid_data_iterator, model, iteration, False + prefix, + forward_step_func, + valid_data_iterator, + model, + iteration, + False, ) evaluate_and_print_results( - prefix, valid_forward_step_func, valid_data_iterator, model, iteration, False + prefix, + valid_forward_step_func, + valid_data_iterator, + model, + iteration, + False, ) # Checkpointing saved_checkpoint = False - if args.save and args.save_interval and (iteration % args.save_interval == 0): # debugging + if ( + args.save and args.save_interval and (iteration % args.save_interval == 0) + ): # debugging save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler) saved_checkpoint = True @@ -1209,7 +1241,7 @@ def evaluate_and_print_results_gold( print_rank_last("-" * length) print_rank_last(string) print_rank_last("-" * length) - + def cyclic_iter(iter): while True: @@ -1326,9 +1358,7 @@ def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provid valid_data_iterator = {} for key, value in valid_dataloader.items(): valid_data_iterator[key] = ( - iter(value) - if dl_type == "single" - else iter(cyclic_iter(value)) + iter(value) if dl_type == "single" else iter(cyclic_iter(value)) ) else: valid_data_iterator = ( @@ -1344,9 +1374,7 @@ def build_train_valid_test_data_iterators(build_train_valid_test_datasets_provid test_data_iterator = {} for key, value in test_dataloader.items(): test_data_iterator[key] = ( - iter(value) - if dl_type == "single" - else iter(cyclic_iter(value)) + iter(value) if dl_type == "single" else iter(cyclic_iter(value)) ) else: test_data_iterator = ( diff --git a/codegeex/megatron/utils.py b/codegeex/megatron/utils.py index 5f88e18..c32cd1c 100644 --- a/codegeex/megatron/utils.py +++ b/codegeex/megatron/utils.py @@ -239,4 +239,4 @@ def flops_calculator(model, args, iteration_time): print_rank_0( f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B" - ) \ No newline at end of file + ) diff --git a/codegeex/mindspore/convertion_1p.py b/codegeex/mindspore/convertion_1p.py index 1e8c746..151fbba 100644 --- a/codegeex/mindspore/convertion_1p.py +++ b/codegeex/mindspore/convertion_1p.py @@ -39,12 +39,12 @@ def load_model(args_opt): r""" - The main function for load model + The main function for load model """ # Set execution mode - context.set_context(save_graphs=False, - mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -59,7 +59,8 @@ def load_model(args_opt): full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, - pipeline_stages=args_opt.stage_num) + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -68,13 +69,14 @@ def load_model(args_opt): device_num = 1 context.reset_auto_parallel_context() context.set_auto_parallel_context( - strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path + ) context.set_context( save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - use_past = (args_opt.use_past == "true") - print('local_rank:{}, start to run...'.format(rank), flush=True) + use_past = args_opt.use_past == "true" + print("local_rank:{}, start to run...".format(rank), flush=True) if args_opt.export: use_past = True # Set model property @@ -85,13 +87,15 @@ def load_model(args_opt): data_parallel_num = int(device_num / model_parallel_num) print("===data_parallel_num is: ", data_parallel_num, flush=True) - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=False, - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=False, + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num @@ -114,7 +118,7 @@ def load_model(args_opt): parallel_config=parallel_config, load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 - if args_opt.param_init_type == 'fp32' + if args_opt.param_init_type == "fp32" else mstype.float16, ) print("===config is: ", config, flush=True) @@ -126,7 +130,9 @@ def load_model(args_opt): eval_net.set_train(False) model_predict = Model(eval_net) # Compile network and obtain tensor layout for loading ckpt - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) if args_opt.distribute == "false": @@ -138,11 +144,15 @@ def load_model(args_opt): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) print("is_first_iteration=True", flush=True) - predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length) + predict_layout = model_predict.infer_predict_layout( + inputs_np, current_index, init_true, batch_valid_length + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) print("is_first_iteration=False", flush=True) init_false = Tensor([False], mstype.bool_) - _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length) + _ = model_predict.infer_predict_layout( + inputs_np_1, current_index, init_false, batch_valid_length + ) else: predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) @@ -154,14 +164,20 @@ def load_model(args_opt): ckpt_name = f"code-13B0-50.ckpt" # TODO: set to current ckpt name if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, "rank_0", ckpt_name)): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, "rank_0", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, "rank_0", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) # TODO: add them back if not for the 1st run! if param_dict.get("epoch_num") and param_dict.get("step_num"): args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) - if not os.path.exists(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}'): + if not os.path.exists( + f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}' + ): os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) @@ -171,12 +187,19 @@ def load_model(args_opt): print("Loaded ckpt in step 1: ", num) time.sleep(1) net_not_load = load_param_into_net(pangu_alpha, param_dict) - print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True) + print( + "====== load_distributed checkpoint done, net_not_load: ", + net_not_load, + flush=True, + ) if not os.path.exists("/home/work/sfs/cache/ckpts_npy/"): os.mkdir("/home/work/sfs/cache/ckpts_npy/") for k, weight in pangu_alpha.parameters_dict().items(): print(k) - np.save(os.path.join("/home/work/sfs/cache/ckpts_npy/", f"{k}.npy"), weight.asnumpy()) + np.save( + os.path.join("/home/work/sfs/cache/ckpts_npy/", f"{k}.npy"), + weight.asnumpy(), + ) rank_obs_save_path = "./" # TODO: set to current obs path for saving if not mox.file.exists(rank_obs_save_path): mox.file.make_dirs(rank_obs_save_path) @@ -187,7 +210,9 @@ def load_model(args_opt): def export_mindir(model_predict, config): """Export mindir model""" - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) batch_valid_length = Tensor(np.array([0]), mstype.int32) @@ -195,26 +220,41 @@ def export_mindir(model_predict, config): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) - export(model_predict.predict_network, inputs_np, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1024", + file_format="MINDIR", + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) - export(model_predict.predict_network, inputs_np_1, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np_1, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1", + file_format="MINDIR", + ) print("Export finished and now exit.") def run_predict(model_predict, config, args_opt, rank): """run predict""" from src.generate import generate, generate_increment + # Define tokenizer - tokenizer = CodeTokenizer(mode='6b') + tokenizer = CodeTokenizer(mode="6b") # Tokenize input sentence to ids samples = [ "# language: Python\ndef add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "def add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "# language: Python\ndef optimization():\n '''\n Find the maximum of P=E**2*R/(R + r)**2 if E and r are fixed but R varies. Import sympy. Use sympy. Find where the derivative is equal to zero. Substitute the value of R into P.\n '''\n", - "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n', "// language: C++\nint add(int a, int b) {\n /* Find the sum of a and b. */\n", "int add(int a, int b) {\n /* Find the sum of a and b. */\n", "bool prime(int n) {\n // Find whether n is a prime number\n", @@ -240,7 +280,8 @@ def run_predict(model_predict, config, args_opt, rank): print(f"=================== generation {i} ====================") print(output_samples_str, flush=True) print( - f"=== Total time (s): {t1 - t0}, {output_ids.shape[-1] - input_ids.shape[-1]} tokens, {(output_ids.shape[-1] - input_ids.shape[-1]) / (t1 - t0)} token/s") + f"=== Total time (s): {t1 - t0}, {output_ids.shape[-1] - input_ids.shape[-1]} tokens, {(output_ids.shape[-1] - input_ids.shape[-1]) / (t1 - t0)} token/s" + ) def main(): diff --git a/codegeex/mindspore/finetune.py b/codegeex/mindspore/finetune.py index 093ea6a..aed2dd2 100644 --- a/codegeex/mindspore/finetune.py +++ b/codegeex/mindspore/finetune.py @@ -36,7 +36,11 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import TimeMonitor from mindspore.train.model import Model -from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net +from mindspore.train.serialization import ( + load_distributed_checkpoint, + load_checkpoint, + load_param_into_net, +) from tensorboardX import SummaryWriter from src.adam import AdamWeightDecayOp @@ -50,15 +54,18 @@ from src.utils import download_data project_root = os.path.abspath( - os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..") -print('project_root:', project_root) + os.path.dirname(os.path.realpath(__file__)) + os.path.sep + ".." +) +print("project_root:", project_root) def set_weight_decay(params): """ Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest """ - decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() + decay_filter = ( + lambda x: "layernorm" not in x.name.lower() and "bias" not in x.name.lower() + ) decay_params = list(filter(decay_filter, params)) other_params = list(filter(lambda x: not decay_filter(x), params)) group_params = [ @@ -75,7 +82,12 @@ def add_checkpoint_callback_policy(args_param, callback, rank_id): """ if args_param.save_checkpoint: # checkpoint store epoch_num and step_num info - ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}] + ckpt_append_info = [ + { + "epoch_num": args_param.has_trained_epoches, + "step_num": args_param.has_trained_steps, + } + ] ckpt_config = CheckpointConfig( save_checkpoint_steps=args_param.save_checkpoint_steps, keep_checkpoint_max=args_param.keep_checkpoint_max, @@ -84,18 +96,22 @@ def add_checkpoint_callback_policy(args_param, callback, rank_id): ) # save checkpoint into rank directory - ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id), - directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), - config=ckpt_config) + ckpoint_cb = ModelCheckpoint( + prefix=args_param.ckpt_name_prefix + str(rank_id), + directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), + config=ckpt_config, + ) callback.append(ckpoint_cb) - saveckpt_cb = SaveCheckpointCallback(cache_dir=args_param.save_checkpoint_path, - bucket=args_param.save_checkpoint_obs_path, - local_rank=rank_id, - has_trained_epoch=args_param.has_trained_epoches, - has_trained_step=args_param.has_trained_steps, - syn_times=args_param.save_checkpoint_steps) + saveckpt_cb = SaveCheckpointCallback( + cache_dir=args_param.save_checkpoint_path, + bucket=args_param.save_checkpoint_obs_path, + local_rank=rank_id, + has_trained_epoch=args_param.has_trained_epoches, + has_trained_step=args_param.has_trained_steps, + syn_times=args_param.save_checkpoint_steps, + ) callback.append(saveckpt_cb) @@ -109,9 +125,12 @@ def set_parallel_context(args_opt): args_opt.optimizer_shard = 0 context.reset_auto_parallel_context() context.set_auto_parallel_context( - parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, - full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, - enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt', + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, + gradients_mean=False, + full_batch=bool(args_opt.full_batch), + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, + enable_parallel_optimizer=bool(args_opt.optimizer_shard), + strategy_ckpt_save_file="strategy.ckpt", ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -122,9 +141,7 @@ def run_train(args_opt): r"""The main training process.""" os.environ["HCCL_CONNECT_TIMEOUT"] = "2000" # Set execution mode - context.set_context( - mode=context.GRAPH_MODE, device_target=args_opt.device_target - ) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) if args_opt.profiling: profiler = Profiler(output_path="/cache/profiler_data") context.set_context(variable_memory_max_size="30GB") @@ -139,21 +156,29 @@ def run_train(args_opt): ) # copy data from the cloud to the /cache/Data - cache_url = '/cache/Data/' - eval_cache_url = '/cache/EvalData/' + cache_url = "/cache/Data/" + eval_cache_url = "/cache/EvalData/" if not args_opt.offline: - download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank) - download_data(src_data_url=args_opt.eval_data_url, tgt_data_path=eval_cache_url, rank=rank) + download_data( + src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank + ) + download_data( + src_data_url=args_opt.eval_data_url, tgt_data_path=eval_cache_url, rank=rank + ) # Set model property model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) batch_size = args_opt.per_batch_size * data_parallel_num - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=bool(args_opt.optimizer_shard), - vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True, - gradient_aggregation_group=args_opt.gradient_aggregation_group) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=bool(args_opt.optimizer_shard), + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + gradient_aggregation_group=args_opt.gradient_aggregation_group, + ) micro_interleaved_size = args_opt.micro_interleaved_size config = PanguAlphaConfig( @@ -182,24 +207,40 @@ def run_train(args_opt): loss = CrossEntropyLoss(config.parallel_config.dp_mp_config) if micro_interleaved_size > 1: print("===using MicroBatchInterleaved", flush=True) - pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithFinetuneLoss(config, pangu_alpha, loss), - micro_interleaved_size) + pangu_alpha_with_loss_net = MicroBatchInterleaved( + PanGUAlphaWithFinetuneLoss(config, pangu_alpha, loss), + micro_interleaved_size, + ) else: - pangu_alpha_with_loss_net = PanGUAlphaWithFinetuneLoss(config, pangu_alpha, loss) + pangu_alpha_with_loss_net = PanGUAlphaWithFinetuneLoss( + config, pangu_alpha, loss + ) pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net) print("=====args_opt is: ", args_opt, flush=True) # Warm-up and cosine decay learning rate - lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr, - warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps) + lr = LearningRate( + learning_rate=args_opt.start_lr, + end_learning_rate=args_opt.end_lr, + warmup_steps=args_opt.warmup_step, + decay_steps=args_opt.decay_steps, + ) params = pangu_alpha_with_loss.trainable_params() group_params = set_weight_decay(params) if args_opt.optimizer == "lamb": optimizer = nn.Lamb(group_params, learning_rate=lr) elif args_opt.opt_offload: - optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95, - param_init_type=config.param_init_type) + optimizer = AdamWeightDecayOp( + group_params, + learning_rate=lr, + eps=1e-8, + beta1=0.9, + beta2=0.95, + param_init_type=config.param_init_type, + ) else: - optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) + optimizer = FP32StateAdamWeightDecay( + group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95 + ) # Initial scaling sens loss_scale_value = math.pow(2, 32) epoch_num = args_opt.epoch_size @@ -208,13 +249,19 @@ def run_train(args_opt): time.sleep(rank * 0.05) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B{rank}_21-{args_opt.load_ckpt_epoch}_2.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) # TODO: remove after warming-up! - param_dict.pop('global_step') + param_dict.pop("global_step") # TODO: add them back if not for the 1st run! # if param_dict.get("epoch_num") and param_dict.get("step_num"): # args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) @@ -223,7 +270,9 @@ def run_train(args_opt): os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 64 == 0: @@ -234,30 +283,47 @@ def run_train(args_opt): if args_opt.tb_dir is not None and rank == 0: os.makedirs(args_opt.tb_dir, exist_ok=True) summary_writer = SummaryWriter(args_opt.tb_dir) - os.system(f'chmod 777 -R {args_opt.tb_dir}') + os.system(f"chmod 777 -R {args_opt.tb_dir}") else: summary_writer = None # Dataset loading mindrecord files - ds, ds_eval = create_dataset(config.batch_size * micro_interleaved_size, data_path=args_opt.code_data, - args_opt=args_opt, data_start_index=0, - eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), - eod_id=args_opt.eod_id, - device_num=device_num, rank=rank, epoch=epoch_num, - train_and_eval=bool(args_opt.train_and_eval_mode), val_ratio=0.001) + ds, ds_eval = create_dataset( + config.batch_size * micro_interleaved_size, + data_path=args_opt.code_data, + args_opt=args_opt, + data_start_index=0, + eod_reset=config.eod_reset, + full_batch=bool(args_opt.full_batch), + eod_id=args_opt.eod_id, + device_num=device_num, + rank=rank, + epoch=epoch_num, + train_and_eval=bool(args_opt.train_and_eval_mode), + val_ratio=0.001, + ) actual_epoch_num = int(ds.get_dataset_size() / args_opt.sink_size) callback = [ TimeMonitor(args_opt.sink_size), ] - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) + update_cell = DynamicLossScaleUpdateCell( + loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000 + ) pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell( - pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, - config=config) + pangu_alpha_with_loss, + optimizer=optimizer, + scale_update_cell=update_cell, + enable_global_norm=True, + config=config, + ) if ds_eval: ppl_metric = PPLMetric(config.seq_length) validation_loss = ValidationLoss(config.seq_length) - model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss, - metrics={"ppl": ppl_metric, "validation_loss": validation_loss}) + model = Model( + pangu_alpha_with_grads, + eval_network=pangu_alpha_with_loss, + metrics={"ppl": ppl_metric, "validation_loss": validation_loss}, + ) callback.append( EvalCallBack( model=model, @@ -268,7 +334,7 @@ def run_train(args_opt): has_trained_step=args_opt.has_trained_steps, local_rank=rank, rank_size=device_num, - tb_writer=summary_writer + tb_writer=summary_writer, ) ) else: @@ -276,15 +342,25 @@ def run_train(args_opt): if args_opt.load_ckpt_epoch > 0: print("===build model and load ckpt") time_stamp = datetime.datetime.now() - print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before building", flush=True) - model.build(train_dataset=ds, sink_size=args_opt.sink_size, epoch=actual_epoch_num) + print( + f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before building", + flush=True, + ) + model.build( + train_dataset=ds, sink_size=args_opt.sink_size, epoch=actual_epoch_num + ) time_stamp = datetime.datetime.now() - print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before loading ckpt", flush=True) + print( + f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before loading ckpt", + flush=True, + ) load_param_into_net(pangu_alpha_with_loss, param_dict) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/2/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/2')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/2') + ) if num == device_num: break if rank % 64 == 0: @@ -306,18 +382,32 @@ def run_train(args_opt): if not args_opt.profiling: add_checkpoint_callback_policy(args_opt, callback, rank) if args_opt.incremental_training: - strategy = model.infer_train_layout(train_dataset=ds, sink_size=args_opt.sink_size) + strategy = model.infer_train_layout( + train_dataset=ds, sink_size=args_opt.sink_size + ) print("======start load_distributed checkpoint", flush=True) # For 2.6B and 13B models, the number of ckpt files is 512. - ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"filerted_{ckpt_rank}.ckpt") for ckpt_rank in - range(0, 512)] + ckpt_file_list = [ + os.path.join(args_opt.load_ckpt_path, f"filerted_{ckpt_rank}.ckpt") + for ckpt_rank in range(0, 512) + ] print(f"Loading from path {ckpt_file_list[0]}", flush=True) load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy) - print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True) + print( + "Dataset size: {}, actual_epoch_num: {}".format( + ds.get_dataset_size(), actual_epoch_num + ), + flush=True, + ) try: - model.train(10 if args_opt.profiling else actual_epoch_num, ds, callbacks=callback, - sink_size=args_opt.sink_size, dataset_sink_mode=True) + model.train( + 10 if args_opt.profiling else actual_epoch_num, + ds, + callbacks=callback, + sink_size=args_opt.sink_size, + dataset_sink_mode=True, + ) finally: if args_opt.profiling: jobid = os.environ["BATCH_JOB_ID"] @@ -325,12 +415,16 @@ def run_train(args_opt): rank_id = rank if context.get_context("save_graphs"): mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) if rank_id % 8 == 0: mox.file.make_dirs("s3://wudao-1/yyf/profiler_" + jobid) - mox.file.copy_parallel(src_url="/cache/profiler_data", - dst_url="s3://wudao-1/yyf/profiler_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/profiler_data", + dst_url="s3://wudao-1/yyf/profiler_" + jobid + "/" + str(rank_id), + ) if __name__ == "__main__": @@ -340,6 +434,8 @@ def run_train(args_opt): raise ValueError("The per_batch_size has not been configured.") if opt.stage_num > 1: if bool(opt.use_moe) or bool(opt.opt_offload): - raise ValueError("Currently, moe and host device mode is not supported in pipeline parallel.") + raise ValueError( + "Currently, moe and host device mode is not supported in pipeline parallel." + ) else: run_train(opt) diff --git a/codegeex/mindspore/generation.py b/codegeex/mindspore/generation.py index 422f877..66a82d9 100644 --- a/codegeex/mindspore/generation.py +++ b/codegeex/mindspore/generation.py @@ -39,12 +39,12 @@ def load_model(args_opt): r""" - The main function for load model + The main function for load model """ # Set execution mode - context.set_context(save_graphs=False, - mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -59,7 +59,8 @@ def load_model(args_opt): full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, - pipeline_stages=args_opt.stage_num) + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -68,26 +69,29 @@ def load_model(args_opt): device_num = 1 context.reset_auto_parallel_context() context.set_auto_parallel_context( - strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path + ) context.set_context( save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - use_past = (args_opt.use_past == "true") - print('local_rank:{}, start to run...'.format(rank), flush=True) + use_past = args_opt.use_past == "true" + print("local_rank:{}, start to run...".format(rank), flush=True) if args_opt.export: use_past = True # Set model property model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=False, - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=False, + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num @@ -109,7 +113,7 @@ def load_model(args_opt): parallel_config=parallel_config, load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 - if args_opt.param_init_type == 'fp32' + if args_opt.param_init_type == "fp32" else mstype.float16, ) print("===config is: ", config, flush=True) @@ -121,23 +125,31 @@ def load_model(args_opt): eval_net.set_train(False) model_predict = Model(eval_net) # Compile network and obtain tensor layout for loading ckpt - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) if args_opt.distribute == "false": predict_layout = None elif config.use_past: - batch_valid_length = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) + batch_valid_length = Tensor( + np.array([0 for _ in range(batch_size)]), mstype.int32 + ) init_true = Tensor([True], mstype.bool_) print("Input shape:", inputs_np.shape, flush=True) inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) print("is_first_iteration=True", flush=True) - predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length) + predict_layout = model_predict.infer_predict_layout( + inputs_np, current_index, init_true, batch_valid_length + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) print("is_first_iteration=False", flush=True) init_false = Tensor([False], mstype.bool_) - _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length) + _ = model_predict.infer_predict_layout( + inputs_np_1, current_index, init_false, batch_valid_length + ) else: predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) @@ -146,37 +158,53 @@ def load_model(args_opt): jobid = os.environ["BATCH_JOB_ID"] rank_id = rank mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) print("======start load_distributed checkpoint", flush=True) if args_opt.load_ckpt_epoch > 0: time.sleep(rank * 0.1) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) if param_dict.get("epoch_num") and param_dict.get("step_num"): args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 8 == 0: print("Loaded ckpt in step 1: ", num) time.sleep(1) net_not_load = load_param_into_net(pangu_alpha, param_dict) - print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True) + print( + "====== load_distributed checkpoint done, net_not_load: ", + net_not_load, + flush=True, + ) return model_predict, config, rank def export_mindir(model_predict, config): """Export mindir model""" - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) batch_valid_length = Tensor(np.array([0]), mstype.int32) @@ -184,19 +212,34 @@ def export_mindir(model_predict, config): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) - export(model_predict.predict_network, inputs_np, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1024", + file_format="MINDIR", + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) - export(model_predict.predict_network, inputs_np_1, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np_1, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1", + file_format="MINDIR", + ) print("Export finished and now exit.") def run_predict(model_predict, config, args_opt, rank): """run predict""" from src.generate import generate, generate_increment + # Define tokenizer - tokenizer = CodeTokenizer(mode='6b') + tokenizer = CodeTokenizer(mode="6b") # Tokenize input sentence to ids samples = [ @@ -204,7 +247,7 @@ def run_predict(model_predict, config, args_opt, rank): "# language: Python\ndef add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "def add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "# language: Python\ndef optimization():\n '''\n Find the maximum of P=E**2*R/(R + r)**2 if E and r are fixed but R varies. Import sympy. Use sympy. Find where the derivative is equal to zero. Substitute the value of R into P.\n '''\n", - "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n', "// language: JavaScript\nfunction prime(n) {\n // Find whether n is a prime number.\n", "string morse_encoder(string text) {\n // Translate text into Morse code\n", "def morse_encoder(text):\n # Translate text into Morse code separated by spaces\n", @@ -243,7 +286,8 @@ def run_predict(model_predict, config, args_opt, rank): print(f"=================== generation {i} ====================") print(output_samples_str, flush=True) print( - f"=== Total time (s): {t1 - t0}, {output_ids.shape[-1] - input_ids.shape[-1]} tokens, {(output_ids.shape[-1] - input_ids.shape[-1]) / (t1 - t0)} token/s") + f"=== Total time (s): {t1 - t0}, {output_ids.shape[-1] - input_ids.shape[-1]} tokens, {(output_ids.shape[-1] - input_ids.shape[-1]) / (t1 - t0)} token/s" + ) def main(): diff --git a/codegeex/mindspore/generation_1p.py b/codegeex/mindspore/generation_1p.py index 9706e11..bcf5788 100644 --- a/codegeex/mindspore/generation_1p.py +++ b/codegeex/mindspore/generation_1p.py @@ -39,12 +39,12 @@ def load_model(args_opt): r""" - The main function for load model + The main function for load model """ # Set execution mode - context.set_context(save_graphs=False, - mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -59,7 +59,8 @@ def load_model(args_opt): full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, - pipeline_stages=args_opt.stage_num) + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -68,13 +69,14 @@ def load_model(args_opt): device_num = 1 context.reset_auto_parallel_context() context.set_auto_parallel_context( - strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path + ) context.set_context( save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - use_past = (args_opt.use_past == "true") - print('local_rank:{}, start to run...'.format(rank), flush=True) + use_past = args_opt.use_past == "true" + print("local_rank:{}, start to run...".format(rank), flush=True) if args_opt.export: use_past = True # Set model property @@ -85,13 +87,15 @@ def load_model(args_opt): data_parallel_num = int(device_num / model_parallel_num) print("===data_parallel_num is: ", data_parallel_num, flush=True) - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=False, - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=False, + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num @@ -114,7 +118,7 @@ def load_model(args_opt): parallel_config=parallel_config, load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 - if args_opt.param_init_type == 'fp32' + if args_opt.param_init_type == "fp32" else mstype.float16, ) print("===config is: ", config, flush=True) @@ -126,7 +130,9 @@ def load_model(args_opt): eval_net.set_train(False) model_predict = Model(eval_net) # Compile network and obtain tensor layout for loading ckpt - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) if args_opt.distribute == "false": @@ -138,11 +144,15 @@ def load_model(args_opt): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) print("is_first_iteration=True", flush=True) - predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length) + predict_layout = model_predict.infer_predict_layout( + inputs_np, current_index, init_true, batch_valid_length + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) print("is_first_iteration=False", flush=True) init_false = Tensor([False], mstype.bool_) - _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length) + _ = model_predict.infer_predict_layout( + inputs_np_1, current_index, init_false, batch_valid_length + ) else: predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) @@ -151,38 +161,54 @@ def load_model(args_opt): jobid = os.environ["BATCH_JOB_ID"] rank_id = rank mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) print("======start load_distributed checkpoint", flush=True) if args_opt.load_ckpt_epoch > 0: time.sleep(rank * 0.5) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B0-{args_opt.load_ckpt_epoch}.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) # TODO: add them back if not for the 1st run! if param_dict.get("epoch_num") and param_dict.get("step_num"): args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 8 == 0: print("Loaded ckpt in step 1: ", num) time.sleep(1) net_not_load = load_param_into_net(pangu_alpha, param_dict) - print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True) + print( + "====== load_distributed checkpoint done, net_not_load: ", + net_not_load, + flush=True, + ) return model_predict, config, rank def export_mindir(model_predict, config): """Export mindir model""" - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) batch_valid_length = Tensor(np.array([0]), mstype.int32) @@ -190,26 +216,41 @@ def export_mindir(model_predict, config): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) - export(model_predict.predict_network, inputs_np, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1024", + file_format="MINDIR", + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) - export(model_predict.predict_network, inputs_np_1, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np_1, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1", + file_format="MINDIR", + ) print("Export finished and now exit.") def run_predict(model_predict, config, args_opt, rank): """run predict""" from src.generate import generate, generate_increment + # Define tokenizer - tokenizer = CodeTokenizer(mode='6b') + tokenizer = CodeTokenizer(mode="6b") # Tokenize input sentence to ids samples = [ "# language: Python\ndef add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "def add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "# language: Python\ndef optimization():\n '''\n Find the maximum of P=E**2*R/(R + r)**2 if E and r are fixed but R varies. Import sympy. Use sympy. Find where the derivative is equal to zero. Substitute the value of R into P.\n '''\n", - "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n', "// language: C++\nint add(int a, int b) {\n /* Find the sum of a and b. */\n", "int add(int a, int b) {\n /* Find the sum of a and b. */\n", "bool prime(int n) {\n // Find whether n is a prime number\n", @@ -235,7 +276,8 @@ def run_predict(model_predict, config, args_opt, rank): print(f"=================== generation {i} ====================") print(output_samples_str, flush=True) print( - f"=== Total time (s): {t1 - t0}, {output_ids.shape[-1] - input_ids.shape[-1]} tokens, {(output_ids.shape[-1] - input_ids.shape[-1]) / (t1 - t0)} token/s") + f"=== Total time (s): {t1 - t0}, {output_ids.shape[-1] - input_ids.shape[-1]} tokens, {(output_ids.shape[-1] - input_ids.shape[-1]) / (t1 - t0)} token/s" + ) break diff --git a/codegeex/mindspore/generation_batch.py b/codegeex/mindspore/generation_batch.py index 06798fc..ed06747 100644 --- a/codegeex/mindspore/generation_batch.py +++ b/codegeex/mindspore/generation_batch.py @@ -39,12 +39,12 @@ def load_model(args_opt): r""" - The main function for load model + The main function for load model """ # Set execution mode - context.set_context(save_graphs=False, - mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -59,7 +59,8 @@ def load_model(args_opt): full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, - pipeline_stages=args_opt.stage_num) + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -68,26 +69,29 @@ def load_model(args_opt): device_num = 1 context.reset_auto_parallel_context() context.set_auto_parallel_context( - strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path + ) context.set_context( save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - use_past = (args_opt.use_past == "true") - print('local_rank:{}, start to run...'.format(rank), flush=True) + use_past = args_opt.use_past == "true" + print("local_rank:{}, start to run...".format(rank), flush=True) if args_opt.export: use_past = True # Set model property model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=False, - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=False, + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num @@ -107,7 +111,7 @@ def load_model(args_opt): parallel_config=parallel_config, load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 - if args_opt.param_init_type == 'fp32' + if args_opt.param_init_type == "fp32" else mstype.float16, ) print("===config is: ", config, flush=True) @@ -119,23 +123,31 @@ def load_model(args_opt): eval_net.set_train(False) model_predict = Model(eval_net) # Compile network and obtain tensor layout for loading ckpt - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) if args_opt.distribute == "false": predict_layout = None elif config.use_past: - batch_valid_length = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) + batch_valid_length = Tensor( + np.array([0 for _ in range(batch_size)]), mstype.int32 + ) init_true = Tensor([True], mstype.bool_) print("Input shape:", inputs_np.shape, flush=True) inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) print("is_first_iteration=True", flush=True) - predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length) + predict_layout = model_predict.infer_predict_layout( + inputs_np, current_index, init_true, batch_valid_length + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) print("is_first_iteration=False", flush=True) init_false = Tensor([False], mstype.bool_) - _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length) + _ = model_predict.infer_predict_layout( + inputs_np_1, current_index, init_false, batch_valid_length + ) else: predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) @@ -144,37 +156,53 @@ def load_model(args_opt): jobid = os.environ["BATCH_JOB_ID"] rank_id = rank mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) print("======start load_distributed checkpoint", flush=True) if args_opt.load_ckpt_epoch > 0: time.sleep(rank * 0.1) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) if param_dict.get("epoch_num") and param_dict.get("step_num"): args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 8 == 0: print("Loaded ckpt in step 1: ", num) time.sleep(1) net_not_load = load_param_into_net(pangu_alpha, param_dict) - print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True) + print( + "====== load_distributed checkpoint done, net_not_load: ", + net_not_load, + flush=True, + ) return model_predict, config, rank def export_mindir(model_predict, config): """Export mindir model""" - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) batch_valid_length = Tensor(np.array([0]), mstype.int32) @@ -182,19 +210,34 @@ def export_mindir(model_predict, config): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) - export(model_predict.predict_network, inputs_np, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1024", + file_format="MINDIR", + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) - export(model_predict.predict_network, inputs_np_1, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np_1, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1", + file_format="MINDIR", + ) print("Export finished and now exit.") def run_predict(model_predict, config, args_opt, rank): """run predict""" from src.generate_finetune import generate_increment + # Define tokenizer - tokenizer = CodeTokenizer(mode='6b') + tokenizer = CodeTokenizer(mode="6b") # Tokenize input sentence to ids samples = [ @@ -202,7 +245,7 @@ def run_predict(model_predict, config, args_opt, rank): "# language: Python\ndef add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "def add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "# language: Python\ndef optimization():\n '''\n Find the maximum of P=E**2*R/(R + r)**2 if E and r are fixed but R varies. Import sympy. Use sympy. Find where the derivative is equal to zero. Substitute the value of R into P.\n '''\n", - "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n', "// language: JavaScript\nfunction prime(n) {\n // Find whether n is a prime number.\n", "string morse_encoder(string text) {\n // Translate text into Morse code\n", "def morse_encoder(text):\n # Translate text into Morse code separated by spaces\n", @@ -225,16 +268,16 @@ def run_predict(model_predict, config, args_opt, rank): samples = [tokenizer.encode_code(l) for l in samples] generations = [] batch_size = config.batch_size - verbose = (rank % 8 == 0) - save_path = f'/home/work/sfs/xx/pangu_alpha_code/generation_batch/{args_opt.temperature}.txt' # TODO: set as current save path + verbose = rank % 8 == 0 + save_path = f"/home/work/sfs/xx/pangu_alpha_code/generation_batch/{args_opt.temperature}.txt" # TODO: set as current save path save_dir = os.path.split(save_path)[0] if rank == 0: if not os.path.exists(save_dir): os.makedirs(save_dir) if not os.path.exists(save_path): - f = open(save_path, 'w') + f = open(save_path, "w") f.close() - os.system(f'sudo chmod 777 -R {save_dir}') + os.system(f"sudo chmod 777 -R {save_dir}") batch = [] input_length = [] sample_ids = [] @@ -247,23 +290,31 @@ def run_predict(model_predict, config, args_opt, rank): if (i + 1) % batch_size == 0: valid_length = max(input_length) for j in range(len(batch)): - batch[j] = np.pad(batch[j], ((0, 0), (0, valid_length - input_length[j])), - 'constant', constant_values=(args_opt.end_token, args_opt.end_token)) + batch[j] = np.pad( + batch[j], + ((0, 0), (0, valid_length - input_length[j])), + "constant", + constant_values=(args_opt.end_token, args_opt.end_token), + ) input_ids = np.concatenate(batch, axis=0) t0 = time.perf_counter() - output_ids = generate_increment(model_predict, input_ids, input_length, args_opt, tokenizer, verbose) + output_ids = generate_increment( + model_predict, input_ids, input_length, args_opt, tokenizer, verbose + ) t1 = time.perf_counter() batch, input_length = [], [] if rank % 8 == 0: print(f"=== Batch time: {t1 - t0}s") for k, out in enumerate(output_ids): - if not out.endswith('\n'): - out = out + '\n' - print(f"=================== generation {sample_ids[k]} ====================") + if not out.endswith("\n"): + out = out + "\n" + print( + f"=================== generation {sample_ids[k]} ====================" + ) print(out, flush=True) generations.append(out) if rank == 0: - f = open(save_path, 'a') + f = open(save_path, "a") f.write(generations[-1]) f.close() sample_ids = [] @@ -273,24 +324,32 @@ def run_predict(model_predict, config, args_opt, rank): input_length.append(-1) valid_length = max(input_length) for j in range(len(batch)): - batch[j] = np.pad(batch[j], ((0, 0), (0, valid_length - batch[j].shape[1])), - 'constant', constant_values=(args_opt.end_token, args_opt.end_token)) + batch[j] = np.pad( + batch[j], + ((0, 0), (0, valid_length - batch[j].shape[1])), + "constant", + constant_values=(args_opt.end_token, args_opt.end_token), + ) input_ids = np.concatenate(batch, axis=0) t0 = time.perf_counter() - output_ids = generate_increment(model_predict, input_ids, input_length, args_opt, tokenizer, verbose) + output_ids = generate_increment( + model_predict, input_ids, input_length, args_opt, tokenizer, verbose + ) t1 = time.perf_counter() if rank % 8 == 0: print(f"=== Batch time: {t1 - t0}s") for k, out in enumerate(output_ids): if input_length[k] == -1: break - if not out.endswith('\n'): - out = out + '\n' - print(f"=================== generation {sample_ids[k]} ====================") + if not out.endswith("\n"): + out = out + "\n" + print( + f"=================== generation {sample_ids[k]} ====================" + ) print(out, flush=True) generations.append(out) if rank == 0: - f = open(save_path, 'a') + f = open(save_path, "a") f.write(generations[-1]) f.close() diff --git a/codegeex/mindspore/generation_finetune.py b/codegeex/mindspore/generation_finetune.py index 9f2a96e..5efa519 100644 --- a/codegeex/mindspore/generation_finetune.py +++ b/codegeex/mindspore/generation_finetune.py @@ -40,12 +40,12 @@ def load_model(args_opt): r""" - The main function for load model + The main function for load model """ # Set execution mode - context.set_context(save_graphs=False, - mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -60,7 +60,8 @@ def load_model(args_opt): full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, - pipeline_stages=args_opt.stage_num) + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -69,13 +70,14 @@ def load_model(args_opt): device_num = 1 context.reset_auto_parallel_context() context.set_auto_parallel_context( - strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path + ) context.set_context( save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - use_past = (args_opt.use_past == "true") - print('local_rank:{}, start to run...'.format(rank), flush=True) + use_past = args_opt.use_past == "true" + print("local_rank:{}, start to run...".format(rank), flush=True) if args_opt.export: use_past = True # Set model property @@ -83,13 +85,15 @@ def load_model(args_opt): data_parallel_num = int(device_num / model_parallel_num) # data_parallel_num = 1 - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=False, - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=False, + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num @@ -109,7 +113,7 @@ def load_model(args_opt): parallel_config=parallel_config, load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 - if args_opt.param_init_type == 'fp32' + if args_opt.param_init_type == "fp32" else mstype.float16, ) print("===config is: ", config, flush=True) @@ -121,23 +125,31 @@ def load_model(args_opt): eval_net.set_train(False) model_predict = Model(eval_net) # Compile network and obtain tensor layout for loading ckpt - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) if args_opt.distribute == "false": predict_layout = None elif config.use_past: - batch_valid_length = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) + batch_valid_length = Tensor( + np.array([0 for _ in range(batch_size)]), mstype.int32 + ) init_true = Tensor([True], mstype.bool_) print("Input shape:", inputs_np.shape, flush=True) inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) print("is_first_iteration=True", flush=True) - predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length) + predict_layout = model_predict.infer_predict_layout( + inputs_np, current_index, init_true, batch_valid_length + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) print("is_first_iteration=False", flush=True) init_false = Tensor([False], mstype.bool_) - _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length) + _ = model_predict.infer_predict_layout( + inputs_np_1, current_index, init_false, batch_valid_length + ) else: predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) @@ -146,38 +158,54 @@ def load_model(args_opt): jobid = os.environ["BATCH_JOB_ID"] rank_id = rank mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) print("======start load_distributed checkpoint", flush=True) if args_opt.load_ckpt_epoch > 0: time.sleep(rank * 0.1) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) if param_dict.get("epoch_num") and param_dict.get("step_num"): args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 8 == 0: print("Loaded ckpt in step 1: ", num) time.sleep(1) net_not_load = load_param_into_net(pangu_alpha, param_dict) - print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True) + print( + "====== load_distributed checkpoint done, net_not_load: ", + net_not_load, + flush=True, + ) return model_predict, config, rank def export_mindir(model_predict, config): """Export mindir model""" - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) batch_valid_length = Tensor(np.array([0]), mstype.int32) @@ -185,23 +213,38 @@ def export_mindir(model_predict, config): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) - export(model_predict.predict_network, inputs_np, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1024", + file_format="MINDIR", + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) - export(model_predict.predict_network, inputs_np_1, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np_1, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1", + file_format="MINDIR", + ) print("Export finished and now exit.") def run_predict(model_predict, config, args_opt, rank): """run predict""" from src.generate_finetune import generate_increment + # Define tokenizer - tokenizer = CodeTokenizer(mode='6b') + tokenizer = CodeTokenizer(mode="6b") # Tokenize input sentence to ids lang = args_opt.language - data_path = os.path.join(args_opt.code_data, lang, 'test') + data_path = os.path.join(args_opt.code_data, lang, "test") dataset = LMDBDataset(data_path) samples = [] for i in range(len(dataset)): @@ -209,16 +252,16 @@ def run_predict(model_predict, config, args_opt, rank): samples.append(prompt[:length]) generations = [] batch_size = config.batch_size - verbose = (rank % 8 == 0) - save_path = f'/home/work/sfs/xx/pangu_alpha_code/generation_finetune/code_translation/{lang}/temp_{args_opt.temperature}.txt' # TODO: set as current save path + verbose = rank % 8 == 0 + save_path = f"/home/work/sfs/xx/pangu_alpha_code/generation_finetune/code_translation/{lang}/temp_{args_opt.temperature}.txt" # TODO: set as current save path save_dir = os.path.split(save_path)[0] if rank == 0: if not os.path.exists(save_dir): os.makedirs(save_dir) if not os.path.exists(save_path): - f = open(save_path, 'w') + f = open(save_path, "w") f.close() - os.system(f'sudo chmod 777 -R {os.path.split(save_dir)[0]}') + os.system(f"sudo chmod 777 -R {os.path.split(save_dir)[0]}") batch = [] input_length = [] sample_ids = [] @@ -231,24 +274,32 @@ def run_predict(model_predict, config, args_opt, rank): if (i + 1) % batch_size == 0: valid_length = max(input_length) for j in range(len(batch)): - batch[j] = np.pad(batch[j], ((0, 0), (0, valid_length - input_length[j])), - 'constant', constant_values=(args_opt.end_token, args_opt.end_token)) + batch[j] = np.pad( + batch[j], + ((0, 0), (0, valid_length - input_length[j])), + "constant", + constant_values=(args_opt.end_token, args_opt.end_token), + ) input_ids = np.concatenate(batch, axis=0) t0 = time.perf_counter() - output_ids = generate_increment(model_predict, input_ids, input_length, args_opt, tokenizer, verbose) + output_ids = generate_increment( + model_predict, input_ids, input_length, args_opt, tokenizer, verbose + ) t1 = time.perf_counter() batch, input_length = [], [] if rank % 8 == 0: print(f"=== Batch time: {t1 - t0}s") for k, out in enumerate(output_ids): - print(f"=================== generation {sample_ids[k]} ====================") + print( + f"=================== generation {sample_ids[k]} ====================" + ) print(out, flush=True) generations.append(out) if rank == 0: - f = open(save_path, 'a') + f = open(save_path, "a") f.write(generations[-1]) - if not generations[-1].endswith('\n'): - f.write('\n') + if not generations[-1].endswith("\n"): + f.write("\n") f.close() sample_ids = [] if len(batch) > 0: @@ -257,11 +308,17 @@ def run_predict(model_predict, config, args_opt, rank): batch.append(np.zeros((1, 1))) input_length.append(-1) for j in range(len(batch)): - batch[j] = np.pad(batch[j], ((0, 0), (0, valid_length - batch[j].shape[1])), - 'constant', constant_values=(args_opt.end_token, args_opt.end_token)) + batch[j] = np.pad( + batch[j], + ((0, 0), (0, valid_length - batch[j].shape[1])), + "constant", + constant_values=(args_opt.end_token, args_opt.end_token), + ) input_ids = np.concatenate(batch, axis=0) t0 = time.perf_counter() - output_ids = generate_increment(model_predict, input_ids, input_length, args_opt, tokenizer, verbose) + output_ids = generate_increment( + model_predict, input_ids, input_length, args_opt, tokenizer, verbose + ) t1 = time.perf_counter() if rank % 8 == 0: print(f"=== Batch time: {t1 - t0}s") @@ -271,14 +328,16 @@ def run_predict(model_predict, config, args_opt, rank): for k, out in enumerate(output_ids): if input_length[k] == -1: break - print(f"=================== generation {sample_ids[k]} ====================") + print( + f"=================== generation {sample_ids[k]} ====================" + ) print(out, flush=True) generations.append(out) if rank == 0: - f = open(save_path, 'a') + f = open(save_path, "a") f.write(generations[-1]) - if not generations[-1].endswith('\n'): - f.write('\n') + if not generations[-1].endswith("\n"): + f.write("\n") f.close() diff --git a/codegeex/mindspore/generation_humaneval.py b/codegeex/mindspore/generation_humaneval.py index 030ff47..8b18bd2 100644 --- a/codegeex/mindspore/generation_humaneval.py +++ b/codegeex/mindspore/generation_humaneval.py @@ -41,12 +41,12 @@ def load_model(args_opt): r""" - The main function for load model + The main function for load model """ # Set execution mode - context.set_context(save_graphs=False, - mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -61,7 +61,8 @@ def load_model(args_opt): full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, - pipeline_stages=args_opt.stage_num) + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -70,26 +71,29 @@ def load_model(args_opt): device_num = 1 context.reset_auto_parallel_context() context.set_auto_parallel_context( - strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path + ) context.set_context( save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - use_past = (args_opt.use_past == "true") - print('local_rank:{}, start to run...'.format(rank), flush=True) + use_past = args_opt.use_past == "true" + print("local_rank:{}, start to run...".format(rank), flush=True) if args_opt.export: use_past = True # Set model property model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=False, - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=False, + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num @@ -109,7 +113,7 @@ def load_model(args_opt): parallel_config=parallel_config, load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 - if args_opt.param_init_type == 'fp32' + if args_opt.param_init_type == "fp32" else mstype.float16, ) print("===config is: ", config, flush=True) @@ -121,23 +125,31 @@ def load_model(args_opt): eval_net.set_train(False) model_predict = Model(eval_net) # Compile network and obtain tensor layout for loading ckpt - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) if args_opt.distribute == "false": predict_layout = None elif config.use_past: - batch_valid_length = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) + batch_valid_length = Tensor( + np.array([0 for _ in range(batch_size)]), mstype.int32 + ) init_true = Tensor([True], mstype.bool_) print("Input shape:", inputs_np.shape, flush=True) inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) print("is_first_iteration=True", flush=True) - predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length) + predict_layout = model_predict.infer_predict_layout( + inputs_np, current_index, init_true, batch_valid_length + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) print("is_first_iteration=False", flush=True) init_false = Tensor([False], mstype.bool_) - _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length) + _ = model_predict.infer_predict_layout( + inputs_np_1, current_index, init_false, batch_valid_length + ) else: predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) @@ -146,37 +158,53 @@ def load_model(args_opt): jobid = os.environ["BATCH_JOB_ID"] rank_id = rank mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) print("======start load_distributed checkpoint", flush=True) if args_opt.load_ckpt_epoch > 0: time.sleep(rank * 0.1) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) if param_dict.get("epoch_num") and param_dict.get("step_num"): args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 8 == 0: print("Loaded ckpt in step 1: ", num) time.sleep(1) net_not_load = load_param_into_net(pangu_alpha, param_dict) - print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True) + print( + "====== load_distributed checkpoint done, net_not_load: ", + net_not_load, + flush=True, + ) return model_predict, config, rank def export_mindir(model_predict, config): """Export mindir model""" - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) batch_valid_length = Tensor(np.array([0]), mstype.int32) @@ -184,37 +212,52 @@ def export_mindir(model_predict, config): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) - export(model_predict.predict_network, inputs_np, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1024", + file_format="MINDIR", + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) - export(model_predict.predict_network, inputs_np_1, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np_1, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1", + file_format="MINDIR", + ) print("Export finished and now exit.") def run_predict(model_predict, config, args_opt, rank): """run predict""" from src.generate_humaneval import generate_increment + # Define tokenizer - tokenizer = CodeTokenizer(mode='6b') + tokenizer = CodeTokenizer(mode="6b") # Tokenize input sentence to ids - humaneval_path = '/home/work/sfs/xx/human_eval_x/data/humaneval_cpp.jsonl' # TODO: set as current humaneval path - humaneval = open(humaneval_path, 'r').readlines() + humaneval_path = "/home/work/sfs/xx/human_eval_x/data/humaneval_cpp.jsonl" # TODO: set as current humaneval path + humaneval = open(humaneval_path, "r").readlines() humaneval = [json.loads(task) for task in humaneval if len(task) != 0] - samples = [task['prompt'] for task in humaneval] + samples = [task["prompt"] for task in humaneval] generations = [] batch_size = config.batch_size - verbose = (rank % 8 == 0) + verbose = rank % 8 == 0 part = int(args_opt.part) - gen_times = 12 # TODO: set as generation times of current task + gen_times = 12 # TODO: set as generation times of current task print(f"gen times: {gen_times}, part: {part}") - save_path = f'/home/work/sfs/xx/pangu_alpha_code/generation_humanevalx/cpp/temp_{args_opt.temperature}/samples_{args_opt.load_ckpt_epoch}_part_{part}.jsonl' # TODO: set as current save path + save_path = f"/home/work/sfs/xx/pangu_alpha_code/generation_humanevalx/cpp/temp_{args_opt.temperature}/samples_{args_opt.load_ckpt_epoch}_part_{part}.jsonl" # TODO: set as current save path if rank == 0 and not os.path.exists(save_path): os.makedirs(os.path.split(save_path)[0], exist_ok=True) - f = open(save_path, 'w') + f = open(save_path, "w") f.close() - os.system(f'sudo chmod 777 {save_path}') + os.system(f"sudo chmod 777 {save_path}") for i, sample in enumerate(samples): tag = "// language: C++\n" sample = tag + sample @@ -223,22 +266,32 @@ def run_predict(model_predict, config, args_opt, rank): print(sample, flush=True) for j in range((gen_times + batch_size - 1) // batch_size): tokenized_token = tokenizer.encode_code(sample) - input_ids = np.array(tokenized_token).reshape(1, -1).repeat(batch_size, axis=0) + input_ids = ( + np.array(tokenized_token).reshape(1, -1).repeat(batch_size, axis=0) + ) # Call inference mindspore.set_seed(j + 8 * part) generate_func = generate_increment t0 = time.perf_counter() - output_ids = generate_func(model_predict, input_ids, args_opt, tokenizer, verbose) + output_ids = generate_func( + model_predict, input_ids, args_opt, tokenizer, verbose + ) t1 = time.perf_counter() if rank % 8 == 0: print(f"=== Batch time: {t1 - t0}s") for k, out in enumerate(output_ids): - print(f"=================== generation {j * batch_size + k} ====================") + print( + f"=================== generation {j * batch_size + k} ====================" + ) print(out, flush=True) - generations.append(json.dumps({'task_id': humaneval[i]['task_id'], 'completion': out})) + generations.append( + json.dumps( + {"task_id": humaneval[i]["task_id"], "completion": out} + ) + ) if rank == 0: - f = open(save_path, 'a') - f.write(generations[-1] + '\n') + f = open(save_path, "a") + f.write(generations[-1] + "\n") f.close() diff --git a/codegeex/mindspore/generation_values.py b/codegeex/mindspore/generation_values.py index 7c81035..54b8780 100644 --- a/codegeex/mindspore/generation_values.py +++ b/codegeex/mindspore/generation_values.py @@ -38,12 +38,12 @@ def load_model(args_opt): r""" - The main function for load model + The main function for load model """ # Set execution mode - context.set_context(save_graphs=False, - mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -58,7 +58,8 @@ def load_model(args_opt): full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, - pipeline_stages=args_opt.stage_num) + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -67,26 +68,29 @@ def load_model(args_opt): device_num = 1 context.reset_auto_parallel_context() context.set_auto_parallel_context( - strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path + ) context.set_context( save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - use_past = (args_opt.use_past == "true") - print('local_rank:{}, start to run...'.format(rank), flush=True) + use_past = args_opt.use_past == "true" + print("local_rank:{}, start to run...".format(rank), flush=True) if args_opt.export: use_past = True # Set model property model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=False, - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=False, + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num @@ -106,7 +110,7 @@ def load_model(args_opt): parallel_config=parallel_config, load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 - if args_opt.param_init_type == 'fp32' + if args_opt.param_init_type == "fp32" else mstype.float16, ) print("===config is: ", config, flush=True) @@ -118,23 +122,31 @@ def load_model(args_opt): eval_net.set_train(False) model_predict = Model(eval_net) # Compile network and obtain tensor layout for loading ckpt - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) if args_opt.distribute == "false": predict_layout = None elif config.use_past: - batch_valid_length = Tensor(np.array([0 for _ in range(batch_size)]), mstype.int32) + batch_valid_length = Tensor( + np.array([0 for _ in range(batch_size)]), mstype.int32 + ) init_true = Tensor([True], mstype.bool_) print("Input shape:", inputs_np.shape, flush=True) inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) print("is_first_iteration=True", flush=True) - predict_layout = model_predict.infer_predict_layout(inputs_np, init_true, batch_valid_length) + predict_layout = model_predict.infer_predict_layout( + inputs_np, init_true, batch_valid_length + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) print("is_first_iteration=False", flush=True) init_false = Tensor([False], mstype.bool_) - _ = model_predict.infer_predict_layout(inputs_np_1, init_false, batch_valid_length) + _ = model_predict.infer_predict_layout( + inputs_np_1, init_false, batch_valid_length + ) else: predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) @@ -143,65 +155,206 @@ def load_model(args_opt): jobid = os.environ["BATCH_JOB_ID"] rank_id = rank mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) print("======start load_distributed checkpoint", flush=True) if args_opt.load_ckpt_epoch > 0: time.sleep(rank * 0.1) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) if param_dict.get("epoch_num") and param_dict.get("step_num"): args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 8 == 0: print("Loaded ckpt in step 1: ", num) time.sleep(1) net_not_load = load_param_into_net(pangu_alpha, param_dict) - print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True) + print( + "====== load_distributed checkpoint done, net_not_load: ", + net_not_load, + flush=True, + ) return model_predict, config, rank def run_predict(model_predict, config, args_opt, rank): """run predict""" # Define tokenizer - tokenizer = CodeTokenizer(mode='6b') + tokenizer = CodeTokenizer(mode="6b") # Tokenize input sentence to ids batch_size = config.batch_size input_ids = np.array( - [8189, 11059, 198, 29584, 25, 198, 11377, 1398, 28186, 1391, 198, 50268, 11377, 9037, 25131, 468, 26125, 36, - 3639, 7, 600, 21737, 997, 82, 11, 493, 11387, 8, 1391, 198, 50272, 1640, 357, 600, 1312, 796, 657, 26, 1312, - 1279, 997, 82, 13, 13664, 532, 352, 26, 1312, 29577, 1391, 198, 50276, 1640, 357, 600, 474, 796, 1312, 1343, - 352, 26, 474, 1279, 997, 82, 13, 13664, 26, 474, 29577, 1391, 198, 50280, 361, 357, 37372, 13, 8937, 7, 77, - 5700, 58, 72, 60, 532, 997, 82, 58, 73, 12962, 1279, 11387, 8, 1391, 198, 50284, 7783, 2081, 26, 198, 50280, - 92, 198, 50276, 92, 198, 50272, 92, 198, 50272, 7783, 3991, 26, 198, 50268, 92, 198, 92, 198, 5247, 25, 198], - dtype=np.int32) + [ + 8189, + 11059, + 198, + 29584, + 25, + 198, + 11377, + 1398, + 28186, + 1391, + 198, + 50268, + 11377, + 9037, + 25131, + 468, + 26125, + 36, + 3639, + 7, + 600, + 21737, + 997, + 82, + 11, + 493, + 11387, + 8, + 1391, + 198, + 50272, + 1640, + 357, + 600, + 1312, + 796, + 657, + 26, + 1312, + 1279, + 997, + 82, + 13, + 13664, + 532, + 352, + 26, + 1312, + 29577, + 1391, + 198, + 50276, + 1640, + 357, + 600, + 474, + 796, + 1312, + 1343, + 352, + 26, + 474, + 1279, + 997, + 82, + 13, + 13664, + 26, + 474, + 29577, + 1391, + 198, + 50280, + 361, + 357, + 37372, + 13, + 8937, + 7, + 77, + 5700, + 58, + 72, + 60, + 532, + 997, + 82, + 58, + 73, + 12962, + 1279, + 11387, + 8, + 1391, + 198, + 50284, + 7783, + 2081, + 26, + 198, + 50280, + 92, + 198, + 50276, + 92, + 198, + 50272, + 92, + 198, + 50272, + 7783, + 3991, + 26, + 198, + 50268, + 92, + 198, + 92, + 198, + 5247, + 25, + 198, + ], + dtype=np.int32, + ) valid_length = input_ids.shape[0] - input_ids = np.concatenate((input_ids, np.ones(2048 - valid_length, dtype=np.int32) * 50256)) + input_ids = np.concatenate( + (input_ids, np.ones(2048 - valid_length, dtype=np.int32) * 50256) + ) attention_mask = np.tril(np.ones((2048, 2048))) attention_mask[valid_length:] = 0 input_ids = input_ids.reshape(1, -1).repeat(config.batch_size, axis=0) current_index = valid_length - 1 if valid_length - 1 > 0 else 0 init = Tensor([False], mstype.bool_) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) - batch_valid_length = Tensor(np.array([current_index for _ in range(batch_size)]), mstype.int32) - output_logits = model_predict.predict(Tensor(input_ids, mstype.int32), - init, batch_valid_length) + batch_valid_length = Tensor( + np.array([current_index for _ in range(batch_size)]), mstype.int32 + ) + output_logits = model_predict.predict( + Tensor(input_ids, mstype.int32), init, batch_valid_length + ) output = output_logits.asnumpy() if rank == 0: - np.save("/home/work/sfs/xx/pangu_alpha_code/output_6_7375_8.13.npy", output) # TODO: set as current save path + np.save( + "/home/work/sfs/xx/pangu_alpha_code/output_6_7375_8.13.npy", output + ) # TODO: set as current save path os.system( - "chmod 777 /home/work/sfs/xx/pangu_alpha_code/output_6_7375_8.13.npy") # TODO: set as current save path + "chmod 777 /home/work/sfs/xx/pangu_alpha_code/output_6_7375_8.13.npy" + ) # TODO: set as current save path print("== Output shape: ", output.shape) diff --git a/codegeex/mindspore/generation_values_1p.py b/codegeex/mindspore/generation_values_1p.py index 2d43e3c..364d305 100644 --- a/codegeex/mindspore/generation_values_1p.py +++ b/codegeex/mindspore/generation_values_1p.py @@ -39,12 +39,12 @@ def load_model(args_opt): r""" - The main function for load model + The main function for load model """ # Set execution mode - context.set_context(save_graphs=False, - mode=context.GRAPH_MODE, - device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) context.set_context(variable_memory_max_size="30GB") # Set parallel context if args_opt.distribute == "true": @@ -59,7 +59,8 @@ def load_model(args_opt): full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, - pipeline_stages=args_opt.stage_num) + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() @@ -68,13 +69,14 @@ def load_model(args_opt): device_num = 1 context.reset_auto_parallel_context() context.set_auto_parallel_context( - strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path) + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path + ) context.set_context( save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - use_past = (args_opt.use_past == "true") - print('local_rank:{}, start to run...'.format(rank), flush=True) + use_past = args_opt.use_past == "true" + print("local_rank:{}, start to run...".format(rank), flush=True) if args_opt.export: use_past = True # Set model property @@ -85,13 +87,15 @@ def load_model(args_opt): data_parallel_num = int(device_num / model_parallel_num) print("===data_parallel_num is: ", data_parallel_num, flush=True) - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=False, - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=False, + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num @@ -114,7 +118,7 @@ def load_model(args_opt): parallel_config=parallel_config, load_ckpt_path=args_opt.load_ckpt_path, param_init_type=mstype.float32 - if args_opt.param_init_type == 'fp32' + if args_opt.param_init_type == "fp32" else mstype.float16, ) print("===config is: ", config, flush=True) @@ -126,7 +130,9 @@ def load_model(args_opt): eval_net.set_train(False) model_predict = Model(eval_net) # Compile network and obtain tensor layout for loading ckpt - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) if args_opt.distribute == "false": @@ -138,11 +144,15 @@ def load_model(args_opt): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) print("is_first_iteration=True", flush=True) - predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length) + predict_layout = model_predict.infer_predict_layout( + inputs_np, current_index, init_true, batch_valid_length + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) print("is_first_iteration=False", flush=True) init_false = Tensor([False], mstype.bool_) - _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_false, batch_valid_length) + _ = model_predict.infer_predict_layout( + inputs_np_1, current_index, init_false, batch_valid_length + ) else: predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) @@ -151,38 +161,54 @@ def load_model(args_opt): jobid = os.environ["BATCH_JOB_ID"] rank_id = rank mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) print("======start load_distributed checkpoint", flush=True) if args_opt.load_ckpt_epoch > 0: time.sleep(rank * 0.5) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B0-{args_opt.load_ckpt_epoch}.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) # TODO: add them back if not for the 1st run! if param_dict.get("epoch_num") and param_dict.get("step_num"): args_opt.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_opt.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 8 == 0: print("Loaded ckpt in step 1: ", num) time.sleep(1) net_not_load = load_param_into_net(pangu_alpha, param_dict) - print("====== load_distributed checkpoint done, net_not_load: ", net_not_load, flush=True) + print( + "====== load_distributed checkpoint done, net_not_load: ", + net_not_load, + flush=True, + ) return model_predict, config, rank def export_mindir(model_predict, config): """Export mindir model""" - inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) + inputs_np = Tensor( + np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32 + ) current_index = Tensor(np.array([0]), mstype.int32) batch_valid_length = Tensor(np.array([0]), mstype.int32) @@ -190,26 +216,41 @@ def export_mindir(model_predict, config): inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) model_predict.predict_network.add_flags_recursive(is_first_iteration=True) - export(model_predict.predict_network, inputs_np, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1024', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1024", + file_format="MINDIR", + ) model_predict.predict_network.add_flags_recursive(is_first_iteration=False) - export(model_predict.predict_network, inputs_np_1, current_index, - init_true, batch_valid_length, file_name='pangu_alpha_1', file_format='MINDIR') + export( + model_predict.predict_network, + inputs_np_1, + current_index, + init_true, + batch_valid_length, + file_name="pangu_alpha_1", + file_format="MINDIR", + ) print("Export finished and now exit.") def run_predict(model_predict, config, args_opt, rank): """run predict""" from src.generate import generate, generate_increment + # Define tokenizer - tokenizer = CodeTokenizer(mode='6b') + tokenizer = CodeTokenizer(mode="6b") # Tokenize input sentence to ids samples = [ "# language: Python\ndef add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "def add(a, b):\n '''\n Find the sum of a and b.\n '''\n", "# language: Python\ndef optimization():\n '''\n Find the maximum of P=E**2*R/(R + r)**2 if E and r are fixed but R varies. Import sympy. Use sympy. Find where the derivative is equal to zero. Substitute the value of R into P.\n '''\n", - "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n', "// language: C++\nint add(int a, int b) {\n /* Find the sum of a and b. */\n", "int add(int a, int b) {\n /* Find the sum of a and b. */\n", "bool prime(int n) {\n // Find whether n is a prime number\n", @@ -236,7 +277,8 @@ def run_predict(model_predict, config, args_opt, rank): print(f"=================== generation {i} ====================") print(output_samples_str, flush=True) print( - f"=== Total time (s): {t1 - t0}, {output_ids.shape[-1] - input_ids.shape[-1]} tokens, {(output_ids.shape[-1] - input_ids.shape[-1]) / (t1 - t0)} token/s") + f"=== Total time (s): {t1 - t0}, {output_ids.shape[-1] - input_ids.shape[-1]} tokens, {(output_ids.shape[-1] - input_ids.shape[-1]) / (t1 - t0)} token/s" + ) break diff --git a/codegeex/mindspore/save_1p_ckpt_from_8p_ckpt.py b/codegeex/mindspore/save_1p_ckpt_from_8p_ckpt.py index 17c2096..1a10b63 100644 --- a/codegeex/mindspore/save_1p_ckpt_from_8p_ckpt.py +++ b/codegeex/mindspore/save_1p_ckpt_from_8p_ckpt.py @@ -41,8 +41,14 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig import mindspore -from mindspore.train.serialization import load_checkpoint, build_searched_strategy, save_checkpoint, \ - merge_sliced_parameter, _convert_to_list, _convert_to_layout +from mindspore.train.serialization import ( + load_checkpoint, + build_searched_strategy, + save_checkpoint, + merge_sliced_parameter, + _convert_to_list, + _convert_to_layout, +) from mindspore.common.parameter import Parameter from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts from mindspore import Tensor @@ -51,7 +57,10 @@ from src.adam import AdamWeightDecayOp from src.dataset import create_dataset from src.pangu_alpha import PanGUAlphaWithLoss, PanguAlphaModel -from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell +from src.pangu_alpha_wrapcell import ( + PanguAlphaTrainOneStepWithLossScaleCell, + PanguAlphaTrainPipelineWithLossScaleCell, +) from src.pangu_alpha_config import set_parse, PanguAlphaConfig from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay from src.utils import download_data @@ -59,15 +68,18 @@ from mindspore.profiler import Profiler project_root = os.path.abspath( - os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..") -print('project_root:', project_root) + os.path.dirname(os.path.realpath(__file__)) + os.path.sep + ".." +) +print("project_root:", project_root) def set_weight_decay(params): """ Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest """ - decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() + decay_filter = ( + lambda x: "layernorm" not in x.name.lower() and "bias" not in x.name.lower() + ) decay_params = list(filter(decay_filter, params)) other_params = list(filter(lambda x: not decay_filter(x), params)) group_params = [ @@ -84,7 +96,12 @@ def add_checkpoint_callback_policy(args_param, callback, rank_id): """ if args_param.save_checkpoint: # checkpoint store epoch_num and step_num info - ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}] + ckpt_append_info = [ + { + "epoch_num": args_param.has_trained_epoches, + "step_num": args_param.has_trained_steps, + } + ] ckpt_config = CheckpointConfig( save_checkpoint_steps=args_param.save_checkpoint_steps, keep_checkpoint_max=args_param.keep_checkpoint_max, @@ -93,18 +110,22 @@ def add_checkpoint_callback_policy(args_param, callback, rank_id): ) # save checkpoint into rank directory - ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id), - directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), - config=ckpt_config) + ckpoint_cb = ModelCheckpoint( + prefix=args_param.ckpt_name_prefix + str(rank_id), + directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), + config=ckpt_config, + ) callback.append(ckpoint_cb) - saveckpt_cb = SaveCheckpointCallback(cache_dir=args_param.save_checkpoint_path, - bucket=args_param.save_checkpoint_obs_path, - local_rank=rank_id, - has_trained_epoch=args_param.has_trained_epoches, - has_trained_step=args_param.has_trained_steps, - syn_times=args_param.save_checkpoint_steps) + saveckpt_cb = SaveCheckpointCallback( + cache_dir=args_param.save_checkpoint_path, + bucket=args_param.save_checkpoint_obs_path, + local_rank=rank_id, + has_trained_epoch=args_param.has_trained_epoches, + has_trained_step=args_param.has_trained_steps, + syn_times=args_param.save_checkpoint_steps, + ) callback.append(saveckpt_cb) @@ -116,10 +137,15 @@ def set_parallel_context(args_opt): print("rank_id is {}, device_num is {}".format(rank, device_num)) context.reset_auto_parallel_context() context.set_auto_parallel_context( - parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, - full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, - enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt', - optimizer_weight_shard_size=16, optimizer_weight_shard_aggregated_save=True) + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, + gradients_mean=False, + full_batch=bool(args_opt.full_batch), + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, + enable_parallel_optimizer=bool(args_opt.optimizer_shard), + strategy_ckpt_save_file="strategy.ckpt", + optimizer_weight_shard_size=16, + optimizer_weight_shard_aggregated_save=True, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() return rank, device_num @@ -129,16 +155,24 @@ def download_ckpt(args_opt, file_num, rank_num, rank_id): ckpt_list = [] for rank in range(0, file_num): ckpt_name = f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt" - local_file = os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}", ckpt_name) + local_file = os.path.join( + args_opt.save_checkpoint_path, f"origin_rank_{rank}", ckpt_name + ) ckpt_list.append(local_file) if rank % rank_num != rank_id: continue time.sleep(rank * 0.05) - if not os.path.exists(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}")): + if not os.path.exists( + os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}") + ): os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}")) - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), local_file) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), local_file + ) print("===download ckpt ok: ", local_file, flush=True) print(ckpt_list) return ckpt_list @@ -157,7 +191,9 @@ def get_needed_model_parallel_list(train_strategy_file, self_rank): return needed_ckpt_ranks -def transform_model_parallel(restore_local_ckpt_file_list, train_strategy_file, save_path, using_fp16=False): +def transform_model_parallel( + restore_local_ckpt_file_list, train_strategy_file, save_path, using_fp16=False +): # check whether the ckpt_file has been download for local_file in restore_local_ckpt_file_list: if not os.path.exists(local_file): @@ -183,23 +219,36 @@ def transform_model_parallel(restore_local_ckpt_file_list, train_strategy_file, if param_name not in strategy_keys: each_param = {"name": param_name} each_param["data"] = param_total_dict[param_name][0] - print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True) + print( + "====", + param_name, + param_total_dict[param_name][0].data.asnumpy().shape, + flush=True, + ) merged_param_list.append(each_param) continue param_unique_strategy = _remove_repeated_slices(train_strategy[param_name]) _param_unique_strategy = _convert_to_layout(param_name, param_unique_strategy) sliced_params = [] - if using_fp16 and "embedding" not in param_name and "layernorm" not in param_name: + if ( + using_fp16 + and "embedding" not in param_name + and "layernorm" not in param_name + ): for i in rank_list[param_name][0]: slice_param = param_total_dict[param_name][i] layerwise_parallel = slice_param.layerwise_parallel requires_grad = slice_param.requires_grad sliced_data = sliced_params.data.asnumpy() sliced_data = sliced_data.astype(np.float16) - paramete_fp16 = Parameter(Tensor(sliced_data), param_name, requires_grad, layerwise_parallel) + paramete_fp16 = Parameter( + Tensor(sliced_data), param_name, requires_grad, layerwise_parallel + ) sliced_params.append(paramete_fp16) else: - sliced_params = [param_total_dict[param_name][i] for i in rank_list[param_name][0]] + sliced_params = [ + param_total_dict[param_name][i] for i in rank_list[param_name][0] + ] merged_param = merge_sliced_parameter(sliced_params, _param_unique_strategy) each_param = {"name": param_name} each_param["data"] = merged_param @@ -212,9 +261,7 @@ def transform_model_parallel(restore_local_ckpt_file_list, train_strategy_file, def run_transform_model_parallel_ckpt(args_opt): # Set execution mode - context.set_context( - mode=context.GRAPH_MODE, device_target=args_opt.device_target - ) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) # Set parallel context rank = 0 device_num = 1 @@ -224,13 +271,19 @@ def run_transform_model_parallel_ckpt(args_opt): ckpt_file_list = download_ckpt(args_opt, 8, device_num, rank) if rank != 0: return - needed_ckpt_ranks = get_needed_model_parallel_list(args_opt.strategy_load_ckpt_path, rank) + needed_ckpt_ranks = get_needed_model_parallel_list( + args_opt.strategy_load_ckpt_path, rank + ) restore_local_ckpt_file_list = [ckpt_file_list[i] for i in needed_ckpt_ranks] - print("====restore_local_ckpt_file_list====", restore_local_ckpt_file_list, flush=True) + print( + "====restore_local_ckpt_file_list====", restore_local_ckpt_file_list, flush=True + ) save_path = os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}") if not os.path.exists(save_path): os.mkdir(save_path) - save_file = transform_model_parallel(restore_local_ckpt_file_list, args_opt.strategy_load_ckpt_path, save_path) + save_file = transform_model_parallel( + restore_local_ckpt_file_list, args_opt.strategy_load_ckpt_path, save_path + ) obs_save_path = args_opt.save_checkpoint_obs_path time.sleep(rank * 0.1) if not mox.file.exists(obs_save_path): @@ -238,7 +291,9 @@ def run_transform_model_parallel_ckpt(args_opt): rank_obs_save_path = os.path.join(obs_save_path, f"rank_{rank}") if not mox.file.exists(rank_obs_save_path): mox.file.make_dirs(rank_obs_save_path) - rank_obs_save_file = os.path.join(rank_obs_save_path, f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt") + rank_obs_save_file = os.path.join( + rank_obs_save_path, f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt" + ) if not os.path.exists(save_file): raise ValueError(save_file + " not exists") mox.file.copy(save_file, rank_obs_save_file) diff --git a/codegeex/mindspore/save_8p_ckpt.py b/codegeex/mindspore/save_8p_ckpt.py index eea3a15..1313ffc 100644 --- a/codegeex/mindspore/save_8p_ckpt.py +++ b/codegeex/mindspore/save_8p_ckpt.py @@ -37,25 +37,37 @@ from mindspore.parallel import set_algo_parameters from mindspore.parallel._cost_model_context import _set_multi_subgraphs from mindspore.train.callback import ModelCheckpoint, CheckpointConfig -from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net +from mindspore.train.serialization import ( + load_distributed_checkpoint, + load_checkpoint, + load_param_into_net, +) import mindspore -from mindspore.train.serialization import load_checkpoint, build_searched_strategy, save_checkpoint, \ - merge_sliced_parameter +from mindspore.train.serialization import ( + load_checkpoint, + build_searched_strategy, + save_checkpoint, + merge_sliced_parameter, +) from mindspore.common.parameter import Parameter from mindspore import Tensor from src.adam import AdamWeightDecayOp from src.dataset import create_dataset -from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell +from src.pangu_alpha_wrapcell import ( + PanguAlphaTrainOneStepWithLossScaleCell, + PanguAlphaTrainPipelineWithLossScaleCell, +) from src.pangu_alpha_config import set_parse, PanguAlphaConfig from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay from src.utils import download_data from mindspore.profiler import Profiler project_root = os.path.abspath( - os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..") -print('project_root:', project_root) + os.path.dirname(os.path.realpath(__file__)) + os.path.sep + ".." +) +print("project_root:", project_root) def set_parallel_context(args_opt): @@ -66,10 +78,15 @@ def set_parallel_context(args_opt): print("rank_id is {}, device_num is {}".format(rank, device_num)) context.reset_auto_parallel_context() context.set_auto_parallel_context( - parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, - full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, - enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt', - optimizer_weight_shard_size=16, optimizer_weight_shard_aggregated_save=True) + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, + gradients_mean=False, + full_batch=bool(args_opt.full_batch), + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, + enable_parallel_optimizer=bool(args_opt.optimizer_shard), + strategy_ckpt_save_file="strategy.ckpt", + optimizer_weight_shard_size=16, + optimizer_weight_shard_aggregated_save=True, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() return rank, device_num @@ -79,15 +96,21 @@ def download_ckpt(args_opt, file_num, rank_num, rank_id): ckpt_list = [] for rank in range(0, file_num): ckpt_name = f"code-13B{rank}_22-{args_opt.load_ckpt_epoch}_2.ckpt" - local_file = os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}", ckpt_name) + local_file = os.path.join( + args_opt.save_checkpoint_path, f"origin_rank_{rank}", ckpt_name + ) ckpt_list.append(local_file) if rank % rank_num != rank_id: continue time.sleep(rank * 0.05) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"origin_rank_{rank}")) - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), local_file) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), local_file + ) print("===download ckpt ok: ", local_file, flush=True) return ckpt_list @@ -102,7 +125,10 @@ def get_needed_opt_shard_list(train_strategy_file, self_rank): if opt_weight_shard_size <= 0: continue group_index = self_rank % opt_weight_shard_step - current_needed_ckpt_ranks = [group_index + i * opt_weight_shard_step for i in range(0, opt_weight_shard_size)] + current_needed_ckpt_ranks = [ + group_index + i * opt_weight_shard_step + for i in range(0, opt_weight_shard_size) + ] if len(current_needed_ckpt_ranks) > len(needed_ckpt_ranks): needed_ckpt_ranks = current_needed_ckpt_ranks return needed_ckpt_ranks @@ -127,7 +153,12 @@ def transform_opt_shard(restore_local_ckpt_file_list, train_strategy_file, save_ if param_name not in strategy_keys: each_param = {"name": param_name} each_param["data"] = param_total_dict[param_name][0] - print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True) + print( + "====", + param_name, + param_total_dict[param_name][0].data.asnumpy().shape, + flush=True, + ) merged_param_list.append(each_param) continue opt_weight_shard_size = train_strategy_origin[param_name].opt_weight_shard_size @@ -136,11 +167,19 @@ def transform_opt_shard(restore_local_ckpt_file_list, train_strategy_file, save_ print("====not opt shard:", param_name) each_param = {"name": param_name} each_param["data"] = param_total_dict[param_name][0] - print("====", param_name, param_total_dict[param_name][0].data.asnumpy().shape, flush=True) + print( + "====", + param_name, + param_total_dict[param_name][0].data.asnumpy().shape, + flush=True, + ) merged_param_list.append(each_param) continue print("====do opt shard:", param_name) - sliced_params = [param_total_dict[param_name][i] for i in range(len(param_total_dict[param_name]))] + sliced_params = [ + param_total_dict[param_name][i] + for i in range(len(param_total_dict[param_name])) + ] merged_param = merge_sliced_parameter(sliced_params, None) each_param = {"name": param_name} each_param["data"] = merged_param @@ -153,9 +192,7 @@ def transform_opt_shard(restore_local_ckpt_file_list, train_strategy_file, save_ def run_transform_opt_shard_ckpt(args_opt): # Set execution mode - context.set_context( - mode=context.GRAPH_MODE, device_target=args_opt.device_target - ) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) # Set parallel context rank = 0 device_num = 1 @@ -163,12 +200,18 @@ def run_transform_opt_shard_ckpt(args_opt): rank, device_num = set_parallel_context(args_opt) print("=====rank is: ", rank, flush=True) ckpt_file_list = download_ckpt(args_opt, 128, device_num, rank) - needed_ckpt_ranks = get_needed_opt_shard_list(args_opt.strategy_load_ckpt_path, rank) + needed_ckpt_ranks = get_needed_opt_shard_list( + args_opt.strategy_load_ckpt_path, rank + ) restore_local_ckpt_file_list = [ckpt_file_list[i] for i in needed_ckpt_ranks] - print("====restore_local_ckpt_file_list====", restore_local_ckpt_file_list, flush=True) + print( + "====restore_local_ckpt_file_list====", restore_local_ckpt_file_list, flush=True + ) save_path = os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}") os.mkdir(save_path) - save_file = transform_opt_shard(restore_local_ckpt_file_list, args_opt.strategy_load_ckpt_path, save_path) + save_file = transform_opt_shard( + restore_local_ckpt_file_list, args_opt.strategy_load_ckpt_path, save_path + ) obs_save_path = args_opt.save_checkpoint_obs_path time.sleep(rank * 0.1) if not mox.file.exists(obs_save_path): @@ -176,7 +219,9 @@ def run_transform_opt_shard_ckpt(args_opt): rank_obs_save_path = os.path.join(obs_save_path, f"rank_{rank}") if not mox.file.exists(rank_obs_save_path): mox.file.make_dirs(rank_obs_save_path) - rank_obs_save_file = os.path.join(rank_obs_save_path, f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt") + rank_obs_save_file = os.path.join( + rank_obs_save_path, f"code-13B{rank}-{args_opt.load_ckpt_epoch}.ckpt" + ) if not os.path.exists(save_file): raise ValueError(save_file + " not exists") mox.file.copy(save_file, rank_obs_save_file) diff --git a/codegeex/mindspore/scripts/layer_norm.py b/codegeex/mindspore/scripts/layer_norm.py index a33c59c..07b399a 100644 --- a/codegeex/mindspore/scripts/layer_norm.py +++ b/codegeex/mindspore/scripts/layer_norm.py @@ -40,11 +40,19 @@ # 'pylint: disable = unused-argument # 'pylint: disable=too-many-arguments,too-many-locals -def get_op_support_info(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - epsilon=1e-12, kernel_name="layer_norm", - impl_mode="high_performance"): +def get_op_support_info( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon=1e-12, + kernel_name="layer_norm", + impl_mode="high_performance", +): """ get_op_support_info """ @@ -58,19 +66,27 @@ def get_op_support_info(input_x, input_gamma, input_beta, if format_x in ("ND", "NCHW", "NHWC", "NC1HWC0"): if begin_params_axis == 0: for i in range(begin_norm_axis): - split_0 = [SplitInput([0, [i], [-1], [-1]], [1, [i], [-1], [-1]], [2, [i], [-1], [-1]]), - SplitOutput([0, [i]], [1, [i]], [2, [i]])] + split_0 = [ + SplitInput( + [0, [i], [-1], [-1]], [1, [i], [-1], [-1]], [2, [i], [-1], [-1]] + ), + SplitOutput([0, [i]], [1, [i]], [2, [i]]), + ] axis_split_matrix.append(split_0) else: if begin_norm_axis <= begin_params_axis: for i in range(begin_norm_axis): - split_0 = [SplitInput([0, [i], [-1], [-1]]), - SplitOutput([0, [i]], [1, [i]], [2, [i]])] + split_0 = [ + SplitInput([0, [i], [-1], [-1]]), + SplitOutput([0, [i]], [1, [i]], [2, [i]]), + ] axis_split_matrix.append(split_0) else: for i in range(begin_params_axis): - split_0 = [SplitInput([0, [i], [-1], [-1]]), - SplitOutput([0, [i]], [1, [i]], [2, [i]])] + split_0 = [ + SplitInput([0, [i], [-1], [-1]]), + SplitOutput([0, [i]], [1, [i]], [2, [i]]), + ] axis_split_matrix.append(split_0) elif format_x == "FRACTAL_NZ": @@ -81,8 +97,10 @@ def get_op_support_info(input_x, input_gamma, input_beta, no_split_axis = to_frac_z_axis(ori_shape_x, no_split_axis) for i in range(len(shape_x)): if i not in no_split_axis: - split_0 = [SplitInput([0, [i], [-1], [-1]]), - SplitOutput([0, [i]], [1, [i]], [2, [i]])] + split_0 = [ + SplitInput([0, [i], [-1], [-1]]), + SplitOutput([0, [i]], [1, [i]], [2, [i]]), + ] axis_split_matrix.append(split_0) else: @@ -101,14 +119,16 @@ def _division_sixteen(shape, begin_norm_axis): if len(shape) < 2: if shape[-1] == 0: error_detail = "value of shape_x is illegal" - error_manager_vector.raise_err_input_shape_invalid("layer_norm", "input_x", - error_detail) + error_manager_vector.raise_err_input_shape_invalid( + "layer_norm", "input_x", error_detail + ) return False if shape[-1] == 0 or shape[-2] == 0: error_detail = "value of shape_x is illegal" - error_manager_vector.raise_err_input_shape_invalid("layer_norm", "input_x", - error_detail) + error_manager_vector.raise_err_input_shape_invalid( + "layer_norm", "input_x", error_detail + ) is_reduce_last = begin_norm_axis in (-1, len(shape) - 1) # if shape[-2] % constant.C0_SIZE == 0: @@ -118,10 +138,17 @@ def _division_sixteen(shape, begin_norm_axis): # 'pylint: disable=too-many-statements,too-many-branches -def op_select_format(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - kernel_name="layer_norm"): +def op_select_format( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + kernel_name="layer_norm", +): """ select format dynamically """ @@ -132,138 +159,182 @@ def op_select_format(input_x, input_gamma, input_beta, if begin_params_axis == 0: if len(shape_gamma) >= 2 or (not _division_sixteen(shape_x, begin_norm_axis)): - input0 = util_select_op_base.gen_param(classify="input0", name="x", - datatype="float16,float16,float16,float16," - "float,float,float,float", - format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND") - - input1 = util_select_op_base.gen_param(classify="input1", name="gamma", - datatype="float16,float16,float16,float16,float," - "float,float,float", - format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND") - - input2 = util_select_op_base.gen_param(classify="input2", name="beta", - datatype="float16,float16,float16,float16,float," - "float,float,float", - format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND") - - output0 = util_select_op_base.gen_param(classify="output0", name="y", - datatype="float16,float16,float16,float16,float," - "float,float,float", - format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND") - - output1 = util_select_op_base.gen_param(classify="output1", name="mean", - datatype="float16,float16,float16,float16,float," - "float,float,float", - format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND") - - output2 = util_select_op_base.gen_param(classify="output2", name="variance", - datatype="float16,float16,float16,float16,float," - "float,float,float", - format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND") + input0 = util_select_op_base.gen_param( + classify="input0", + name="x", + datatype="float16,float16,float16,float16," "float,float,float,float", + format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND", + ) + + input1 = util_select_op_base.gen_param( + classify="input1", + name="gamma", + datatype="float16,float16,float16,float16,float," "float,float,float", + format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND", + ) + + input2 = util_select_op_base.gen_param( + classify="input2", + name="beta", + datatype="float16,float16,float16,float16,float," "float,float,float", + format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND", + ) + + output0 = util_select_op_base.gen_param( + classify="output0", + name="y", + datatype="float16,float16,float16,float16,float," "float,float,float", + format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND", + ) + + output1 = util_select_op_base.gen_param( + classify="output1", + name="mean", + datatype="float16,float16,float16,float16,float," "float,float,float", + format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND", + ) + + output2 = util_select_op_base.gen_param( + classify="output2", + name="variance", + datatype="float16,float16,float16,float16,float," "float,float,float", + format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND", + ) else: - input0 = util_select_op_base.gen_param(classify="input0", name="x", - datatype="float16,float,float16,float16,float16," - "float16,float,float,float,float", - format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC," - "ND,NCHW,NC1HWC0,NHWC,ND") - - input1 = util_select_op_base.gen_param(classify="input1", name="gamma", - datatype="float16,float,float16,float16,float16," - "float16,float,float,float,float", - format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0," - "NHWC,ND") - - input2 = util_select_op_base.gen_param(classify="input2", name="beta", - datatype="float16,float,float16,float16,float16," - "float16,float,float,float,float", - format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0," - "NHWC,ND") - - output0 = util_select_op_base.gen_param(classify="output0", name="y", - datatype="float16,float,float16,float16,float16," - "float16,float,float,float,float", - format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,ND," - "NCHW,NC1HWC0,NHWC,ND") - - output1 = util_select_op_base.gen_param(classify="output1", name="mean", - datatype="float16,float,float16,float16,float16," - "float16,float,float,float,float", - format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0," - "NHWC,ND") - - output2 = util_select_op_base.gen_param(classify="output2", name="variance", - datatype="float16,float,float16,float16,float16," - "float16,float,float,float,float", - format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0," - "NHWC,ND") + input0 = util_select_op_base.gen_param( + classify="input0", + name="x", + datatype="float16,float,float16,float16,float16," + "float16,float,float,float,float", + format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC," + "ND,NCHW,NC1HWC0,NHWC,ND", + ) + + input1 = util_select_op_base.gen_param( + classify="input1", + name="gamma", + datatype="float16,float,float16,float16,float16," + "float16,float,float,float,float", + format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0," "NHWC,ND", + ) + + input2 = util_select_op_base.gen_param( + classify="input2", + name="beta", + datatype="float16,float,float16,float16,float16," + "float16,float,float,float,float", + format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0," "NHWC,ND", + ) + + output0 = util_select_op_base.gen_param( + classify="output0", + name="y", + datatype="float16,float,float16,float16,float16," + "float16,float,float,float,float", + format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,ND," + "NCHW,NC1HWC0,NHWC,ND", + ) + + output1 = util_select_op_base.gen_param( + classify="output1", + name="mean", + datatype="float16,float,float16,float16,float16," + "float16,float,float,float,float", + format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0," "NHWC,ND", + ) + + output2 = util_select_op_base.gen_param( + classify="output2", + name="variance", + datatype="float16,float,float16,float16,float16," + "float16,float,float,float,float", + format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0," "NHWC,ND", + ) else: if len(shape_gamma) >= 2 or (not _division_sixteen(shape_x, begin_norm_axis)): - input0 = util_select_op_base.gen_param(classify="input0", name="x", - datatype="float16,float16,float16," - "float,float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - - input1 = util_select_op_base.gen_param(classify="input1", name="gamma", - datatype="float16,float16,float16," - "float,float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - - input2 = util_select_op_base.gen_param(classify="input2", name="beta", - datatype="float16,float16,float16," - "float,float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - - output0 = util_select_op_base.gen_param(classify="output0", name="y", - datatype="float16,float16,float16," - "float,float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - - output1 = util_select_op_base.gen_param(classify="output1", name="mean", - datatype="float16,float16,float16," - "float,float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - - output2 = util_select_op_base.gen_param(classify="output2", name="variance", - datatype="float16,float16,float16," - "float,float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") + input0 = util_select_op_base.gen_param( + classify="input0", + name="x", + datatype="float16,float16,float16," "float,float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + + input1 = util_select_op_base.gen_param( + classify="input1", + name="gamma", + datatype="float16,float16,float16," "float,float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + + input2 = util_select_op_base.gen_param( + classify="input2", + name="beta", + datatype="float16,float16,float16," "float,float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + + output0 = util_select_op_base.gen_param( + classify="output0", + name="y", + datatype="float16,float16,float16," "float,float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + + output1 = util_select_op_base.gen_param( + classify="output1", + name="mean", + datatype="float16,float16,float16," "float,float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + + output2 = util_select_op_base.gen_param( + classify="output2", + name="variance", + datatype="float16,float16,float16," "float,float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) else: - input0 = util_select_op_base.gen_param(classify="input0", name="x", - datatype="float16,float,float16,float16," - "float16,float,float,float", - format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC," - "ND,NCHW,NHWC,ND") - - input1 = util_select_op_base.gen_param(classify="input1", name="gamma", - datatype="float16,float,float16,float16," - "float16,float,float,float", - format="ND,ND,NCHW,NHWC,ND,NCHW," - "NHWC,ND") - - input2 = util_select_op_base.gen_param(classify="input2", name="beta", - datatype="float16,float,float16,float16," - "float16,float,float,float", - format="ND,ND,NCHW,NHWC,ND,NCHW," - "NHWC,ND") - - output0 = util_select_op_base.gen_param(classify="output0", name="y", - datatype="float16,float,float16,float16," - "float16,float,float,float", - format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,ND," - "NCHW,NHWC,ND") - - output1 = util_select_op_base.gen_param(classify="output1", name="mean", - datatype="float16,float,float16,float16," - "float16,float,float,float", - format="ND,ND,NCHW,NHWC,ND,NCHW," - "NHWC,ND") - - output2 = util_select_op_base.gen_param(classify="output2", name="variance", - datatype="float16,float,float16,float16," - "float16,float,float,float", - format="ND,ND,NCHW,NHWC,ND,NCHW," - "NHWC,ND") + input0 = util_select_op_base.gen_param( + classify="input0", + name="x", + datatype="float16,float,float16,float16," "float16,float,float,float", + format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC," "ND,NCHW,NHWC,ND", + ) + + input1 = util_select_op_base.gen_param( + classify="input1", + name="gamma", + datatype="float16,float,float16,float16," "float16,float,float,float", + format="ND,ND,NCHW,NHWC,ND,NCHW," "NHWC,ND", + ) + + input2 = util_select_op_base.gen_param( + classify="input2", + name="beta", + datatype="float16,float,float16,float16," "float16,float,float,float", + format="ND,ND,NCHW,NHWC,ND,NCHW," "NHWC,ND", + ) + + output0 = util_select_op_base.gen_param( + classify="output0", + name="y", + datatype="float16,float,float16,float16," "float16,float,float,float", + format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,ND," "NCHW,NHWC,ND", + ) + + output1 = util_select_op_base.gen_param( + classify="output1", + name="mean", + datatype="float16,float,float16,float16," "float16,float,float,float", + format="ND,ND,NCHW,NHWC,ND,NCHW," "NHWC,ND", + ) + + output2 = util_select_op_base.gen_param( + classify="output2", + name="variance", + datatype="float16,float,float16,float16," "float16,float,float,float", + format="ND,ND,NCHW,NHWC,ND,NCHW," "NHWC,ND", + ) param_list = [input0, input1, input2, output0, output1, output2] param_dynamic_in_json = util_select_op_base.get_dynamic_param_in_json(param_list) @@ -318,9 +389,11 @@ def _broadcast_nz(tensor, shape): for i, _ in enumerate(shape): if shape[i] != src_shape[i]: broadcast_axes.append(i) - if len(broadcast_axes) == 2 and \ - broadcast_axes[1] - broadcast_axes[0] != 1 and \ - broadcast_axes[1] + 1 == len(shape): + if ( + len(broadcast_axes) == 2 + and broadcast_axes[1] - broadcast_axes[0] != 1 + and broadcast_axes[1] + 1 == len(shape) + ): temp_shape = src_shape[:-1] + [shape[-1]] tensor = tbe.broadcast(tensor, temp_shape) tensor = tbe.broadcast(tensor, shape) @@ -337,11 +410,22 @@ def _check_vector_to_cube(dtype, ori_shape_x, shape_x, begin_norm_axis, impl_mod def _check_shape_and_dtype(): if dtype != "float16": return False - if len(ori_shape_x) not in (2, 3) or ori_shape_x[-1] not in (1024, 768, 96, 384, 192, 128, 512, 256): + if len(ori_shape_x) not in (2, 3) or ori_shape_x[-1] not in ( + 1024, + 768, + 96, + 384, + 192, + 128, + 512, + 256, + ): return False if len(shape_x) not in (4, 5) or shape_x[-4] not in (64, 48, 6, 12, 24, 16, 32): return False - if "Ascend910" not in get_soc_spec(SOC_VERSION) and "Ascend710" not in get_soc_spec(SOC_VERSION): + if "Ascend910" not in get_soc_spec( + SOC_VERSION + ) and "Ascend710" not in get_soc_spec(SOC_VERSION): return False if begin_norm_axis != (len(ori_shape_x) - 1): return False @@ -351,22 +435,33 @@ def _check_shape_and_dtype(): # 'pylint: disable=too-many-locals,too-many-statements,too-many-branches -def nz_non_aligned(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - ori_shape, epsilon, kernel_name="layer_norm", - impl_mode="high_performance"): +def nz_non_aligned( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + ori_shape, + epsilon, + kernel_name="layer_norm", + impl_mode="high_performance", +): """ DSL description of the layernorm operator's mathematical calculation process for non_aligned scene """ shape_x = shape_util.shape_to_list(input_x.shape) dtype = input_x.dtype.lower() cast_dtype = "float16" - if dtype == "float16" and \ - ((tbe_platform.cce_conf.api_check_support - ("te.lang.cce.vexp", "float32") and - impl_mode == "high_performance") or - impl_mode == "high_precision"): + if dtype == "float16" and ( + ( + tbe_platform.cce_conf.api_check_support("te.lang.cce.vexp", "float32") + and impl_mode == "high_performance" + ) + or impl_mode == "high_precision" + ): cast_dtype = "float32" input_x = tbe.cast_to(input_x, "float32") input_gamma = tbe.cast_to(input_gamma, "float32") @@ -405,8 +500,7 @@ def nz_non_aligned(input_x, input_gamma, input_beta, normalize_add = tbe.vadds(variance, epsilon) normalize_log = tbe.vlog(normalize_add) - normalize_log_mul = \ - tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype)) + normalize_log_mul = tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype)) normalize_exp = tbe.vexp(normalize_log_mul) variance_normalize_broadcast = _broadcast_nz(normalize_exp, shape_x) normalize_mul = tbe.vmul(normalize_sub, variance_normalize_broadcast) @@ -417,11 +511,13 @@ def nz_non_aligned(input_x, input_gamma, input_beta, scale_mul = tbe.vmul(gamma_broadcast, normalize_mul) res = tbe.vadd(scale_mul, beta_broadcast) - if dtype == "float16" and \ - ((tbe_platform.cce_conf.api_check_support - ("te.lang.cce.vexp", "float32") and - impl_mode == "high_performance") or - impl_mode == "high_precision"): + if dtype == "float16" and ( + ( + tbe_platform.cce_conf.api_check_support("te.lang.cce.vexp", "float32") + and impl_mode == "high_performance" + ) + or impl_mode == "high_precision" + ): mean = tbe.cast_to(mean, "float16") variance = tbe.cast_to(variance, "float16") res = tbe.cast_to(res, "float16") @@ -430,11 +526,20 @@ def nz_non_aligned(input_x, input_gamma, input_beta, # 'pylint: disable=too-many-statements,too-many-branches -def layer_norm_compute_nz(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - ori_shape, epsilon, kernel_name="layer_norm", - impl_mode="high_performance"): +def layer_norm_compute_nz( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + ori_shape, + epsilon, + kernel_name="layer_norm", + impl_mode="high_performance", +): """ DSL description of the layernorm operator's mathematical calculation process @@ -470,11 +575,13 @@ def layer_norm_compute_nz(input_x, input_gamma, input_beta, dtype = input_x.dtype.lower() cast_dtype, cast_fp16_dtype = "float16", "float16" cast_dtype_precision = dtype - if dtype == "float16" and \ - ((tbe_platform.cce_conf.api_check_support - ("te.lang.cce.vexp", "float32") and - impl_mode == "high_performance") or - impl_mode == "high_precision"): + if dtype == "float16" and ( + ( + tbe_platform.cce_conf.api_check_support("te.lang.cce.vexp", "float32") + and impl_mode == "high_performance" + ) + or impl_mode == "high_precision" + ): cast_dtype = "float32" cast_dtype_precision = "float32" input_x = tbe.cast_to(input_x, "float32") @@ -518,14 +625,11 @@ def layer_norm_compute_nz(input_x, input_gamma, input_beta, variance_normalize_broadcast = _broadcast_nz(variance, shape_x) normalize_add = tbe.vadds(variance_normalize_broadcast, epsilon) normalize_log = tbe.vlog(normalize_add) - normalize_log_mul = \ - tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype)) + normalize_log_mul = tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype)) normalize_exp = tbe.vexp(normalize_log_mul) normalize_mul = tbe.vmul(normalize_sub, normalize_exp) elif impl_mode == "high_precision": - tesor_one = tbe.broadcast(tvm.const - (1, cast_dtype_precision), - shape_x) + tesor_one = tbe.broadcast(tvm.const(1, cast_dtype_precision), shape_x) mean_normalize_broadcast = _broadcast_nz(mean, shape_x) normalize_sub = tbe.vsub(input_x, mean_normalize_broadcast) variance_normalize_broadcast = _broadcast_nz(variance, shape_x) @@ -538,8 +642,9 @@ def layer_norm_compute_nz(input_x, input_gamma, input_beta, epsilon = tvm.const(epsilon, dtype=cast_fp16_dtype) normalize_add = tbe.vadds(variance, epsilon) normalize_log = tbe.vlog(normalize_add) - normalize_log_mul = \ - tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_fp16_dtype)) + normalize_log_mul = tbe.vmuls( + normalize_log, tvm.const(-0.5, dtype=cast_fp16_dtype) + ) normalize_exp = tbe.vexp(normalize_log_mul) variance_normalize_broadcast = _broadcast_nz(normalize_exp, shape_x) normalize_mul = tbe.vmul(variance_sub, variance_normalize_broadcast) @@ -554,11 +659,13 @@ def layer_norm_compute_nz(input_x, input_gamma, input_beta, scale_mul = tbe.vmul(gamma_broadcast, normalize_mul) res = tbe.vadd(scale_mul, beta_broadcast) - if dtype == "float16" and \ - ((tbe_platform.cce_conf.api_check_support - ("te.lang.cce.vexp", "float32") and - impl_mode == "high_performance") or - impl_mode == "high_precision"): + if dtype == "float16" and ( + ( + tbe_platform.cce_conf.api_check_support("te.lang.cce.vexp", "float32") + and impl_mode == "high_performance" + ) + or impl_mode == "high_precision" + ): mean = tbe.cast_to(mean, "float16") variance = tbe.cast_to(variance, "float16") res = tbe.cast_to(res, "float16") @@ -568,11 +675,19 @@ def layer_norm_compute_nz(input_x, input_gamma, input_beta, # 'pylint: disable=too-many-statements,too-many-branches @tbe_platform.fusion_manager.fusion_manager.register("layer_norm") -def layer_norm_compute(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - epsilon, kernel_name="layer_norm", - impl_mode="high_performance"): +def layer_norm_compute( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + kernel_name="layer_norm", + impl_mode="high_performance", +): """ DSL description of the layernorm operator's mathematical calculation process @@ -608,11 +723,13 @@ def layer_norm_compute(input_x, input_gamma, input_beta, dtype = input_x.dtype.lower() cast_dtype, cast_fp16_dtype = "float16", "float16" cast_dtype_precision = dtype - if dtype == "float16" and \ - ((tbe_platform.cce_conf.api_check_support - ("te.lang.cce.vexp", "float32") and - impl_mode == "high_performance") or - impl_mode == "high_precision"): + if dtype == "float16" and ( + ( + tbe_platform.cce_conf.api_check_support("te.lang.cce.vexp", "float32") + and impl_mode == "high_performance" + ) + or impl_mode == "high_precision" + ): cast_dtype = "float32" cast_dtype_precision = "float32" input_x = tbe.cast_to(input_x, "float32") @@ -658,14 +775,11 @@ def layer_norm_compute(input_x, input_gamma, input_beta, variance_normalize_broadcast = tbe.broadcast(variance, shape_x) normalize_add = tbe.vadds(variance_normalize_broadcast, epsilon) normalize_log = tbe.vlog(normalize_add) - normalize_log_mul = \ - tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype)) + normalize_log_mul = tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_dtype)) normalize_exp = tbe.vexp(normalize_log_mul) normalize_mul = tbe.vmul(normalize_sub, normalize_exp) elif impl_mode == "high_precision": - tesor_one = tbe.broadcast(tvm.const - (1, cast_dtype_precision), - shape_x) + tesor_one = tbe.broadcast(tvm.const(1, cast_dtype_precision), shape_x) mean_normalize_broadcast = tbe.broadcast(mean, shape_x) normalize_sub = tbe.vsub(input_x, mean_normalize_broadcast) variance_normalize_broadcast = tbe.broadcast(variance, shape_x) @@ -678,8 +792,9 @@ def layer_norm_compute(input_x, input_gamma, input_beta, epsilon = tvm.const(epsilon, dtype=cast_fp16_dtype) normalize_add = tbe.vadds(variance, epsilon) normalize_log = tbe.vlog(normalize_add) - normalize_log_mul = \ - tbe.vmuls(normalize_log, tvm.const(-0.5, dtype=cast_fp16_dtype)) + normalize_log_mul = tbe.vmuls( + normalize_log, tvm.const(-0.5, dtype=cast_fp16_dtype) + ) normalize_exp = tbe.vexp(normalize_log_mul) variance_normalize_broadcast = tbe.broadcast(normalize_exp, shape_x) normalize_mul = tbe.vmul(variance_sub, variance_normalize_broadcast) @@ -694,11 +809,13 @@ def layer_norm_compute(input_x, input_gamma, input_beta, scale_mul = tbe.vmul(gamma_broadcast, normalize_mul) res = tbe.vadd(scale_mul, beta_broadcast) - if dtype == "float16" and \ - ((tbe_platform.cce_conf.api_check_support - ("te.lang.cce.vexp", "float32") and - impl_mode == "high_performance") or - impl_mode == "high_precision"): + if dtype == "float16" and ( + ( + tbe_platform.cce_conf.api_check_support("te.lang.cce.vexp", "float32") + and impl_mode == "high_performance" + ) + or impl_mode == "high_precision" + ): mean = tbe.cast_to(mean, "float16") variance = tbe.cast_to(variance, "float16") res = tbe.cast_to(res, "float16") @@ -717,17 +834,32 @@ def is_support_nz_non_aligned(ori_shape_x, begin_params_axis, impl_mode): return False -@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT, - para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, - para_check.REQUIRED_OUTPUT, para_check.REQUIRED_OUTPUT, - para_check.REQUIRED_ATTR_INT, para_check.REQUIRED_ATTR_INT, - para_check.OPTION_ATTR_FLOAT, para_check.KERNEL_NAME, - para_check.OPTION_ATTR_STR) -def layer_norm(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - epsilon=1e-12, kernel_name="layer_norm", - impl_mode="high_performance"): +@para_check.check_op_params( + para_check.REQUIRED_INPUT, + para_check.REQUIRED_INPUT, + para_check.REQUIRED_INPUT, + para_check.REQUIRED_OUTPUT, + para_check.REQUIRED_OUTPUT, + para_check.REQUIRED_OUTPUT, + para_check.REQUIRED_ATTR_INT, + para_check.REQUIRED_ATTR_INT, + para_check.OPTION_ATTR_FLOAT, + para_check.KERNEL_NAME, + para_check.OPTION_ATTR_STR, +) +def layer_norm( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon=1e-12, + kernel_name="layer_norm", + impl_mode="high_performance", +): """ layernorm operator interface implementation calculating: x, gamma, beta @@ -786,34 +918,60 @@ def layer_norm(input_x, input_gamma, input_beta, shape_beta = list(input_beta.get("shape")) flag_vector2cube = False - tik_support = if_tik_support(input_x, input_gamma, input_beta, output_y, output_mean, - output_variance, begin_norm_axis, begin_params_axis, epsilon) + tik_support = if_tik_support( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + ) if tik_support: - layer_normalize(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - epsilon, kernel_name) + layer_normalize( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + kernel_name, + ) else: if input_format == "FRACTAL_NZ": begin_norm_axis = shape_util.axis_check(len(ori_shape_x), begin_norm_axis) - begin_params_axis = shape_util.axis_check(len(ori_shape_x), begin_params_axis) + begin_params_axis = shape_util.axis_check( + len(ori_shape_x), begin_params_axis + ) - flag_vector2cube = _check_vector_to_cube(dtype, ori_shape_x, shape_x, begin_norm_axis, impl_mode) + flag_vector2cube = _check_vector_to_cube( + dtype, ori_shape_x, shape_x, begin_norm_axis, impl_mode + ) if input_gamma_format == "FRACTAL_NZ" or input_beta_format == "FRACTAL_NZ": error_detail = "gamma and beta not support Nz in bert" - error_manager_vector.raise_err_two_input_format_invalid(kernel_name, "input_gamma", - "input_beta", error_detail) + error_manager_vector.raise_err_two_input_format_invalid( + kernel_name, "input_gamma", "input_beta", error_detail + ) if shape_gamma != shape_beta: error_detail = "gamma and beta's shape must be same." - error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "input_gamma", - "input_beta", error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + kernel_name, "input_gamma", "input_beta", error_detail + ) if ori_shape_x[begin_params_axis:] != shape_gamma: error_detail = "x or gamma or begin_params_axis is wrong." - error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "x", - "input_gamma", error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + kernel_name, "x", "input_gamma", error_detail + ) if len(shape_gamma) > 1: error_detail = "shape of gamma or beta only support 1D in bert" - error_manager_vector.raise_err_input_shape_invalid(kernel_name, "input_gamma", error_detail) + error_manager_vector.raise_err_input_shape_invalid( + kernel_name, "input_gamma", error_detail + ) # make shape_x,shape_gamma,shape_beta dim same in vector case if not flag_vector2cube: @@ -826,7 +984,10 @@ def layer_norm(input_x, input_gamma, input_beta, shape_gamma.append(shape_x[-1]) if begin_params_axis > len(ori_shape_x) - 2: shape_x[-3:] = [shape_x[-3] * shape_x[-2], shape_x[-1]] - shape_gamma[-3:] = [shape_gamma[-3] * shape_gamma[-2], shape_gamma[-1]] + shape_gamma[-3:] = [ + shape_gamma[-3] * shape_gamma[-2], + shape_gamma[-1], + ] shape_beta = shape_gamma else: begin_norm_axis = shape_util.axis_check(len(shape_x), begin_norm_axis) @@ -834,8 +995,9 @@ def layer_norm(input_x, input_gamma, input_beta, if shape_gamma != shape_beta: error_detail = "gamma and beta's shape must be same." - error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "input_gamma", - "input_beta", error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + kernel_name, "input_gamma", "input_beta", error_detail + ) no_need_fix_gamma = False no_need_fix_beta = False if shape_x[begin_params_axis:] != shape_gamma: @@ -843,15 +1005,17 @@ def layer_norm(input_x, input_gamma, input_beta, no_need_fix_gamma = True else: error_detail = "x or gamma or begin_params_axis is wrong." - error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "x", - "input_gamma", error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + kernel_name, "x", "input_gamma", error_detail + ) if shape_x[begin_params_axis:] != shape_beta: if len(shape_x) == len(shape_beta): no_need_fix_beta = True else: error_detail = "x or gamma or begin_params_axis is wrong." - error_manager_vector.raise_err_two_input_shape_invalid(kernel_name, "x", - "input_beta", error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + kernel_name, "x", "input_beta", error_detail + ) # make shape_x,shape_gamma,shape_beta dim same if begin_params_axis != 0 and not no_need_fix_gamma: for i in range(begin_params_axis): @@ -869,68 +1033,153 @@ def layer_norm(input_x, input_gamma, input_beta, dyn_input_x = deepcopy(input_x) dyn_input_x["shape"] = shape_x if flag_vector2cube: - layer_norm_cube = LayerNormCube({"ori_shape": ori_shape_x, - "epsilon" : epsilon}) - mean, variance, res = \ - layer_norm_cube.layer_norm_cube_compute(data_x, data_gamma, data_beta) + layer_norm_cube = LayerNormCube( + {"ori_shape": ori_shape_x, "epsilon": epsilon} + ) + mean, variance, res = layer_norm_cube.layer_norm_cube_compute( + data_x, data_gamma, data_beta + ) elif is_support_nz_non_aligned(ori_shape_x, begin_params_axis, impl_mode): - mean, variance, res = \ - nz_non_aligned(data_x, data_gamma, data_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - ori_shape_x, epsilon, kernel_name, impl_mode) - elif layer_norm_unify.is_special_cases(dyn_input_x, input_gamma, input_beta, begin_norm_axis, impl_mode): - __dynamic_template_api(input_x, input_gamma, input_beta, output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, epsilon, kernel_name, impl_mode) + mean, variance, res = nz_non_aligned( + data_x, + data_gamma, + data_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + ori_shape_x, + epsilon, + kernel_name, + impl_mode, + ) + elif layer_norm_unify.is_special_cases( + dyn_input_x, input_gamma, input_beta, begin_norm_axis, impl_mode + ): + __dynamic_template_api( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + kernel_name, + impl_mode, + ) return else: - mean, variance, res = \ - layer_norm_compute_nz(data_x, data_gamma, data_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - ori_shape_x, epsilon, kernel_name, impl_mode) + mean, variance, res = layer_norm_compute_nz( + data_x, + data_gamma, + data_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + ori_shape_x, + epsilon, + kernel_name, + impl_mode, + ) else: - if layer_norm_unify.is_special_cases(input_x, input_gamma, input_beta, begin_norm_axis, impl_mode): - __dynamic_template_api(input_x, input_gamma, input_beta, output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, epsilon, kernel_name, impl_mode) + if layer_norm_unify.is_special_cases( + input_x, input_gamma, input_beta, begin_norm_axis, impl_mode + ): + __dynamic_template_api( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + kernel_name, + impl_mode, + ) return else: - mean, variance, res = \ - layer_norm_compute(data_x, data_gamma, data_beta, - output_y, output_mean, - output_variance, - begin_norm_axis, begin_params_axis, - epsilon, kernel_name, impl_mode) + mean, variance, res = layer_norm_compute( + data_x, + data_gamma, + data_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + kernel_name, + impl_mode, + ) with tvm.target.cce(): sch = tbe.auto_schedule([res, mean, variance]) - config = {"print_ir" : False, - "name" : kernel_name, - "tensor_list": [data_x, data_gamma, - data_beta, res, mean, variance]} + config = { + "print_ir": False, + "name": kernel_name, + "tensor_list": [data_x, data_gamma, data_beta, res, mean, variance], + } tbe.cce_build_code(sch, config) -def __dynamic_template_api(input_x, input_gamma, input_beta, output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, epsilon, kernel_name, impl_mode): +def __dynamic_template_api( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + kernel_name, + impl_mode, +): # when all reduce axis, or reduce axis non aligned or reduced mte data less one block etc. single-core cases will # transfer dynamic template to use multi-core - input_x, input_gamma, input_beta = layer_norm_unify.set_range(input_x, input_gamma, input_beta) + input_x, input_gamma, input_beta = layer_norm_unify.set_range( + input_x, input_gamma, input_beta + ) context_ops = tbe_context.op_context.get_context() if context_ops is not None: context_ops.set_op_mode("static") context_ops.add_addition("is_static", True) - dyn.layer_norm(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - epsilon, kernel_name, impl_mode) + dyn.layer_norm( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + kernel_name, + impl_mode, + ) else: with tbe_context.op_context.OpContext("static"): tbe_context.op_context.get_context().add_addition("is_static", True) - dyn.layer_norm(input_x, input_gamma, input_beta, - output_y, output_mean, output_variance, - begin_norm_axis, begin_params_axis, - epsilon, kernel_name, impl_mode) + dyn.layer_norm( + input_x, + input_gamma, + input_beta, + output_y, + output_mean, + output_variance, + begin_norm_axis, + begin_params_axis, + epsilon, + kernel_name, + impl_mode, + ) diff --git a/codegeex/mindspore/scripts/layer_norm_x_backprop_v2.py b/codegeex/mindspore/scripts/layer_norm_x_backprop_v2.py index 03e5e80..a685a3d 100644 --- a/codegeex/mindspore/scripts/layer_norm_x_backprop_v2.py +++ b/codegeex/mindspore/scripts/layer_norm_x_backprop_v2.py @@ -33,14 +33,16 @@ # 'pylint: disable=too-many-lines # 'pylint: disable = unused-argument,too-many-arguments,too-many-locals,global-variable-undefined -def get_op_support_info(input_dy, - input_x, - input_variance, - input_mean, - input_gamma, - output_pd_x, - res_for_gamma, - kernel_name="layer_norm_x_backprop_v2"): +def get_op_support_info( + input_dy, + input_x, + input_variance, + input_mean, + input_gamma, + output_pd_x, + res_for_gamma, + kernel_name="layer_norm_x_backprop_v2", +): """ get_op_support_info """ @@ -59,17 +61,27 @@ def get_op_support_info(input_dy, if flag == -1: for i in range(len(shape_x) - 1): split_0 = [ - SplitInput([0, [i], [-1], [-1]], [1, [i], [-1], [-1]], [2, [i], [-1], [-1]], - [3, [i], [-1], [-1]], [4, [i], [-1], [-1]]), - SplitOutput([0, [i]]) + SplitInput( + [0, [i], [-1], [-1]], + [1, [i], [-1], [-1]], + [2, [i], [-1], [-1]], + [3, [i], [-1], [-1]], + [4, [i], [-1], [-1]], + ), + SplitOutput([0, [i]]), ] axis_split_matrix.append(split_0) else: for i in range(flag): split_0 = [ - SplitInput([0, [i], [-1], [-1]], [1, [i], [-1], [-1]], [2, [i], [-1], [-1]], - [3, [i], [-1], [-1]], [4, [i], [-1], [-1]]), - SplitOutput([0, [i]], [1, [i]]) + SplitInput( + [0, [i], [-1], [-1]], + [1, [i], [-1], [-1]], + [2, [i], [-1], [-1]], + [3, [i], [-1], [-1]], + [4, [i], [-1], [-1]], + ), + SplitOutput([0, [i]], [1, [i]]), ] axis_split_matrix.append(split_0) else: @@ -91,20 +103,21 @@ def _check_dynamic_format(shape_dy, shape_gamma, c_0): """ if len(shape_dy) < 2 or len(shape_gamma) != 1: return True - if shape_dy[-1] % c_0 != 0 or shape_dy[-2] % c_0 != 0 \ - or shape_gamma[-1] % c_0 != 0: + if shape_dy[-1] % c_0 != 0 or shape_dy[-2] % c_0 != 0 or shape_gamma[-1] % c_0 != 0: return True return True -def op_select_format(input_dy, - input_x, - input_variance, - input_mean, - input_gamma, - output_pd_x, - res_for_gamma, - kernel_name="layer_norm_x_backprop_v2"): +def op_select_format( + input_dy, + input_x, + input_variance, + input_mean, + input_gamma, + output_pd_x, + res_for_gamma, + kernel_name="layer_norm_x_backprop_v2", +): """ function of selecting dynamic format @@ -136,84 +149,91 @@ def op_select_format(input_dy, c_0 = 16 if _check_dynamic_format(shape_dy, shape_gamma, c_0): - input0 = util_select_op_base.gen_param(classify="input0", - name="dy", - datatype="float16,float16,float16,float," - "float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - input1 = util_select_op_base.gen_param(classify="input1", - name="x", - datatype="float16,float16,float16,float," - "float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - input2 = util_select_op_base.gen_param(classify="input2", - name="variance", - datatype="float16,float16,float16,float," - "float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - input3 = util_select_op_base.gen_param(classify="input3", - name="mean", - datatype="float16,float16,float16,float," - "float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - input4 = util_select_op_base.gen_param(classify="input4", - name="gamma", - datatype="float16,float16,float16,float," - "float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - output0 = util_select_op_base.gen_param(classify="output0", - name="pd_x", - datatype="float16,float16,float16,float," - "float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") - output1 = util_select_op_base.gen_param(classify="output1", - name="res_for_gamma", - datatype="float,float,float,float," - "float,float", - format="NCHW,NHWC,ND,NCHW,NHWC,ND") + input0 = util_select_op_base.gen_param( + classify="input0", + name="dy", + datatype="float16,float16,float16,float," "float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + input1 = util_select_op_base.gen_param( + classify="input1", + name="x", + datatype="float16,float16,float16,float," "float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + input2 = util_select_op_base.gen_param( + classify="input2", + name="variance", + datatype="float16,float16,float16,float," "float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + input3 = util_select_op_base.gen_param( + classify="input3", + name="mean", + datatype="float16,float16,float16,float," "float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + input4 = util_select_op_base.gen_param( + classify="input4", + name="gamma", + datatype="float16,float16,float16,float," "float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + output0 = util_select_op_base.gen_param( + classify="output0", + name="pd_x", + datatype="float16,float16,float16,float," "float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) + output1 = util_select_op_base.gen_param( + classify="output1", + name="res_for_gamma", + datatype="float,float,float,float," "float,float", + format="NCHW,NHWC,ND,NCHW,NHWC,ND", + ) else: - input0 = util_select_op_base.gen_param(classify="input0", - name="dy", - datatype="float16, float,float16,float16," - "float16,float,float,float", - format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,ND," - "NCHW,NHWC,ND") - input1 = util_select_op_base.gen_param(classify="input1", - name="x", - datatype="float16, float,float16,float16," - "float16,float,float,float", - format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,ND," - "NCHW,NHWC,ND") - input2 = util_select_op_base.gen_param(classify="input2", - name="variance", - datatype="float16, float,float16,float16," - "float16,float,float,float", - format="ND,ND,NCHW,NHWC,ND,NCHW," - "NHWC,ND") - input3 = util_select_op_base.gen_param(classify="input3", - name="mean", - datatype="float16, float,float16,float16," - "float16,float,float,float", - format="ND,ND,NCHW,NHWC,ND,NCHW," - "NHWC,ND") - input4 = util_select_op_base.gen_param(classify="input4", - name="gamma", - datatype="float16, float,float16,float16," - "float16,float,float,float", - format="ND,ND,NCHW,NHWC,ND,NCHW," - "NHWC,ND") - output0 = util_select_op_base.gen_param(classify="output0", - name="pd_x", - datatype="float16, float,float16,float16," - "float16,float,float,float", - format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC," - "ND,NCHW,NHWC,ND") - output1 = util_select_op_base.gen_param(classify="output1", - name="res_for_gamma", - datatype="float, float,float,float," - "float,float,float,float", - format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC," - "ND,NCHW,NHWC,ND") + input0 = util_select_op_base.gen_param( + classify="input0", + name="dy", + datatype="float16, float,float16,float16," "float16,float,float,float", + format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,ND," "NCHW,NHWC,ND", + ) + input1 = util_select_op_base.gen_param( + classify="input1", + name="x", + datatype="float16, float,float16,float16," "float16,float,float,float", + format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,ND," "NCHW,NHWC,ND", + ) + input2 = util_select_op_base.gen_param( + classify="input2", + name="variance", + datatype="float16, float,float16,float16," "float16,float,float,float", + format="ND,ND,NCHW,NHWC,ND,NCHW," "NHWC,ND", + ) + input3 = util_select_op_base.gen_param( + classify="input3", + name="mean", + datatype="float16, float,float16,float16," "float16,float,float,float", + format="ND,ND,NCHW,NHWC,ND,NCHW," "NHWC,ND", + ) + input4 = util_select_op_base.gen_param( + classify="input4", + name="gamma", + datatype="float16, float,float16,float16," "float16,float,float,float", + format="ND,ND,NCHW,NHWC,ND,NCHW," "NHWC,ND", + ) + output0 = util_select_op_base.gen_param( + classify="output0", + name="pd_x", + datatype="float16, float,float16,float16," "float16,float,float,float", + format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC," "ND,NCHW,NHWC,ND", + ) + output1 = util_select_op_base.gen_param( + classify="output1", + name="res_for_gamma", + datatype="float, float,float,float," "float,float,float,float", + format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC," "ND,NCHW,NHWC,ND", + ) param_list = [input0, input1, input2, input3, input4, output0, output1] param_dynamic_in_json = util_select_op_base.get_dynamic_param_in_json(param_list) @@ -262,13 +282,17 @@ def _check_shape(params_map): """ if operator.ne(tuple(params_map.get("shape_dy")), tuple(params_map.get("shape_x"))): error_detail = "shape of input_dy and input_x should be same" - error_manager_vector.raise_err_two_input_shape_invalid("layer_norm_x_backprop_v2", "input_dy", "input_x", - error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + "layer_norm_x_backprop_v2", "input_dy", "input_x", error_detail + ) - if operator.ne(tuple(params_map.get("shape_var")), tuple(params_map.get("shape_mean"))): + if operator.ne( + tuple(params_map.get("shape_var")), tuple(params_map.get("shape_mean")) + ): error_detail = "shape of input_variance and input_mean should be same" - error_manager_vector.raise_err_two_input_shape_invalid("layer_norm_x_backprop_v2", "input_variance", - "input_mean", error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + "layer_norm_x_backprop_v2", "input_variance", "input_mean", error_detail + ) shape_x = params_map.get("shape_x") shape_mean = params_map.get("shape_mean") @@ -299,12 +323,15 @@ def _check_shape_mean(shape_x, shape_mean): """ if len(shape_x) != len(shape_mean): error_detail = "length of shape_x and shape_mean should be same" - error_manager_vector.raise_err_two_input_shape_invalid("layer_norm_x_backprop_v2", "input_x", "input_mean", - error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + "layer_norm_x_backprop_v2", "input_x", "input_mean", error_detail + ) if shape_mean[-1] != 1: error_detail = "value of shape_mean's last dim must be 1" - error_manager_vector.raise_err_input_shape_invalid("layer_norm_x_backprop_v2", "input_mean", error_detail) + error_manager_vector.raise_err_input_shape_invalid( + "layer_norm_x_backprop_v2", "input_mean", error_detail + ) flag = -1 for i, (xtem, mean) in enumerate(zip(shape_x, shape_mean)): @@ -318,8 +345,9 @@ def _check_shape_mean(shape_x, shape_mean): continue if mean != 1: error_detail = "value of shape_mean must be 1" - error_manager_vector.raise_err_input_shape_invalid("layer_norm_x_backprop_v2", "input_mean", - error_detail) + error_manager_vector.raise_err_input_shape_invalid( + "layer_norm_x_backprop_v2", "input_mean", error_detail + ) def _check_shape_gamma(shape_x, shape_gamma): @@ -339,13 +367,16 @@ def _check_shape_gamma(shape_x, shape_gamma): """ if len(shape_gamma) > len(shape_x): error_detail = "length of shape_gamma can not be longer than shape_x" - error_manager_vector.raise_err_two_input_shape_invalid("layer_norm_x_backprop_v2", "input_gamma", "input_x", - error_detail) + error_manager_vector.raise_err_two_input_shape_invalid( + "layer_norm_x_backprop_v2", "input_gamma", "input_x", error_detail + ) for xtem, gamma in zip(reversed(shape_x), reversed(shape_gamma)): if xtem != gamma: error_detail = "value of shape_gamma is wrong" - error_manager_vector.raise_err_input_shape_invalid("layer_norm_x_backprop_v2", "input_gamma", error_detail) + error_manager_vector.raise_err_input_shape_invalid( + "layer_norm_x_backprop_v2", "input_gamma", error_detail + ) def _broadcast_nz(tensor, shape): @@ -354,9 +385,11 @@ def _broadcast_nz(tensor, shape): for i, _ in enumerate(shape): if shape[i] != src_shape[i]: broadcast_axes.append(i) - if len(broadcast_axes) == 2 and \ - broadcast_axes[1] - broadcast_axes[0] != 1 and \ - broadcast_axes[1] + 1 == len(shape): + if ( + len(broadcast_axes) == 2 + and broadcast_axes[1] - broadcast_axes[0] != 1 + and broadcast_axes[1] + 1 == len(shape) + ): temp_shape = src_shape[:-1] + [shape[-1]] tensor = tbe.broadcast(tensor, temp_shape) tensor = tbe.broadcast(tensor, shape) @@ -414,9 +447,13 @@ def _get_data_gm(shapes, dtype): """ data_dy = tvm.placeholder(shapes.get("shape_dy"), name="data_dy", dtype=dtype) data_x = tvm.placeholder(shapes.get("shape_x"), name="data_x", dtype=dtype) - data_variance = tvm.placeholder(shapes.get("shape_var"), name="data_variance", dtype=dtype) + data_variance = tvm.placeholder( + shapes.get("shape_var"), name="data_variance", dtype=dtype + ) data_mean = tvm.placeholder(shapes.get("shape_mean"), name="data_mean", dtype=dtype) - data_gamma = tvm.placeholder(shapes.get("shape_gamma"), name="data_gamma", dtype=dtype) + data_gamma = tvm.placeholder( + shapes.get("shape_gamma"), name="data_gamma", dtype=dtype + ) data_gm = (data_dy, data_x, data_variance, data_mean, data_gamma) @@ -461,7 +498,11 @@ def _get_params(shape_x, shape_mean, shape_gamma): for i in reduce_axis: mean_num *= shape_x[i] - params = {"param_axis": param_axis, "reduce_axis": reduce_axis, "mean_num": mean_num} + params = { + "param_axis": param_axis, + "reduce_axis": reduce_axis, + "mean_num": mean_num, + } return params @@ -628,7 +669,9 @@ def _get_pd_x(data, params, shape_x, dtype, cast_dtype): """ pd_xl = _get_pd_xl(data, shape_x) - pd_var, var_elta_2, sub_x_mean = _get_pd_var(data, params, shape_x, pd_xl, cast_dtype) + pd_var, var_elta_2, sub_x_mean = _get_pd_var( + data, params, shape_x, pd_xl, cast_dtype + ) pd_mean = _get_pd_mean(params, pd_xl, pd_var, var_elta_2, sub_x_mean, cast_dtype) @@ -636,10 +679,14 @@ def _get_pd_x(data, params, shape_x, dtype, cast_dtype): pd_x_1 = tbe.vmul(var_elta_2_cast, pd_xl) res_for_gamma = tbe.vmul(var_elta_2_cast, sub_x_mean) - pd_var = tbe.vmuls(pd_var, tvm.const((2 * (params.get("mean_num") ** (-1))), dtype=cast_dtype)) + pd_var = tbe.vmuls( + pd_var, tvm.const((2 * (params.get("mean_num") ** (-1))), dtype=cast_dtype) + ) pdx2_broad = tbe.broadcast(pd_var, shape_x) pd_x_2 = tbe.vmul(pdx2_broad, sub_x_mean) - pd_x_3 = tbe.vmuls(pd_mean, tvm.const((params.get("mean_num") ** (-1)), dtype=cast_dtype)) + pd_x_3 = tbe.vmuls( + pd_mean, tvm.const((params.get("mean_num") ** (-1)), dtype=cast_dtype) + ) pdx_broad = tbe.broadcast(pd_x_3, shape_x) pdx_add = tbe.vadd(pd_x_1, pd_x_2) @@ -719,7 +766,9 @@ def _get_pds(data_dy, data_x, data_variance, data_mean, data_gamma, shape_gamma_ has_improve_precision = False cast_dtype = dtype - if dtype == "float16" and tbe_platform.cce_conf.api_check_support("te.lang.cce.vexp", "float32"): + if dtype == "float16" and tbe_platform.cce_conf.api_check_support( + "te.lang.cce.vexp", "float32" + ): has_improve_precision = True cast_dtype = "float32" params = _get_params(shape_x, shape_mean, shape_gamma_ori) @@ -732,11 +781,11 @@ def _get_pds(data_dy, data_x, data_variance, data_mean, data_gamma, shape_gamma_ data_gamma = tbe.cast_to(data_gamma, "float32") data = { - "data_dy" : data_dy, - "data_x" : data_x, + "data_dy": data_dy, + "data_x": data_x, "data_variance": data_variance, - "data_mean" : data_mean, - "data_gamma" : data_gamma + "data_mean": data_mean, + "data_gamma": data_gamma, } pd_x, res_for_gamma = _get_res(data, params, shape_x, dtype, cast_dtype) @@ -745,13 +794,15 @@ def _get_pds(data_dy, data_x, data_variance, data_mean, data_gamma, shape_gamma_ @tbe_platform.fusion_manager.fusion_manager.register("layer_norm_x_backprop_v2") -def layer_norm_x_backprop_v2_compute(input_dy, - input_x, - input_variance, - input_mean, - input_gamma, - output_pd_x, - kernel_name="layer_norm_x_backprop_v2"): +def layer_norm_x_backprop_v2_compute( + input_dy, + input_x, + input_variance, + input_mean, + input_gamma, + output_pd_x, + kernel_name="layer_norm_x_backprop_v2", +): """ DSL description of the layernorm_grad operator's mathematical calculation process @@ -778,7 +829,9 @@ def layer_norm_x_backprop_v2_compute(input_dy, res_tuple: tuple (pd_x, pd_gamma, pd_beta) """ - pd_x, res_for_gamma = _get_pds(input_dy, input_x, input_variance, input_mean, input_gamma, input_gamma.shape) + pd_x, res_for_gamma = _get_pds( + input_dy, input_x, input_variance, input_mean, input_gamma, input_gamma.shape + ) res_list = [pd_x, res_for_gamma] return res_list @@ -833,12 +886,12 @@ def update_shape_nz(shape_x, shape_var, shape_gamma): mean_nz_num *= shape_x_nz[i] param_nz = { - "shape_x_nz" : shape_x_nz, - "shape_var_nz" : shape_var_nz, + "shape_x_nz": shape_x_nz, + "shape_var_nz": shape_var_nz, "shape_gamma_nz": shape_gamma_nz, - "reduce_axis" : reduce_nz_axis, - "param_axis" : param_nz_axis, - "mean_num" : mean_nz_num + "reduce_axis": reduce_nz_axis, + "param_axis": param_nz_axis, + "mean_num": mean_nz_num, } return param_nz @@ -851,9 +904,15 @@ def _get_data_nz(param_nz, dtype): """ data_dy = tvm.placeholder(param_nz.get("shape_x_nz"), name="data_dy", dtype=dtype) data_x = tvm.placeholder(param_nz.get("shape_x_nz"), name="data_x", dtype=dtype) - data_variance = tvm.placeholder(param_nz.get("shape_var_nz"), name="data_variance", dtype=dtype) - data_mean = tvm.placeholder(param_nz.get("shape_var_nz"), name="data_mean", dtype=dtype) - data_gamma = tvm.placeholder(param_nz.get("shape_gamma_nz"), name="data_gamma", dtype=dtype) + data_variance = tvm.placeholder( + param_nz.get("shape_var_nz"), name="data_variance", dtype=dtype + ) + data_mean = tvm.placeholder( + param_nz.get("shape_var_nz"), name="data_mean", dtype=dtype + ) + data_gamma = tvm.placeholder( + param_nz.get("shape_gamma_nz"), name="data_gamma", dtype=dtype + ) data_gm = (data_dy, data_x, data_variance, data_mean, data_gamma) @@ -926,15 +985,21 @@ def _get_pd_x_nz(data, param_nz, dtype, cast_dtype): pd_var, var_elta_2, sub_x_mean = _get_pd_var_nz(data, param_nz, pd_xl, cast_dtype) - pd_mean = _get_pd_mean_nz(param_nz, pd_xl, pd_var, var_elta_2, sub_x_mean, cast_dtype) + pd_mean = _get_pd_mean_nz( + param_nz, pd_xl, pd_var, var_elta_2, sub_x_mean, cast_dtype + ) var_elta_2_cast = _broadcast_nz(var_elta_2, param_nz.get("shape_x_nz")) pd_x_1 = tbe.vmul(var_elta_2_cast, pd_xl) res_for_gamma = tbe.vmul(var_elta_2_cast, sub_x_mean) - pd_var = tbe.vmuls(pd_var, tvm.const((2 * (param_nz.get("mean_num") ** (-1))), dtype=cast_dtype)) + pd_var = tbe.vmuls( + pd_var, tvm.const((2 * (param_nz.get("mean_num") ** (-1))), dtype=cast_dtype) + ) pdx2_broad = _broadcast_nz(pd_var, param_nz.get("shape_x_nz")) pd_x_2 = tbe.vmul(pdx2_broad, sub_x_mean) - pd_x_3 = tbe.vmuls(pd_mean, tvm.const((param_nz.get("mean_num") ** (-1)), dtype=cast_dtype)) + pd_x_3 = tbe.vmuls( + pd_mean, tvm.const((param_nz.get("mean_num") ** (-1)), dtype=cast_dtype) + ) pdx_broad = _broadcast_nz(pd_x_3, param_nz.get("shape_x_nz")) pdx_add = tbe.vadd(pd_x_1, pd_x_2) @@ -967,7 +1032,9 @@ def _get_pds_nz(data_dy, data_x, data_variance, data_mean, data_gamma, param_nz) has_improve_precision = False cast_dtype = dtype - if dtype == "float16" and tbe_platform.cce_conf.api_check_support("te.lang.cce.vexp", "float32"): + if dtype == "float16" and tbe_platform.cce_conf.api_check_support( + "te.lang.cce.vexp", "float32" + ): has_improve_precision = True cast_dtype = "float32" @@ -979,11 +1046,11 @@ def _get_pds_nz(data_dy, data_x, data_variance, data_mean, data_gamma, param_nz) data_gamma = tbe.cast_to(data_gamma, "float32") data = { - "data_dy" : data_dy, - "data_x" : data_x, + "data_dy": data_dy, + "data_x": data_x, "data_variance": data_variance, - "data_mean" : data_mean, - "data_gamma" : data_gamma + "data_mean": data_mean, + "data_gamma": data_gamma, } pd_x, res_for_gamma = _get_res_nz(data, param_nz, dtype, cast_dtype) @@ -991,7 +1058,9 @@ def _get_pds_nz(data_dy, data_x, data_variance, data_mean, data_gamma, param_nz) return pd_x, res_for_gamma -def layer_norm_x_back_nz_compute(data_dy, data_x, data_variance, data_mean, data_gamma, param_nz): +def layer_norm_x_back_nz_compute( + data_dy, data_x, data_variance, data_mean, data_gamma, param_nz +): """ DSL description of the layernorm_grad operator's mathematical calculation process @@ -1016,22 +1085,33 @@ def layer_norm_x_back_nz_compute(data_dy, data_x, data_variance, data_mean, data res_tuple: tuple (pd_x, res_for_gamma) """ - pd_x, res_for_gamma = _get_pds_nz(data_dy, data_x, data_variance, data_mean, data_gamma, param_nz) + pd_x, res_for_gamma = _get_pds_nz( + data_dy, data_x, data_variance, data_mean, data_gamma, param_nz + ) return [pd_x, res_for_gamma] -@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT, - para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, - para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME) -def layer_norm_x_backprop_v2(input_dy, - input_x, - input_variance, - input_mean, - input_gamma, - output_pd_x, - res_for_gamma, - kernel_name="layer_norm_x_backprop_v2"): +@para_check.check_op_params( + para_check.REQUIRED_INPUT, + para_check.REQUIRED_INPUT, + para_check.REQUIRED_INPUT, + para_check.REQUIRED_INPUT, + para_check.REQUIRED_INPUT, + para_check.REQUIRED_OUTPUT, + para_check.REQUIRED_OUTPUT, + para_check.KERNEL_NAME, +) +def layer_norm_x_backprop_v2( + input_dy, + input_x, + input_variance, + input_mean, + input_gamma, + output_pd_x, + res_for_gamma, + kernel_name="layer_norm_x_backprop_v2", +): """ algorithm: layernorm_grad calculating: gradient of layernorm @@ -1085,69 +1165,116 @@ def layer_norm_x_backprop_v2(input_dy, format_dy = input_dy.get("format") global EPSLON EPSLON = 1e-5 if dtype == "float16" else 1e-12 - if layer_norm_x_backprop_v2_unify.is_special_cases(shape_dy, shape_variance, shape_gamma): + if layer_norm_x_backprop_v2_unify.is_special_cases( + shape_dy, shape_variance, shape_gamma + ): context = tbe_context.op_context.get_context() if context is not None: context.set_op_mode("static") context.add_addition("is_static", True) - layer_norm_x_backprop_v2_unify.layer_norm_x_backprop_v2(input_dy, input_x, input_variance, input_mean, - input_gamma, output_pd_x, res_for_gamma, - kernel_name) + layer_norm_x_backprop_v2_unify.layer_norm_x_backprop_v2( + input_dy, + input_x, + input_variance, + input_mean, + input_gamma, + output_pd_x, + res_for_gamma, + kernel_name, + ) else: with tbe_context.op_context.OpContext("static"): tbe_context.op_context.get_context().add_addition("is_static", True) - layer_norm_x_backprop_v2_unify.layer_norm_x_backprop_v2(input_dy, input_x, input_variance, input_mean, - input_gamma, output_pd_x, res_for_gamma, - kernel_name) + layer_norm_x_backprop_v2_unify.layer_norm_x_backprop_v2( + input_dy, + input_x, + input_variance, + input_mean, + input_gamma, + output_pd_x, + res_for_gamma, + kernel_name, + ) return else: if format_dy.upper() == "FRACTAL_NZ": param_nz = update_shape_nz(shape_x, shape_variance, shape_gamma) - data_dy = tvm.placeholder(param_nz.get("shape_x_nz"), name="data_dy", dtype=dtype) - data_x = tvm.placeholder(param_nz.get("shape_x_nz"), name="data_x", dtype=dtype) - data_variance = tvm.placeholder(param_nz.get("shape_var_nz"), name="data_variance", dtype=dtype) - data_mean = tvm.placeholder(param_nz.get("shape_var_nz"), name="data_mean", dtype=dtype) - data_gamma = tvm.placeholder(param_nz.get("shape_gamma_nz"), name="data_gamma", dtype=dtype) - - res_list = layer_norm_x_back_nz_compute(data_dy, data_x, data_variance, data_mean, data_gamma, param_nz) - - tensor_list = [data_dy, data_x, data_variance, data_mean, data_gamma] + res_list + data_dy = tvm.placeholder( + param_nz.get("shape_x_nz"), name="data_dy", dtype=dtype + ) + data_x = tvm.placeholder( + param_nz.get("shape_x_nz"), name="data_x", dtype=dtype + ) + data_variance = tvm.placeholder( + param_nz.get("shape_var_nz"), name="data_variance", dtype=dtype + ) + data_mean = tvm.placeholder( + param_nz.get("shape_var_nz"), name="data_mean", dtype=dtype + ) + data_gamma = tvm.placeholder( + param_nz.get("shape_gamma_nz"), name="data_gamma", dtype=dtype + ) + + res_list = layer_norm_x_back_nz_compute( + data_dy, data_x, data_variance, data_mean, data_gamma, param_nz + ) + + tensor_list = [ + data_dy, + data_x, + data_variance, + data_mean, + data_gamma, + ] + res_list with tvm.target.cce(): sch = tbe.auto_schedule(res_list) - config = {"print_ir": False, "name": kernel_name, "tensor_list": tensor_list} + config = { + "print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list, + } tbe.cce_build_code(sch, config) else: - _check_params({ - "shape_dy" : shape_dy, - "shape_x" : shape_x, - "shape_var" : shape_variance, - "shape_mean" : shape_mean, - "shape_gamma": shape_gamma, - "dtype" : dtype, - "kernel_name": kernel_name - }) + _check_params( + { + "shape_dy": shape_dy, + "shape_x": shape_x, + "shape_var": shape_variance, + "shape_mean": shape_mean, + "shape_gamma": shape_gamma, + "dtype": dtype, + "kernel_name": kernel_name, + } + ) shape_gamma = _update_gamma_shape(shape_x, shape_gamma)[0] data_gm = _get_data_gm( { - "shape_dy" : shape_dy, - "shape_x" : shape_x, - "shape_var" : shape_variance, - "shape_mean" : shape_mean, - "shape_gamma": shape_gamma - }, dtype) - - res_list = layer_norm_x_backprop_v2_compute(data_gm[0], data_gm[1], data_gm[2], data_gm[3], data_gm[4], - output_pd_x) + "shape_dy": shape_dy, + "shape_x": shape_x, + "shape_var": shape_variance, + "shape_mean": shape_mean, + "shape_gamma": shape_gamma, + }, + dtype, + ) + + res_list = layer_norm_x_backprop_v2_compute( + data_gm[0], data_gm[1], data_gm[2], data_gm[3], data_gm[4], output_pd_x + ) with tvm.target.cce(): sch = tbe.auto_schedule(res_list) tensor_list = list(data_gm) + list(res_list) - config = {"print_ir": False, "name": kernel_name, "tensor_list": tensor_list} + config = { + "print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list, + } tbe.cce_build_code(sch, config) diff --git a/codegeex/mindspore/scripts/run_modelarts.py b/codegeex/mindspore/scripts/run_modelarts.py index 05715a0..9417d5f 100644 --- a/codegeex/mindspore/scripts/run_modelarts.py +++ b/codegeex/mindspore/scripts/run_modelarts.py @@ -11,7 +11,12 @@ args = parser.parse_args() -log_path = os.path.join(args.work_dir, "logs", os.environ.get("JOB_ID"), f'device{os.environ.get("RANK_ID")}') +log_path = os.path.join( + args.work_dir, + "logs", + os.environ.get("JOB_ID"), + f'device{os.environ.get("RANK_ID")}', +) tb_path = os.path.join(args.work_dir, "runs", os.environ.get("JOB_ID")) Path(log_path).mkdir(parents=True, exist_ok=True) @@ -25,7 +30,8 @@ print("=================ms import done", flush=True) time.sleep(10) os.system( - "cp /home/work/rank_table/jobstart_hccl.json /home/work/sfs/xx; sudo chmod +777 /home/work/rank_table/jobstart_hccl.json") + "cp /home/work/rank_table/jobstart_hccl.json /home/work/sfs/xx; sudo chmod +777 /home/work/rank_table/jobstart_hccl.json" +) ret = os.system(f"cd {log_path} && bash {args.script} 2>&1 | tee output.log") if os.environ.get("RANK_ID") == 0: log_dir = os.path.join(args.work_dir, "logs", os.environ.get("JOB_ID")) diff --git a/codegeex/mindspore/scripts/run_modelarts_gen_finetune.py b/codegeex/mindspore/scripts/run_modelarts_gen_finetune.py index 2e5f1ec..cf527f4 100644 --- a/codegeex/mindspore/scripts/run_modelarts_gen_finetune.py +++ b/codegeex/mindspore/scripts/run_modelarts_gen_finetune.py @@ -12,7 +12,12 @@ args = parser.parse_args() -log_path = os.path.join(args.work_dir, "logs", os.environ.get("JOB_ID"), f'device{os.environ.get("RANK_ID")}') +log_path = os.path.join( + args.work_dir, + "logs", + os.environ.get("JOB_ID"), + f'device{os.environ.get("RANK_ID")}', +) tb_path = os.path.join(args.work_dir, "runs", os.environ.get("JOB_ID")) Path(log_path).mkdir(parents=True, exist_ok=True) @@ -30,7 +35,8 @@ print("=================ms import done", flush=True) time.sleep(10) os.system( - "cp /home/work/rank_table/jobstart_hccl.json /home/work/sfs/xx; sudo chmod +777 /home/work/rank_table/jobstart_hccl.json") + "cp /home/work/rank_table/jobstart_hccl.json /home/work/sfs/xx; sudo chmod +777 /home/work/rank_table/jobstart_hccl.json" +) ret = os.system(f"cd {log_path} && bash {args.script} 2>&1 | tee output.log") if os.environ.get("RANK_ID") == 0: log_dir = os.path.join(args.work_dir, "logs", os.environ.get("JOB_ID")) diff --git a/codegeex/mindspore/scripts/run_modelarts_gen_humaneval_x.py b/codegeex/mindspore/scripts/run_modelarts_gen_humaneval_x.py index 256d7d1..069cdcd 100644 --- a/codegeex/mindspore/scripts/run_modelarts_gen_humaneval_x.py +++ b/codegeex/mindspore/scripts/run_modelarts_gen_humaneval_x.py @@ -12,7 +12,12 @@ args = parser.parse_args() -log_path = os.path.join(args.work_dir, "logs", os.environ.get("JOB_ID"), f'device{os.environ.get("RANK_ID")}') +log_path = os.path.join( + args.work_dir, + "logs", + os.environ.get("JOB_ID"), + f'device{os.environ.get("RANK_ID")}', +) tb_path = os.path.join(args.work_dir, "runs", os.environ.get("JOB_ID")) Path(log_path).mkdir(parents=True, exist_ok=True) @@ -30,7 +35,8 @@ print("=================ms import done", flush=True) time.sleep(10) os.system( - "cp /home/work/rank_table/jobstart_hccl.json /home/work/sfs/xx; sudo chmod +777 /home/work/rank_table/jobstart_hccl.json") + "cp /home/work/rank_table/jobstart_hccl.json /home/work/sfs/xx; sudo chmod +777 /home/work/rank_table/jobstart_hccl.json" +) ret = os.system(f"cd {log_path} && bash {args.script} 2>&1 | tee output.log") if os.environ.get("RANK_ID") == 0: log_dir = os.path.join(args.work_dir, "logs", os.environ.get("JOB_ID")) diff --git a/codegeex/mindspore/src/adam.py b/codegeex/mindspore/src/adam.py index c2ccfe8..cc0e792 100644 --- a/codegeex/mindspore/src/adam.py +++ b/codegeex/mindspore/src/adam.py @@ -31,21 +31,65 @@ _cpu_div = P.RealDiv().add_prim_attr("primitive_target", "CPU") -@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", - "Tensor", "Tensor", "Bool", "Bool") -def _update_run_kernel(opt, clip_value, beta1, beta2, eps, lr, weight_decay, - param, m, v, gradient, decay_flags, optim_filter): +@_adam_opt.register( + "Function", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Number", + "Tensor", + "Tensor", + "Tensor", + "Tensor", + "Bool", + "Bool", +) +def _update_run_kernel( + opt, + clip_value, + beta1, + beta2, + eps, + lr, + weight_decay, + param, + m, + v, + gradient, + decay_flags, + optim_filter, +): """ Update parameters by AdamWeightDecay op. """ success = True if optim_filter: if decay_flags: - next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, - _cpu_div(P.Cast()(gradient, mstype.float16), clip_value)) + next_param = opt( + param, + m, + v, + lr, + beta1, + beta2, + eps, + weight_decay, + _cpu_div(P.Cast()(gradient, mstype.float16), clip_value), + ) else: - next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, - _cpu_div(P.Cast()(gradient, mstype.float16), clip_value)) + next_param = opt( + param, + m, + v, + lr, + beta1, + beta2, + eps, + 0.0, + _cpu_div(P.Cast()(gradient, mstype.float16), clip_value), + ) return F.depend(success, next_param) return success @@ -131,24 +175,33 @@ class AdamWeightDecayOp(Optimizer): >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim) - """ + """ - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, - clip_norm=1.0, param_init_type=mstype.float32): + def __init__( + self, + params, + learning_rate=1e-3, + beta1=0.9, + beta2=0.999, + eps=1e-6, + weight_decay=0.0, + clip_norm=1.0, + param_init_type=mstype.float32, + ): super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) _check_param_value(beta1, beta2, eps, self.cls_name) self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) self.eps = Tensor(np.array([eps]).astype(np.float32)) self.clip_norm = Tensor([clip_norm], mstype.float32) - self.enable_init_fp16 = (param_init_type == mstype.float16) + self.enable_init_fp16 = param_init_type == mstype.float16 if self.enable_init_fp16: - self.moments1 = self.clone_param32(prefix="adam_m", init='zeros') - self.moments2 = self.clone_param32(prefix="adam_v", init='zeros') + self.moments1 = self.clone_param32(prefix="adam_m", init="zeros") + self.moments2 = self.clone_param32(prefix="adam_v", init="zeros") self.opt = P.FusedCastAdamWeightDecay() else: - self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') - self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') + self.moments1 = self.parameters.clone(prefix="adam_m", init="zeros") + self.moments2 = self.parameters.clone(prefix="adam_v", init="zeros") self.opt = P.AdamWeightDecay() self.hyper_map = C.HyperMap() self.opt.add_prim_attr("primitive_target", "CPU") @@ -161,20 +214,62 @@ def construct(self, gradients, clip_value): global_norm = P.Cast()(global_norm, mstype.float16) if self.is_group: if self.is_group_lr: - optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, global_norm, - self.beta1, self.beta2, self.eps), - lr, self.weight_decay, self.parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.map_reverse( + F.partial( + _adam_opt, + self.opt, + global_norm, + self.beta1, + self.beta2, + self.eps, + ), + lr, + self.weight_decay, + self.parameters, + self.moments1, + self.moments2, + gradients, + self.decay_flags, + self.optim_filter, + ) else: - optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, global_norm, - self.beta1, self.beta2, self.eps, lr), - self.weight_decay, self.parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.map_reverse( + F.partial( + _adam_opt, + self.opt, + global_norm, + self.beta1, + self.beta2, + self.eps, + lr, + ), + self.weight_decay, + self.parameters, + self.moments1, + self.moments2, + gradients, + self.decay_flags, + self.optim_filter, + ) else: - optim_result = self.map_reverse(F.partial(_adam_opt, self.opt, global_norm, - self.beta1, self.beta2, self.eps, lr, - self.weight_decay), self.parameters, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + optim_result = self.map_reverse( + F.partial( + _adam_opt, + self.opt, + global_norm, + self.beta1, + self.beta2, + self.eps, + lr, + self.weight_decay, + ), + self.parameters, + self.moments1, + self.moments2, + gradients, + self.decay_flags, + self.optim_filter, + ) if self.use_parallel: self.broadcast_params(optim_result) return optim_result @@ -196,7 +291,9 @@ def clone_param32(self, prefix, init=None): param_init = init if init is None: param_init = old_param.init - new_state = Parameter(initializer(param_init, shape=old_param.shape, dtype=mstype.float32)) + new_state = Parameter( + initializer(param_init, shape=old_param.shape, dtype=mstype.float32) + ) new_state.param_info = old_param.param_info.clone() new_state.is_init = False new_state.is_param_ps = old_param.is_param_ps @@ -205,6 +302,6 @@ def clone_param32(self, prefix, init=None): new_state.requires_aggr = old_param.requires_aggr if old_param.cache_shape: new_state.cache_shape = old_param.cache_shape - new_state.name = prefix + '.' + new_state.name + new_state.name = prefix + "." + new_state.name new.append(new_state) return ParameterTuple(new) diff --git a/codegeex/mindspore/src/callbacks.py b/codegeex/mindspore/src/callbacks.py index 569946e..c8dc62f 100644 --- a/codegeex/mindspore/src/callbacks.py +++ b/codegeex/mindspore/src/callbacks.py @@ -35,16 +35,16 @@ class LossCallBack(Callback): """ def __init__( - self, - name, - dataset_size=-1, - local_rank=0, - rank_size=1, - has_trained_epoch=0, - has_trained_step=0, - micro_size=1, - sink_size=2, - tb_writer=None, + self, + name, + dataset_size=-1, + local_rank=0, + rank_size=1, + has_trained_epoch=0, + has_trained_step=0, + micro_size=1, + sink_size=2, + tb_writer=None, ): super(LossCallBack, self).__init__() self._dataset_size = dataset_size @@ -56,7 +56,12 @@ def __init__( self.sink_size = sink_size self.summary_writer = tb_writer - print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True) + print( + "load has trained epoch :{} and step: {}".format( + has_trained_epoch, has_trained_step + ), + flush=True, + ) def step_end(self, run_context): """ @@ -64,29 +69,29 @@ def step_end(self, run_context): """ cb_params = run_context.original_args() if self._dataset_size > 0 and self.local_rank % 8 == 0: - percent, epoch_num = math.modf(cb_params.cur_step_num / - self._dataset_size) + percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size) if percent == 0: epoch_num -= 1 date = time.asctime(time.localtime(time.time())) loss_value = cb_params.net_outputs[0].asnumpy() / self.micro_size if self.summary_writer is not None: - print(f"writing: {loss_value.item()}, {cb_params.net_outputs[2].asnumpy()}") + print( + f"writing: {loss_value.item()}, {cb_params.net_outputs[2].asnumpy()}" + ) self.summary_writer.add_scalar( tag="training_loss", scalar_value=loss_value.item(), - global_step=cb_params.cur_step_num - + int(self.has_trained_step), + global_step=cb_params.cur_step_num + int(self.has_trained_step), ) self.summary_writer.add_scalar( tag="loss_scale", scalar_value=cb_params.net_outputs[2].asnumpy(), - global_step=cb_params.cur_step_num - + int(self.has_trained_step), + global_step=cb_params.cur_step_num + int(self.has_trained_step), ) print( - f"time: {date} local_rank: {int(self.local_rank)}, epoch: {int(epoch_num) + int(self.has_trained_epoch)}, step: {cb_params.cur_step_num + int(self.has_trained_step)}, output is {loss_value}, overflow is {cb_params.net_outputs[1].asnumpy()}, scale is {cb_params.net_outputs[2].asnumpy()}") + f"time: {date} local_rank: {int(self.local_rank)}, epoch: {int(epoch_num) + int(self.has_trained_epoch)}, step: {cb_params.cur_step_num + int(self.has_trained_step)}, output is {loss_value}, overflow is {cb_params.net_outputs[1].asnumpy()}, scale is {cb_params.net_outputs[2].asnumpy()}" + ) class EvalCallBack(Callback): @@ -99,8 +104,19 @@ class EvalCallBack(Callback): print_per_step (int): Print loss every times. Default: 1. """ - def __init__(self, model, eval_dataset, ppl_metric, validation_loss, print_per_step=250, has_trained_step=0, - local_rank=0, rank_size=1, tb_writer=None, lang=None): + def __init__( + self, + model, + eval_dataset, + ppl_metric, + validation_loss, + print_per_step=250, + has_trained_step=0, + local_rank=0, + rank_size=1, + tb_writer=None, + lang=None, + ): super(EvalCallBack, self).__init__() if not isinstance(print_per_step, int) or print_per_step < 0: raise ValueError("print_per_step must be int and >= 0.") @@ -115,8 +131,12 @@ def __init__(self, model, eval_dataset, ppl_metric, validation_loss, print_per_s self.pplMetric.clear() self.validation_loss.clear() self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - self.strategy_ckpt_save_file = context.get_auto_parallel_context("strategy_ckpt_save_file") - self.strategy_ckpt_load_file = context.get_auto_parallel_context("strategy_ckpt_load_file") + self.strategy_ckpt_save_file = context.get_auto_parallel_context( + "strategy_ckpt_save_file" + ) + self.strategy_ckpt_load_file = context.get_auto_parallel_context( + "strategy_ckpt_load_file" + ) self.summary_writer = tb_writer self.lang = lang @@ -130,12 +150,20 @@ def step_end(self, run_context): return self.pplMetric.clear() self.validation_loss.clear() - if self.parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - context.set_auto_parallel_context(strategy_ckpt_save_file="", - strategy_ckpt_load_file=self.strategy_ckpt_save_file) + if self.parallel_mode in ( + ParallelMode.SEMI_AUTO_PARALLEL, + ParallelMode.AUTO_PARALLEL, + ): + context.set_auto_parallel_context( + strategy_ckpt_save_file="", + strategy_ckpt_load_file=self.strategy_ckpt_save_file, + ) rank_id = 0 - if self.parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, - ParallelMode.AUTO_PARALLEL, ParallelMode.DATA_PARALLEL): + if self.parallel_mode in ( + ParallelMode.SEMI_AUTO_PARALLEL, + ParallelMode.AUTO_PARALLEL, + ParallelMode.DATA_PARALLEL, + ): rank_id = get_rank() print("validation begin") start_time = time.time() @@ -148,18 +176,32 @@ def step_end(self, run_context): print(out_str) if self.summary_writer is not None: print(f"writing: {out}") - tag = "validation_loss" if self.lang is None else f"validaton_loss/{self.lang}" + tag = ( + "validation_loss" + if self.lang is None + else f"validaton_loss/{self.lang}" + ) self.summary_writer.add_scalar( tag=tag, - scalar_value=out['ppl'], + scalar_value=out["ppl"], global_step=cb_params.cur_step_num + int(self.has_trained_step), ) - context.set_auto_parallel_context(strategy_ckpt_save_file=self.strategy_ckpt_save_file, - strategy_ckpt_load_file=self.strategy_ckpt_load_file) + context.set_auto_parallel_context( + strategy_ckpt_save_file=self.strategy_ckpt_save_file, + strategy_ckpt_load_file=self.strategy_ckpt_load_file, + ) class SaveCheckpointCallback(Callback): - def __init__(self, cache_dir, bucket, local_rank=0, has_trained_epoch=0, has_trained_step=0, syn_times=100): + def __init__( + self, + cache_dir, + bucket, + local_rank=0, + has_trained_epoch=0, + has_trained_step=0, + syn_times=100, + ): self.cache_dir = os.path.join(cache_dir, f"rank_{local_rank}") self.local_rank = local_rank self.has_trained_epoch = has_trained_epoch @@ -181,5 +223,9 @@ def step_end(self, run_context): print("Copying checkpoint to the buckets ends", flush=True) def syn_files(self): - process = Process(target=mox.file.copy_parallel, args=(self.cache_dir, self.bucket), name="checkpoint_sync") + process = Process( + target=mox.file.copy_parallel, + args=(self.cache_dir, self.bucket), + name="checkpoint_sync", + ) process.start() diff --git a/codegeex/mindspore/src/code_tokenizer.py b/codegeex/mindspore/src/code_tokenizer.py index a434746..15be0f2 100644 --- a/codegeex/mindspore/src/code_tokenizer.py +++ b/codegeex/mindspore/src/code_tokenizer.py @@ -6,7 +6,7 @@ def encode_whitespaces(text, start_extra_id: int, max_len: int): - """ Encode whitespaces to extra tokens in GPT-J. + """Encode whitespaces to extra tokens in GPT-J. >>> encode_whitespaces('a\\n b\\n c', 10, 10) 'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c' @@ -16,16 +16,18 @@ def push_acc_space(acc_len: int, text: str): if acc_len == 0: return text if acc_len == 1: - return text + ' ' - assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}' + return text + " " + assert ( + acc_len <= max_len + ), f"Max whitespace run length {max_len}, but found {acc_len}" extra_id = start_extra_id - 2 + acc_len - extra_token = f'<|extratoken_{extra_id}|>' + extra_token = f"<|extratoken_{extra_id}|>" return text + extra_token acc_len = 0 - res = '' + res = "" for ch in text: - if ch == ' ': + if ch == " ": acc_len += 1 if acc_len == max_len: res = push_acc_space(acc_len, res) @@ -41,7 +43,7 @@ def push_acc_space(acc_len: int, text: str): def decode_whitespaces(text: str, start_extra_id: int, max_len: int): - """ Decode the whitespace-encoded strings produced by encode_whitespace. + """Decode the whitespace-encoded strings produced by encode_whitespace. >>> text = 'a\\n b\\n c' >>> s, l = 10, 10 @@ -50,17 +52,17 @@ def decode_whitespaces(text: str, start_extra_id: int, max_len: int): """ for l in range(2, max_len + 1): token_id = start_extra_id - 2 + l - token = f'<|extratoken_{token_id}|>' - text = text.replace(token, ' ' * l) + token = f"<|extratoken_{token_id}|>" + text = text.replace(token, " " * l) return text class Code13BDictionary(object): def __init__( - self, - dict_file: str, - extra_token_ids: List[str] = None, - pad_to_vocab_size: int = -1, + self, + dict_file: str, + extra_token_ids: List[str] = None, + pad_to_vocab_size: int = -1, ): self._idx = dict() self._count = dict() @@ -130,28 +132,36 @@ def decode_tokens(self, tokens): class CodeTokenizer(object): def __init__( - self, - tokenizer: GPT2TokenizerFast = None, - start_extra_id: int = 10, - max_len: int = 10, - mode='13b', - dict_file: str = None, + self, + tokenizer: GPT2TokenizerFast = None, + start_extra_id: int = 10, + max_len: int = 10, + mode="13b", + dict_file: str = None, ): - self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") - if mode not in ['6b', '13b']: + self.tokenizer = ( + tokenizer + if tokenizer is not None + else AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + ) + if mode not in ["6b", "13b"]: raise ValueError(f"Invalid mode {mode}, choose from ['6b', '13b']") self.start_extra_id = start_extra_id self.max_len = max_len self.mode = mode - self.code_dict = Code13BDictionary(dict_file, pad_to_vocab_size=51200) if self.mode == '13b' else None + self.code_dict = ( + Code13BDictionary(dict_file, pad_to_vocab_size=51200) + if self.mode == "13b" + else None + ) self.eos_token_id = self.tokenizer.eos_token_id def encode_code(self, code: str): - if self.mode == '6b': + if self.mode == "6b": code = encode_whitespaces(code, self.start_extra_id, self.max_len) input_ids = self.tokenizer(code).input_ids - elif self.mode == '13b': + elif self.mode == "13b": code = encode_whitespaces(code, self.start_extra_id, self.max_len) input_ids = self.code_dict.map_tokens(self.tokenizer.encode(code)) input_ids = np.array(input_ids, dtype=np.int64).reshape(1, -1) @@ -159,13 +169,19 @@ def encode_code(self, code: str): return input_ids def decode_code(self, input_ids): - if self.mode == '6b': + if self.mode == "6b": texts = self.tokenizer.batch_decode(input_ids) - output_code = [decode_whitespaces(text, self.start_extra_id, self.max_len) for text in texts] + output_code = [ + decode_whitespaces(text, self.start_extra_id, self.max_len) + for text in texts + ] - elif self.mode == '13b': + elif self.mode == "13b": input_ids = [self.code_dict.decode_tokens(input_ids.tolist()[0])] texts = self.tokenizer.batch_decode(input_ids) - output_code = [decode_whitespaces(text, self.start_extra_id, self.max_len) for text in texts] + output_code = [ + decode_whitespaces(text, self.start_extra_id, self.max_len) + for text in texts + ] return output_code diff --git a/codegeex/mindspore/src/dataset.py b/codegeex/mindspore/src/dataset.py index 725ecea..b6b2e2a 100644 --- a/codegeex/mindspore/src/dataset.py +++ b/codegeex/mindspore/src/dataset.py @@ -50,7 +50,7 @@ def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset): # input_ids = input_ids[rank * dis : (rank + 1) * dis] if np.any(input_ids > 60000): raise ValueError("==exceed error") - # print("===input_ids tpye: ", input_ids.dtype, flush=True) + # print("===input_ids tpye: ", input_ids.dtype, flush=True) if not eod_reset: return input_ids seq_length = input_ids.shape[1] - 1 @@ -71,16 +71,29 @@ def get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset): for i in range(eod_index.size): # Reset position_ids and attention_mask considering EOD index = eod_index[i] - batch_attention_mask[bs_i, (index + 1):, :(index + 1)] = 0 - batch_position_ids[bs_i, (index + 1):] -= (index + 1 - prev_index) + batch_attention_mask[bs_i, (index + 1) :, : (index + 1)] = 0 + batch_position_ids[bs_i, (index + 1) :] -= index + 1 - prev_index prev_index = index + 1 return batch_input_ids, batch_position_ids, batch_attention_mask -def create_dataset(batch_size, data_path, args_opt, device_num=1, rank=0, drop=True, full_batch=False, - data_start_index=0, - eod_reset=False, eod_id=50256, column_name='input_ids', epoch=1, num_samples=None, - train_and_eval=False, val_ratio=0): +def create_dataset( + batch_size, + data_path, + args_opt, + device_num=1, + rank=0, + drop=True, + full_batch=False, + data_start_index=0, + eod_reset=False, + eod_id=50256, + column_name="input_ids", + epoch=1, + num_samples=None, + train_and_eval=False, + val_ratio=0, +): """ Create dataset Inputs: @@ -116,57 +129,102 @@ def create_dataset(batch_size, data_path, args_opt, device_num=1, rank=0, drop=T skip_num = args_opt.has_trained_steps * dis # skip_num = 0 num_parallel_workers = 4 - train_data = get_code_data_train(data_path, args_opt, skip_num=(skip_num // num_parallel_workers)) + train_data = get_code_data_train( + data_path, args_opt, skip_num=(skip_num // num_parallel_workers) + ) if train_and_eval: - val_data = get_code_data_eval("/home/work/sfs/xx/data_valid", - args_opt) # TODO: set as current validation set path + val_data = get_code_data_eval( + "/home/work/sfs/xx/data_valid", args_opt + ) # TODO: set as current validation set path else: val_data = None - dataset_train = ds.GeneratorDataset(train_data, column_names=[column_name], num_samples=num_samples, - num_shards=device_num, shard_id=rank, shuffle=True, - num_parallel_workers=num_parallel_workers) + dataset_train = ds.GeneratorDataset( + train_data, + column_names=[column_name], + num_samples=num_samples, + num_shards=device_num, + shard_id=rank, + shuffle=True, + num_parallel_workers=num_parallel_workers, + ) if train_and_eval: - dataset_val = ds.GeneratorDataset(val_data, column_names=[column_name], num_samples=num_samples, - num_shards=device_num, shard_id=rank, shuffle=True, - num_parallel_workers=num_parallel_workers) + dataset_val = ds.GeneratorDataset( + val_data, + column_names=[column_name], + num_samples=num_samples, + num_shards=device_num, + shard_id=rank, + shuffle=True, + num_parallel_workers=num_parallel_workers, + ) else: dataset_val = None type_cast_op = C.TypeCast(mstype.int32) type_cast_op_float = C.TypeCast(mstype.float16) - map_func = (lambda input_ids: get_input_data_batch_slice_map(input_ids, eod_id, rank, dis, eod_reset)) + map_func = lambda input_ids: get_input_data_batch_slice_map( + input_ids, eod_id, rank, dis, eod_reset + ) # If eod_reset enabled, another two inputs will be generated through input_ids dataset_train = dataset_train.skip(skip_num) if eod_reset: dataset_train = dataset_train.batch(dis, drop_remainder=drop) - dataset_train = dataset_train.map(operations=map_func, input_columns=[column_name], - output_columns=[column_name, "position_id", "attention_mask"], - column_order=[column_name, "position_id", "attention_mask"]) - dataset_train = dataset_train.map(input_columns="position_id", operations=type_cast_op) - dataset_train = dataset_train.map(input_columns="attention_mask", operations=type_cast_op_float) + dataset_train = dataset_train.map( + operations=map_func, + input_columns=[column_name], + output_columns=[column_name, "position_id", "attention_mask"], + column_order=[column_name, "position_id", "attention_mask"], + ) + dataset_train = dataset_train.map( + input_columns="position_id", operations=type_cast_op + ) + dataset_train = dataset_train.map( + input_columns="attention_mask", operations=type_cast_op_float + ) else: - dataset_train = dataset_train.map(input_columns=[column_name], operations=type_cast_op) + dataset_train = dataset_train.map( + input_columns=[column_name], operations=type_cast_op + ) dataset_train = dataset_train.batch(batch_size, drop_remainder=drop) - dataset_train = dataset_train.map(operations=map_func, input_columns=[column_name], - output_columns=[column_name]) - dataset_train = dataset_train.map(input_columns=column_name, operations=type_cast_op) + dataset_train = dataset_train.map( + operations=map_func, + input_columns=[column_name], + output_columns=[column_name], + ) + dataset_train = dataset_train.map( + input_columns=column_name, operations=type_cast_op + ) dataset_train = dataset_train.repeat(epoch) if dataset_val is not None: if eod_reset: dataset_val = dataset_val.batch(dis, drop_remainder=drop) - dataset_val = dataset_val.map(operations=map_func, input_columns=[column_name], - output_columns=[column_name, "position_id", "attention_mask"], - column_order=[column_name, "position_id", "attention_mask"]) - dataset_val = dataset_val.map(input_columns="position_id", operations=type_cast_op) - dataset_val = dataset_val.map(input_columns="attention_mask", operations=type_cast_op_float) + dataset_val = dataset_val.map( + operations=map_func, + input_columns=[column_name], + output_columns=[column_name, "position_id", "attention_mask"], + column_order=[column_name, "position_id", "attention_mask"], + ) + dataset_val = dataset_val.map( + input_columns="position_id", operations=type_cast_op + ) + dataset_val = dataset_val.map( + input_columns="attention_mask", operations=type_cast_op_float + ) else: - dataset_val = dataset_val.map(input_columns=[column_name], operations=type_cast_op) + dataset_val = dataset_val.map( + input_columns=[column_name], operations=type_cast_op + ) dataset_val = dataset_val.batch(batch_size, drop_remainder=drop) - dataset_val = dataset_val.map(operations=map_func, input_columns=[column_name], - output_columns=[column_name]) - dataset_val = dataset_val.map(input_columns=column_name, operations=type_cast_op) + dataset_val = dataset_val.map( + operations=map_func, + input_columns=[column_name], + output_columns=[column_name], + ) + dataset_val = dataset_val.map( + input_columns=column_name, operations=type_cast_op + ) return dataset_train, dataset_val @@ -177,8 +235,11 @@ def get_code_data_train(code_data_path, args_opt, process_fn=None, scale=1, skip for dir in sorted(os.listdir(code_data_path)): sub_dirs = os.listdir(os.path.join(code_data_path, dir)) for sub_dir in sub_dirs: - if os.path.exists(os.path.join(code_data_path, dir, sub_dir, 'data.mdb')) and os.path.exists( - os.path.join(code_data_path, dir, sub_dir, 'lock.mdb')): + if os.path.exists( + os.path.join(code_data_path, dir, sub_dir, "data.mdb") + ) and os.path.exists( + os.path.join(code_data_path, dir, sub_dir, "lock.mdb") + ): paths.append(os.path.join(code_data_path, dir, sub_dir)) for full_path in paths: @@ -205,8 +266,11 @@ def get_code_data_eval(code_data_path, args_opt, process_fn=None, scale=1): for dir in sorted(os.listdir(code_data_path)): sub_dirs = os.listdir(os.path.join(code_data_path, dir)) for sub_dir in sub_dirs: - if os.path.exists(os.path.join(code_data_path, dir, sub_dir, 'data.mdb')) and os.path.exists( - os.path.join(code_data_path, dir, sub_dir, 'lock.mdb')): + if os.path.exists( + os.path.join(code_data_path, dir, sub_dir, "data.mdb") + ) and os.path.exists( + os.path.join(code_data_path, dir, sub_dir, "lock.mdb") + ): paths.append(os.path.join(code_data_path, dir, sub_dir)) for full_path in paths: diff --git a/codegeex/mindspore/src/dataset_finetune.py b/codegeex/mindspore/src/dataset_finetune.py index dd85d28..c477777 100644 --- a/codegeex/mindspore/src/dataset_finetune.py +++ b/codegeex/mindspore/src/dataset_finetune.py @@ -50,7 +50,7 @@ def get_input_data_batch_slice_map(input_ids, loss_mask, eod_id, rank, dis, eod_ # input_ids = input_ids[rank * dis : (rank + 1) * dis] if np.any(input_ids > 60000): raise ValueError("==exceed error") - # print("===input_ids tpye: ", input_ids.dtype, flush=True) + # print("===input_ids tpye: ", input_ids.dtype, flush=True) if not eod_reset: return input_ids seq_length = input_ids.shape[1] - 1 @@ -72,16 +72,29 @@ def get_input_data_batch_slice_map(input_ids, loss_mask, eod_id, rank, dis, eod_ for i in range(eod_index.size): # Reset position_ids and attention_mask considering EOD index = eod_index[i] - batch_attention_mask[bs_i, (index + 1):, :(index + 1)] = 0 - batch_position_ids[bs_i, (index + 1):] -= (index + 1 - prev_index) + batch_attention_mask[bs_i, (index + 1) :, : (index + 1)] = 0 + batch_position_ids[bs_i, (index + 1) :] -= index + 1 - prev_index prev_index = index + 1 # print(f"===batch_loss_mask: {batch_loss_mask}, shape: {batch_loss_mask.shape}, nonzero: {batch_loss_mask.nonzero()}") return batch_input_ids, batch_loss_mask, batch_position_ids, batch_attention_mask -def create_dataset(batch_size, data_path, args_opt, device_num=1, rank=0, drop=True, full_batch=False, - data_start_index=0, - eod_reset=False, eod_id=50256, epoch=1, num_samples=None, train_and_eval=False, val_ratio=0): +def create_dataset( + batch_size, + data_path, + args_opt, + device_num=1, + rank=0, + drop=True, + full_batch=False, + data_start_index=0, + eod_reset=False, + eod_id=50256, + epoch=1, + num_samples=None, + train_and_eval=False, + val_ratio=0, +): """ Create dataset Inputs: @@ -117,48 +130,83 @@ def create_dataset(batch_size, data_path, args_opt, device_num=1, rank=0, drop=T # skip_num = args_opt.has_trained_steps * dis # skip_num = 0 num_parallel_workers = 4 - train_data = get_code_data(data_path, 'train', args_opt) + train_data = get_code_data(data_path, "train", args_opt) if train_and_eval: - val_data = get_code_data(data_path, 'val', args_opt) + val_data = get_code_data(data_path, "val", args_opt) else: val_data = None - dataset_train = ds.GeneratorDataset(train_data, column_names=['input_ids', 'loss_mask'], num_samples=num_samples, - num_shards=device_num, shard_id=rank, shuffle=True, - num_parallel_workers=num_parallel_workers) + dataset_train = ds.GeneratorDataset( + train_data, + column_names=["input_ids", "loss_mask"], + num_samples=num_samples, + num_shards=device_num, + shard_id=rank, + shuffle=True, + num_parallel_workers=num_parallel_workers, + ) if train_and_eval: - dataset_val = ds.GeneratorDataset(val_data, column_names=['input_ids', 'loss_mask'], num_samples=num_samples, - num_shards=device_num, shard_id=rank, shuffle=True, - num_parallel_workers=num_parallel_workers) + dataset_val = ds.GeneratorDataset( + val_data, + column_names=["input_ids", "loss_mask"], + num_samples=num_samples, + num_shards=device_num, + shard_id=rank, + shuffle=True, + num_parallel_workers=num_parallel_workers, + ) else: dataset_val = None type_cast_op = C.TypeCast(mstype.int32) type_cast_op_float = C.TypeCast(mstype.float16) type_cast_op_float2 = C.TypeCast(mstype.float32) - map_func = ( - lambda input_ids, loss_mask: get_input_data_batch_slice_map(input_ids, loss_mask, eod_id, rank, dis, eod_reset)) + map_func = lambda input_ids, loss_mask: get_input_data_batch_slice_map( + input_ids, loss_mask, eod_id, rank, dis, eod_reset + ) # If eod_reset enabled, another two inputs will be generated through input_ids # dataset_train = dataset_train.skip(skip_num) dataset_train = dataset_train.batch(dis, drop_remainder=drop) - dataset_train = dataset_train.map(operations=map_func, input_columns=["input_ids", "loss_mask"], - output_columns=["input_ids", "loss_mask", "position_id", "attention_mask"], - column_order=["input_ids", "loss_mask", "position_id", "attention_mask"]) - dataset_train = dataset_train.map(input_columns="position_id", operations=type_cast_op) - dataset_train = dataset_train.map(input_columns="attention_mask", operations=type_cast_op_float) - dataset_train = dataset_train.map(input_columns="loss_mask", operations=type_cast_op_float2) - dataset_train = dataset_train.map(input_columns="input_ids", operations=type_cast_op) + dataset_train = dataset_train.map( + operations=map_func, + input_columns=["input_ids", "loss_mask"], + output_columns=["input_ids", "loss_mask", "position_id", "attention_mask"], + column_order=["input_ids", "loss_mask", "position_id", "attention_mask"], + ) + dataset_train = dataset_train.map( + input_columns="position_id", operations=type_cast_op + ) + dataset_train = dataset_train.map( + input_columns="attention_mask", operations=type_cast_op_float + ) + dataset_train = dataset_train.map( + input_columns="loss_mask", operations=type_cast_op_float2 + ) + dataset_train = dataset_train.map( + input_columns="input_ids", operations=type_cast_op + ) dataset_train = dataset_train.repeat(epoch) if dataset_val is not None: dataset_val = dataset_val.batch(dis, drop_remainder=drop) - dataset_val = dataset_val.map(operations=map_func, input_columns=["input_ids", "loss_mask"], - output_columns=["input_ids", "loss_mask", "position_id", "attention_mask"], - column_order=["input_ids", "loss_mask", "position_id", "attention_mask"]) - dataset_val = dataset_val.map(input_columns="position_id", operations=type_cast_op) - dataset_val = dataset_val.map(input_columns="attention_mask", operations=type_cast_op_float) - dataset_val = dataset_val.map(input_columns="loss_mask", operations=type_cast_op_float2) - dataset_val = dataset_val.map(input_columns="input_ids", operations=type_cast_op) + dataset_val = dataset_val.map( + operations=map_func, + input_columns=["input_ids", "loss_mask"], + output_columns=["input_ids", "loss_mask", "position_id", "attention_mask"], + column_order=["input_ids", "loss_mask", "position_id", "attention_mask"], + ) + dataset_val = dataset_val.map( + input_columns="position_id", operations=type_cast_op + ) + dataset_val = dataset_val.map( + input_columns="attention_mask", operations=type_cast_op_float + ) + dataset_val = dataset_val.map( + input_columns="loss_mask", operations=type_cast_op_float2 + ) + dataset_val = dataset_val.map( + input_columns="input_ids", operations=type_cast_op + ) return dataset_train, dataset_val diff --git a/codegeex/mindspore/src/generate.py b/codegeex/mindspore/src/generate.py index 0fb9bc1..ea9add2 100644 --- a/codegeex/mindspore/src/generate.py +++ b/codegeex/mindspore/src/generate.py @@ -118,8 +118,12 @@ def generate(model, origin_inputs, config, verbose=False): pad_length = seq_length - origin_inputs.shape[-1] # Pad original inputs to seq_length print("Original shape:", origin_inputs.shape) - input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), - 'constant', constant_values=(end_token, end_token)) + input_ids = np.pad( + origin_inputs, + ((0, 0), (0, pad_length)), + "constant", + constant_values=(end_token, end_token), + ) # print("input_ids is ", input_ids) # A single loop generates one token, loop until reaching target seq_length or generating eod token @@ -133,8 +137,11 @@ def generate(model, origin_inputs, config, verbose=False): log_probs = model.predict(inputs, current_index) # Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results log_probs = log_probs.asnumpy().reshape(1, -1) - log_probs_revised = log_probs - frequency_list * \ - frequency_penalty - (frequency_list > 0) * presence_penalty + log_probs_revised = ( + log_probs + - frequency_list * frequency_penalty + - (frequency_list > 0) * presence_penalty + ) log_probs_revised /= temperature p, p_args = sampler(log_probs_revised, top_p, top_k_num, use_pynative) @@ -145,14 +152,15 @@ def generate(model, origin_inputs, config, verbose=False): print("=== log_probs_revised is", log_probs_revised) print("=== p:", p, "shape:", p.shape) print("=== p_args:", p_args, "shape", p_args.shape) - print(f"=== Length {valid_length}, target index {target_index}, chosen token {p_args[target_index]}.") + print( + f"=== Length {valid_length}, target index {target_index}, chosen token {p_args[target_index]}." + ) # Stop judgment if p_args[target_index] == end_token or valid_length == target_length - 1: outputs = input_ids if verbose: - print( - f"=== generation end, last token: {p_args[target_index]}") + print(f"=== generation end, last token: {p_args[target_index]}") break # update frequency list @@ -202,8 +210,12 @@ def generate_increment(model, origin_inputs, config, verbose=False): frequency_list = np.array([[0 for _ in range(vocab_embedding_vocab_size)]]) pad_length = seq_length - origin_inputs.shape[-1] # Pad original inputs to seq_length - input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), - 'constant', constant_values=(end_token, end_token)) + input_ids = np.pad( + origin_inputs, + ((0, 0), (0, pad_length)), + "constant", + constant_values=(end_token, end_token), + ) print("input_ids is ", input_ids) # Indicate the exact token position @@ -217,8 +229,9 @@ def generate_increment(model, origin_inputs, config, verbose=False): # Claim the first graph model.predict_network.add_flags_recursive(is_first_iteration=True) # Call a single inference with input size of (bs, seq_length) - logits = model.predict(Tensor(input_ids, mstype.int32), - current_index, init, batch_valid_length) + logits = model.predict( + Tensor(input_ids, mstype.int32), current_index, init, batch_valid_length + ) # Claim the second graph and set not_init to true init = init_true @@ -231,8 +244,11 @@ def generate_increment(model, origin_inputs, config, verbose=False): log_probs = logits.reshape(1, vocab_embedding_vocab_size) # Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results - log_probs_revised = log_probs - frequency_list * \ - frequency_penalty - (frequency_list > 0) * presence_penalty + log_probs_revised = ( + log_probs + - frequency_list * frequency_penalty + - (frequency_list > 0) * presence_penalty + ) log_probs_revised /= temperature p, p_args = sampler(log_probs_revised, top_p, top_k_num, use_pynative) @@ -243,7 +259,9 @@ def generate_increment(model, origin_inputs, config, verbose=False): print("=== log_probs_revised is", log_probs_revised) print("=== p:", p, "shape:", p.shape) print("=== p_args:", p_args, "shape", p_args.shape) - print(f"=== Length {valid_length}, target index {target_index}, chosen token {p_args[target_index]}.") + print( + f"=== Length {valid_length}, target index {target_index}, chosen token {p_args[target_index]}." + ) # Stop judgment if p_args[target_index] == end_token or valid_length == target_length - 1: @@ -262,7 +280,6 @@ def generate_increment(model, origin_inputs, config, verbose=False): outputs.append(int(target)) # Call a single inference with input size of (bs, 1) - logits = model.predict(input_id, current_index, - init, batch_valid_length) + logits = model.predict(input_id, current_index, init, batch_valid_length) # Return valid outputs out of padded outputs return np.array(outputs) diff --git a/codegeex/mindspore/src/generate_finetune.py b/codegeex/mindspore/src/generate_finetune.py index aad974d..3e6909f 100644 --- a/codegeex/mindspore/src/generate_finetune.py +++ b/codegeex/mindspore/src/generate_finetune.py @@ -32,7 +32,9 @@ def topk_fun(logits, topk=5): return value, index -def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_index=[]): +def sampler( + log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_index=[] +): for i, bad_words in enumerate(bad_words_index): for bad_word in bad_words: log_probs_revised[i, bad_word] = -10000 @@ -75,7 +77,9 @@ def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_i return p, p_args -def generate_increment(model, origin_inputs, origin_length, config, tokenizer, verbose=False): +def generate_increment( + model, origin_inputs, origin_length, config, tokenizer, verbose=False +): """ Text generation for incremental inference @@ -101,7 +105,9 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v batch_size, valid_length = origin_inputs.shape # Init outputs with original inputs - outputs = [[origin_inputs[i][j] for j in range(valid_length)] for i in range(batch_size)] + outputs = [ + [origin_inputs[i][j] for j in range(valid_length)] for i in range(batch_size) + ] output_codes = [[] for _ in range(batch_size)] # If target length exceeds seq_length, use seq_length instead target_lengths = [min(seq_length, l + max_generate_length) for l in origin_length] @@ -112,8 +118,12 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v frequency_list = np.zeros((batch_size, vocab_embedding_vocab_size)) pad_length = seq_length - origin_inputs.shape[-1] # Pad original inputs to seq_length - input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), - 'constant', constant_values=(end_token, end_token)) + input_ids = np.pad( + origin_inputs, + ((0, 0), (0, pad_length)), + "constant", + constant_values=(end_token, end_token), + ) if verbose: print("input_ids is ", input_ids) @@ -121,7 +131,10 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v current_indexes = [max(l - 1, 0) for l in valid_lengths] # batch_valid_length = Tensor(np.array([current_index for _ in range(batch_size)]), mstype.int32) batch_valid_length = Tensor(np.array(current_indexes), mstype.int32) - current_indexes = Tensor(np.array([current_indexes[i] + i * seq_length for i in range(batch_size)]), mstype.int32) + current_indexes = Tensor( + np.array([current_indexes[i] + i * seq_length for i in range(batch_size)]), + mstype.int32, + ) # For first graph, not_init should be false init_true = Tensor([True], mstype.bool_) init_false = Tensor([False], mstype.bool_) @@ -129,15 +142,20 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v # Claim the first graph model.predict_network.add_flags_recursive(is_first_iteration=True) # Call a single inference with input size of (bs, seq_length) - logits = model.predict(Tensor(input_ids, mstype.int32), - current_indexes, init, batch_valid_length) + logits = model.predict( + Tensor(input_ids, mstype.int32), current_indexes, init, batch_valid_length + ) # Claim the second graph and set not_init to true init = init_true model.predict_network.add_flags_recursive(is_first_iteration=False) - comments_index = [2, ] # '#': 2, ' #': 1303 - newline_index = [198, ] # '\n': 198 + comments_index = [ + 2, + ] # '#': 2, ' #': 1303 + newline_index = [ + 198, + ] # '\n': 198 # A single loop generates one token, loop until reaching target seq_length or generating eod token while not all(gen_end): # Reshape the output logits @@ -145,12 +163,22 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v log_probs = logits.reshape(batch_size, vocab_embedding_vocab_size) # Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results - log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty + log_probs_revised = ( + log_probs + - frequency_list * frequency_penalty + - (frequency_list > 0) * presence_penalty + ) log_probs_revised /= temperature bad_words_index = [[] for _ in range(batch_size)] - p, p_args = sampler(log_probs_revised, top_p, top_k_num, use_pynative, bad_words_index=bad_words_index) + p, p_args = sampler( + log_probs_revised, + top_p, + top_k_num, + use_pynative, + bad_words_index=bad_words_index, + ) # Random select a token as final output for this round target_index = np.zeros(batch_size, dtype=np.int64) for i in range(batch_size): @@ -160,11 +188,14 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v print("=== p:", p, "shape:", p.shape) print("=== p_args:", p_args, "shape", p_args.shape) print( - f"=== Length {valid_lengths}, target index {target_index}, chosen token {p_args[np.arange(batch_size), target_index]}, generation end status {gen_end}.") + f"=== Length {valid_lengths}, target index {target_index}, chosen token {p_args[np.arange(batch_size), target_index]}, generation end status {gen_end}." + ) # Update frequency list target = p_args[np.arange(batch_size), target_index] - frequency_list[np.arange(batch_size), target] = frequency_list[np.arange(batch_size), target] + 1 + frequency_list[np.arange(batch_size), target] = ( + frequency_list[np.arange(batch_size), target] + 1 + ) batch_valid_length = Tensor(np.array(valid_lengths), mstype.int32) current_indexes = Tensor(np.arange(batch_size, dtype=np.int32), mstype.int32) @@ -182,6 +213,5 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v gen_end[i] = True # Call a single inference with input size of (bs, 1) - logits = model.predict(input_id, current_indexes, - init, batch_valid_length) + logits = model.predict(input_id, current_indexes, init, batch_valid_length) return tokenizer.decode_code(output_codes) diff --git a/codegeex/mindspore/src/generate_greedy.py b/codegeex/mindspore/src/generate_greedy.py index 83251c7..3eab89b 100644 --- a/codegeex/mindspore/src/generate_greedy.py +++ b/codegeex/mindspore/src/generate_greedy.py @@ -31,7 +31,9 @@ def topk_fun(logits, topk=5): return value, index -def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_index=[]): +def sampler( + log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_index=[] +): for i, bad_words in enumerate(bad_words_index): for bad_word in bad_words: log_probs_revised[i, bad_word] = -10000 @@ -42,7 +44,9 @@ def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_i return log_probs_revised.argmax(axis=1) -def generate_increment(model, origin_inputs, origin_length, config, tokenizer, verbose=False): +def generate_increment( + model, origin_inputs, origin_length, config, tokenizer, verbose=False +): """ Text generation for incremental inference @@ -68,7 +72,9 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v batch_size, valid_length = origin_inputs.shape # Init outputs with original inputs - outputs = [[origin_inputs[i][j] for j in range(valid_length)] for i in range(batch_size)] + outputs = [ + [origin_inputs[i][j] for j in range(valid_length)] for i in range(batch_size) + ] output_codes = [[] for _ in range(batch_size)] # If target length exceeds seq_length, use seq_length instead target_lengths = [min(seq_length, l + max_generate_length) for l in origin_length] @@ -79,8 +85,12 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v frequency_list = np.zeros((batch_size, vocab_embedding_vocab_size)) pad_length = seq_length - origin_inputs.shape[-1] # Pad original inputs to seq_length - input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), - 'constant', constant_values=(end_token, end_token)) + input_ids = np.pad( + origin_inputs, + ((0, 0), (0, pad_length)), + "constant", + constant_values=(end_token, end_token), + ) if verbose: print("input_ids is ", input_ids) @@ -88,7 +98,10 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v current_indexes = [max(l - 1, 0) for l in valid_lengths] # batch_valid_length = Tensor(np.array([current_index for _ in range(batch_size)]), mstype.int32) batch_valid_length = Tensor(np.array(current_indexes), mstype.int32) - current_indexes = Tensor(np.array([current_indexes[i] + i * seq_length for i in range(batch_size)]), mstype.int32) + current_indexes = Tensor( + np.array([current_indexes[i] + i * seq_length for i in range(batch_size)]), + mstype.int32, + ) # For first graph, not_init should be false init_true = Tensor([True], mstype.bool_) init_false = Tensor([False], mstype.bool_) @@ -96,15 +109,20 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v # Claim the first graph model.predict_network.add_flags_recursive(is_first_iteration=True) # Call a single inference with input size of (bs, seq_length) - logits = model.predict(Tensor(input_ids, mstype.int32), - current_indexes, init, batch_valid_length) + logits = model.predict( + Tensor(input_ids, mstype.int32), current_indexes, init, batch_valid_length + ) # Claim the second graph and set not_init to true init = init_true model.predict_network.add_flags_recursive(is_first_iteration=False) - comments_index = [2, ] # '#': 2, ' #': 1303 - newline_index = [198, ] # '\n': 198 + comments_index = [ + 2, + ] # '#': 2, ' #': 1303 + newline_index = [ + 198, + ] # '\n': 198 # A single loop generates one token, loop until reaching target seq_length or generating eod token while not all(gen_end): # Reshape the output logits @@ -112,18 +130,32 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v log_probs = logits.reshape(batch_size, vocab_embedding_vocab_size) # Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results - log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty + log_probs_revised = ( + log_probs + - frequency_list * frequency_penalty + - (frequency_list > 0) * presence_penalty + ) bad_words_index = [[] for _ in range(batch_size)] - target_index = sampler(log_probs_revised, top_p, top_k_num, use_pynative, bad_words_index=bad_words_index) + target_index = sampler( + log_probs_revised, + top_p, + top_k_num, + use_pynative, + bad_words_index=bad_words_index, + ) if verbose: - print(f"=== Length {valid_lengths}, target index {target_index}, generation end status {gen_end}.") + print( + f"=== Length {valid_lengths}, target index {target_index}, generation end status {gen_end}." + ) # Update frequency list target = target_index - frequency_list[np.arange(batch_size), target] = frequency_list[np.arange(batch_size), target] + 1 + frequency_list[np.arange(batch_size), target] = ( + frequency_list[np.arange(batch_size), target] + 1 + ) batch_valid_length = Tensor(np.array(valid_lengths), mstype.int32) current_indexes = Tensor(np.arange(batch_size, dtype=np.int32), mstype.int32) @@ -141,6 +173,5 @@ def generate_increment(model, origin_inputs, origin_length, config, tokenizer, v gen_end[i] = True # Call a single inference with input size of (bs, 1) - logits = model.predict(input_id, current_indexes, - init, batch_valid_length) + logits = model.predict(input_id, current_indexes, init, batch_valid_length) return tokenizer.decode_code(output_codes) diff --git a/codegeex/mindspore/src/generate_humaneval.py b/codegeex/mindspore/src/generate_humaneval.py index c6c2486..cbf9ef1 100644 --- a/codegeex/mindspore/src/generate_humaneval.py +++ b/codegeex/mindspore/src/generate_humaneval.py @@ -27,7 +27,7 @@ def is_code_generation_finished(text: str): Checks whether the generated code text is finished. """ # end_words = ['\ndef', '\nclass', '\nif', '\n#', '\nprint', '<|endoftext|>'] - end_words = ['\n}'] + end_words = ["\n}"] for w in end_words: if w in text: return True @@ -39,10 +39,10 @@ def cleanup_text(text: str): Cleans up the generated code text. """ # end_words = ['\ndef', '\nclass', '\nif', '\n#', '\nprint', '<|endoftext|>'] - end_words = ['\n}'] + end_words = ["\n}"] for w in end_words: if text.endswith(w): - text = text[:-len(w)] + text = text[: -len(w)] return text @@ -51,7 +51,7 @@ def truncate_text(text: str): Cleans up the generated code text. """ # end_words = ['\ndef', '\nclass', '\nif', '\n#', '\nprint', '<|endoftext|>'] - end_words = ['\n}'] + end_words = ["\n}"] for w in end_words: idx = text.find(w) if idx != -1: @@ -74,7 +74,9 @@ def topk_fun(logits, topk=5): return value, index -def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_index=[]): +def sampler( + log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_index=[] +): for i, bad_words in enumerate(bad_words_index): for bad_word in bad_words: log_probs_revised[i, bad_word] = -10000 @@ -98,7 +100,7 @@ def sampler(log_probs_revised, top_p, top_k_num, use_pynative=False, bad_words_i cumsum_p = np.cumsum(sorted_p, axis=1) # index = index[0] # sorted_logits = sorted_logits[0] - # cumsum_p = cumsum_p[0] + # cumsum_p = cumsum_p[0] top_p_num = (cumsum_p < top_p).sum(axis=1) + 1 # Get the corresponding probs and indices @@ -154,12 +156,21 @@ def generate_increment(model, origin_inputs, config, tokenizer, verbose=False): batch_size, valid_length = origin_inputs.shape # Init outputs with original inputs # outputs = deepcopy(origin_inputs) - outputs = [[origin_inputs[i][j] for j in range(valid_length)] for i in range(batch_size)] + outputs = [ + [origin_inputs[i][j] for j in range(valid_length)] for i in range(batch_size) + ] output_codes = ["" for _ in range(batch_size)] # If target length exceeds seq_length, use seq_length instead target_length = valid_length + max_generate_length if verbose: - print("target_length was ", valid_length, " + ", max_generate_length, " = ", target_length) + print( + "target_length was ", + valid_length, + " + ", + max_generate_length, + " = ", + target_length, + ) target_length = seq_length if target_length > seq_length else target_length if verbose: print("target_length is ", target_length) @@ -171,15 +182,24 @@ def generate_increment(model, origin_inputs, config, tokenizer, verbose=False): frequency_list = np.zeros((batch_size, vocab_embedding_vocab_size)) pad_length = seq_length - origin_inputs.shape[-1] # Pad original inputs to seq_length - input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), - 'constant', constant_values=(end_token, end_token)) + input_ids = np.pad( + origin_inputs, + ((0, 0), (0, pad_length)), + "constant", + constant_values=(end_token, end_token), + ) if verbose: print("input_ids is ", input_ids) # Indicate the exact token position current_index = valid_length - 1 if valid_length - 1 > 0 else 0 - batch_valid_length = Tensor(np.array([current_index for _ in range(batch_size)]), mstype.int32) - current_index = Tensor(np.array([current_index + i * seq_length for i in range(batch_size)]), mstype.int32) + batch_valid_length = Tensor( + np.array([current_index for _ in range(batch_size)]), mstype.int32 + ) + current_index = Tensor( + np.array([current_index + i * seq_length for i in range(batch_size)]), + mstype.int32, + ) # For first graph, not_init should be false init_true = Tensor([True], mstype.bool_) init_false = Tensor([False], mstype.bool_) @@ -187,15 +207,20 @@ def generate_increment(model, origin_inputs, config, tokenizer, verbose=False): # Claim the first graph model.predict_network.add_flags_recursive(is_first_iteration=True) # Call a single inference with input size of (bs, seq_length) - logits = model.predict(Tensor(input_ids, mstype.int32), - current_index, init, batch_valid_length) + logits = model.predict( + Tensor(input_ids, mstype.int32), current_index, init, batch_valid_length + ) # Claim the second graph and set not_init to true init = init_true model.predict_network.add_flags_recursive(is_first_iteration=False) - comments_index = [2, ] # '#': 2, ' #': 1303 - newline_index = [198, ] # '\n': 198 + comments_index = [ + 2, + ] # '#': 2, ' #': 1303 + newline_index = [ + 198, + ] # '\n': 198 # A single loop generates one token, loop until reaching target seq_length or generating eod token while valid_length < target_length: if all(gen_end): @@ -205,8 +230,11 @@ def generate_increment(model, origin_inputs, config, tokenizer, verbose=False): log_probs = logits.reshape(batch_size, vocab_embedding_vocab_size) # Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results - log_probs_revised = log_probs - frequency_list * \ - frequency_penalty - (frequency_list > 0) * presence_penalty + log_probs_revised = ( + log_probs + - frequency_list * frequency_penalty + - (frequency_list > 0) * presence_penalty + ) log_probs_revised /= temperature bad_words_index = [[] for _ in range(batch_size)] @@ -214,7 +242,13 @@ def generate_increment(model, origin_inputs, config, tokenizer, verbose=False): # if not allow_comments[i]: # bad_words_index[i] += comments_index - p, p_args = sampler(log_probs_revised, top_p, top_k_num, use_pynative, bad_words_index=bad_words_index) + p, p_args = sampler( + log_probs_revised, + top_p, + top_k_num, + use_pynative, + bad_words_index=bad_words_index, + ) # Random select a token as final output for this round target_index = np.zeros(batch_size, dtype=np.int64) for i in range(batch_size): @@ -225,13 +259,18 @@ def generate_increment(model, origin_inputs, config, tokenizer, verbose=False): print("=== p:", p, "shape:", p.shape) print("=== p_args:", p_args, "shape", p_args.shape) print( - f"=== Length {valid_length}, target index {target_index}, chosen token {p_args[np.arange(batch_size), target_index]}, generation end status {gen_end}.") + f"=== Length {valid_length}, target index {target_index}, chosen token {p_args[np.arange(batch_size), target_index]}, generation end status {gen_end}." + ) # Update frequency list target = p_args[np.arange(batch_size), target_index] - frequency_list[np.arange(batch_size), target] = frequency_list[np.arange(batch_size), target] + 1 + frequency_list[np.arange(batch_size), target] = ( + frequency_list[np.arange(batch_size), target] + 1 + ) - batch_valid_length = Tensor(np.array([valid_length for _ in range(batch_size)]), mstype.int32) + batch_valid_length = Tensor( + np.array([valid_length for _ in range(batch_size)]), mstype.int32 + ) current_index = Tensor(np.arange(batch_size, dtype=np.int32), mstype.int32) input_id = Tensor([target], mstype.int32).reshape(-1, 1) for i in range(batch_size): @@ -240,15 +279,14 @@ def generate_increment(model, origin_inputs, config, tokenizer, verbose=False): if is_code_generation_finished(output_codes[i]): gen_end[i] = True output_codes[i] = truncate_text(output_codes[i]) - if output_codes[i].endswith('#'): + if output_codes[i].endswith("#"): allow_comments_next[i] = False - elif output_codes[i].endswith('\n'): + elif output_codes[i].endswith("\n"): allow_comments[i] = allow_comments_next[i] allow_comments_next[i] = True outputs[i].append(int(target[i])) # Call a single inference with input size of (bs, 1) - logits = model.predict(input_id, current_index, - init, batch_valid_length) + logits = model.predict(input_id, current_index, init, batch_valid_length) valid_length += 1 return output_codes diff --git a/codegeex/mindspore/src/metrics.py b/codegeex/mindspore/src/metrics.py index 4452d60..82478be 100644 --- a/codegeex/mindspore/src/metrics.py +++ b/codegeex/mindspore/src/metrics.py @@ -35,7 +35,7 @@ def __init__(self, data_length): pipeline_stages = context.get_auto_parallel_context("pipeline_stages") per_stage_device_num = get_group_size() // pipeline_stages stage_id = get_rank() // per_stage_device_num - self.is_last_stage = (stage_id == pipeline_stages - 1) + self.is_last_stage = stage_id == pipeline_stages - 1 def clear(self): """Clear the internal evaluation result.""" @@ -72,7 +72,7 @@ def __init__(self, data_length): pipeline_stages = context.get_auto_parallel_context("pipeline_stages") per_stage_device_num = get_group_size() // pipeline_stages stage_id = get_rank() // per_stage_device_num - self.is_last_stage = (stage_id == pipeline_stages - 1) + self.is_last_stage = stage_id == pipeline_stages - 1 def clear(self): """Clear the internal evaluation result.""" diff --git a/codegeex/mindspore/src/pangu_alpha.py b/codegeex/mindspore/src/pangu_alpha.py index 32738ac..080875b 100644 --- a/codegeex/mindspore/src/pangu_alpha.py +++ b/codegeex/mindspore/src/pangu_alpha.py @@ -22,6 +22,7 @@ from mindspore import Tensor, Parameter from mindspore.common.initializer import initializer from mindspore.nn import Cell + # from mindspore.parallel.nn.layers import _LayerNorm from mindspore.nn.transformer.layers import _Dropout, _LayerNorm from mindspore.ops import functional as F @@ -89,9 +90,7 @@ def __init__(self, config): self.use_past = config.use_past self.batch_size = config.batch_size - def construct( - self, input_ids, input_position, init_reset, batch_valid_length - ): + def construct(self, input_ids, input_position, init_reset, batch_valid_length): word_embedding, word_table = self.word_embedding(input_ids) if self.use_past and not self.is_first_iteration: _, seq_length = F.shape(input_ids) @@ -110,20 +109,20 @@ class QueryLayer(TransformerEncoderLayer): r"""Query Layer at the final layer.""" def __init__( - self, - batch_size, - hidden_size, - ffn_hidden_size, - num_heads, - seq_length, - attention_dropout_rate=0.1, - hidden_dropout_rate=0.1, - post_layernorm_residual=False, - param_init_type=mstype.float32, - hidden_act="fast_gelu", - use_past=False, - parallel_config=None, - softmax_compute_type=mstype.float32, + self, + batch_size, + hidden_size, + ffn_hidden_size, + num_heads, + seq_length, + attention_dropout_rate=0.1, + hidden_dropout_rate=0.1, + post_layernorm_residual=False, + param_init_type=mstype.float32, + hidden_act="fast_gelu", + use_past=False, + parallel_config=None, + softmax_compute_type=mstype.float32, ): super(QueryLayer, self).__init__( batch_size=batch_size, @@ -142,12 +141,12 @@ def __init__( ) def construct( - self, - x, - query_vector, - input_mask, - init_reset=True, - batch_valid_length=None, + self, + x, + query_vector, + input_mask, + init_reset=True, + batch_valid_length=None, ): r""" The forward process of the block. @@ -229,9 +228,7 @@ class PanGuHead(Cell): logits: Tensor, the logits of the corresponding inputs """ - def __init__( - self, hidden_size, compute_type=mstype.float16, parallel_config=None - ): + def __init__(self, hidden_size, compute_type=mstype.float16, parallel_config=None): super(PanGuHead, self).__init__() if parallel_config.vocab_emb_dp: self.matmul = P.MatMul(transpose_b=True).shard( @@ -252,24 +249,22 @@ def __init__( def construct(self, state, embed): state = P.Reshape()(state, (-1, self.hidden_size)) # output logits over vocabulary [bs*seq_length, vocab_size] - logits = self.matmul( - self.cast(state, self.dtype), self.cast(embed, self.dtype) - ) + logits = self.matmul(self.cast(state, self.dtype), self.cast(embed, self.dtype)) return logits def set_parallel_configure_for_layer( - network, layer_id, offset, parallel_config, layers + network, layer_id, offset, parallel_config, layers ): r""" - Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. + Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. - Args: - network(Cell) - Represents the transformer block - layer_id(int) - Means the layer index for the current module, counts from zero. - offset(int) - Means the layer_index needs a offset, if there are other modules in the net. - layers(int) - The total layers used for the model. + Args: + network(Cell) - Represents the transformer block + layer_id(int) - Means the layer index for the current module, counts from zero. + offset(int) - Means the layer_index needs a offset, if there are other modules in the net. + layers(int) - The total layers used for the model. """ # Used for the pipeline's stages setting # As the final layer is not included here, so we need to manually add here. @@ -301,9 +296,7 @@ def __init__(self, config): self.is_pipeline = config.parallel_config.pipeline_stage > 1 self.embedding = EmbeddingLayer(config) self.config = config - self.layernorm = _LayerNorm((config.hidden_size,)).to_float( - mstype.float32 - ) + self.layernorm = _LayerNorm((config.hidden_size,)).to_float(mstype.float32) if config.parallel_config.pipeline_stage > 1: self.layernorm.set_comm_fusion(2) else: @@ -311,35 +304,35 @@ def __init__(self, config): config.parallel_config.gradient_aggregation_group ) self.layernorm.shard(((config.parallel_config.data_parallel, 1),)) - self.layernorm.pipeline_stage = ( - config.parallel_config.pipeline_stage - 1 - ) + self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1 # Configure the shard configure of the Embedding layer self.embedding.pipeline_stage = 0 self.num_layers = config.num_layers if config.use_moe: moe_config = MoEConfig( expert_num=config.parallel_config.data_parallel - * config.per_dp_dim_expert_num + * config.per_dp_dim_expert_num ) else: moe_config = MoEConfig(expert_num=1) # The shard setting of Transformer is set within the class StackedTransformer - self.blocks = TransformerEncoder(num_layers=config.num_layers - 1, - batch_size=config.batch_size, - hidden_size=config.hidden_size, - ffn_hidden_size=config.ffn_hidden_size, - num_heads=config.num_heads, - seq_length=config.seq_length, - attention_dropout_rate=config.dropout_rate, - hidden_dropout_rate=config.dropout_rate, - lambda_func=set_parallel_configure_for_layer, - hidden_act="fast_gelu", - param_init_type=config.param_init_type, - use_past=config.use_past, - parallel_config=config.parallel_config, - moe_config=moe_config, - softmax_compute_type=config.softmax_compute_type).blocks + self.blocks = TransformerEncoder( + num_layers=config.num_layers - 1, + batch_size=config.batch_size, + hidden_size=config.hidden_size, + ffn_hidden_size=config.ffn_hidden_size, + num_heads=config.num_heads, + seq_length=config.seq_length, + attention_dropout_rate=config.dropout_rate, + hidden_dropout_rate=config.dropout_rate, + lambda_func=set_parallel_configure_for_layer, + hidden_act="fast_gelu", + param_init_type=config.param_init_type, + use_past=config.use_past, + parallel_config=config.parallel_config, + moe_config=moe_config, + softmax_compute_type=config.softmax_compute_type, + ).blocks for block in self.blocks: block.attention.dense1.bias.parallel_optimizer = False block.attention.dense2.bias.parallel_optimizer = False @@ -347,37 +340,47 @@ def __init__(self, config): block.output.mapping.bias.parallel_optimizer = False copied_parallel_config = copy.deepcopy(config.parallel_config) copied_parallel_config.vocab_emb_dp = True - self.top_query_embedding = VocabEmbedding(vocab_size=config.seq_length, - embedding_size=config.hidden_size, - param_init=initializer("normal", - [config.seq_length, config.hidden_size], - dtype=mstype.float32), - # dtype=config.param_init_type), - parallel_config=copied_parallel_config.embedding_dp_mp_config) - self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1 + self.top_query_embedding = VocabEmbedding( + vocab_size=config.seq_length, + embedding_size=config.hidden_size, + param_init=initializer( + "normal", [config.seq_length, config.hidden_size], dtype=mstype.float32 + ), + # dtype=config.param_init_type), + parallel_config=copied_parallel_config.embedding_dp_mp_config, + ) + self.top_query_embedding.pipeline_stage = ( + config.parallel_config.pipeline_stage - 1 + ) if config.parallel_config.pipeline_stage > 1: self.top_query_embedding.set_comm_fusion(2) else: - self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group) - - self.top_query_layer = QueryLayer(batch_size=config.batch_size, - hidden_size=config.hidden_size, - ffn_hidden_size=config.ffn_hidden_size, - num_heads=config.num_heads, - seq_length=config.seq_length, - attention_dropout_rate=config.dropout_rate, - hidden_dropout_rate=config.dropout_rate, - hidden_act=config.hidden_act, - param_init_type=config.param_init_type, - use_past=config.use_past, - parallel_config=config.parallel_config) + self.top_query_embedding.set_comm_fusion( + config.parallel_config.gradient_aggregation_group + ) + + self.top_query_layer = QueryLayer( + batch_size=config.batch_size, + hidden_size=config.hidden_size, + ffn_hidden_size=config.ffn_hidden_size, + num_heads=config.num_heads, + seq_length=config.seq_length, + attention_dropout_rate=config.dropout_rate, + hidden_dropout_rate=config.dropout_rate, + hidden_act=config.hidden_act, + param_init_type=config.param_init_type, + use_past=config.use_past, + parallel_config=config.parallel_config, + ) self.top_query_layer.attention.dense1.bias.parallel_optimizer = False self.top_query_layer.attention.dense2.bias.parallel_optimizer = False self.top_query_layer.attention.dense3.bias.parallel_optimizer = False self.top_query_layer.output.mapping.bias.parallel_optimizer = False if config.parallel_config.recompute: self.top_query_layer.recompute() - self.top_query_layer.set_comm_fusion(config.parallel_config.gradient_aggregation_group) + self.top_query_layer.set_comm_fusion( + config.parallel_config.gradient_aggregation_group + ) self.top_query_layer.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.dtype = mstype.float16 @@ -385,25 +388,37 @@ def __init__(self, config): # if config.load_ckpt_path: # self.load_embedding_from_ckpt(config.load_ckpt_path) - def construct(self, input_ids, - input_position, - encoder_masks, - init_reset=True, - batch_valid_length=None): + def construct( + self, + input_ids, + input_position, + encoder_masks, + init_reset=True, + batch_valid_length=None, + ): r"""forward pass of the model""" - embed, word_table = self.embedding(input_ids, input_position, init_reset, batch_valid_length) + embed, word_table = self.embedding( + input_ids, input_position, init_reset, batch_valid_length + ) hidden_state = P.Cast()(embed, self.dtype) if init_reset is False: hidden_state = self.reshape_to_2d(hidden_state) # encoder_mask = self.create_encoder_mask(encoder_masks) if self.blocks is not None: for i in range(self.num_layers - 1): - hidden_state, _ = self.blocks[i](hidden_state, encoder_masks, init_reset, batch_valid_length) + hidden_state, _ = self.blocks[i]( + hidden_state, encoder_masks, init_reset, batch_valid_length + ) if self.is_pipeline: top_query_hidden_states, _ = self.top_query_embedding(input_position) top_query_hidden_states = self.reshape_to_2d(top_query_hidden_states) - encoder_output, _ = self.top_query_layer(hidden_state, top_query_hidden_states, - encoder_masks, init_reset, batch_valid_length) + encoder_output, _ = self.top_query_layer( + hidden_state, + top_query_hidden_states, + encoder_masks, + init_reset, + batch_valid_length, + ) encoder_output = self.layernorm(encoder_output) else: hidden_state = self.reshape_to_2d(hidden_state) @@ -411,8 +426,13 @@ def construct(self, input_ids, encoder_output = P.Cast()(encoder_output, self.dtype) top_query_hidden_states, _ = self.top_query_embedding(input_position) top_query_hidden_states = self.reshape_to_2d(top_query_hidden_states) - encoder_output, _ = self.top_query_layer(encoder_output, top_query_hidden_states, - encoder_masks, init_reset, batch_valid_length) + encoder_output, _ = self.top_query_layer( + encoder_output, + top_query_hidden_states, + encoder_masks, + init_reset, + batch_valid_length, + ) return encoder_output, word_table @@ -440,22 +460,35 @@ def load_param(path): # three embedding needed to be loaded # Loading the embedding table from the ckpt path: - position_embedding_path = os.path.join(load_ckpt_path, 'position_embedding.npy') - word_embedding_path = os.path.join(load_ckpt_path, 'word_embedding.npy') - top_query_embedding_path = os.path.join(load_ckpt_path, 'top_query_embedding.npy') - self.embedding.word_embedding.embedding_table = Parameter(initializer(load_param(word_embedding_path), - [self.config.vocab_size, - self.config.hidden_size]), - name='word_embedding_table', parallel_optimizer=False) - self.embedding.position_embedding.embedding_table = Parameter(initializer(load_param(position_embedding_path), - [self.config.seq_length, - self.config.hidden_size]), - name='position_embedding_table', - parallel_optimizer=False) - self.top_query_embedding.embedding_table = Parameter(initializer(load_param(top_query_embedding_path), - [self.config.seq_length, - self.config.hidden_size]), - name='query_embedding_table', parallel_optimizer=False) + position_embedding_path = os.path.join(load_ckpt_path, "position_embedding.npy") + word_embedding_path = os.path.join(load_ckpt_path, "word_embedding.npy") + top_query_embedding_path = os.path.join( + load_ckpt_path, "top_query_embedding.npy" + ) + self.embedding.word_embedding.embedding_table = Parameter( + initializer( + load_param(word_embedding_path), + [self.config.vocab_size, self.config.hidden_size], + ), + name="word_embedding_table", + parallel_optimizer=False, + ) + self.embedding.position_embedding.embedding_table = Parameter( + initializer( + load_param(position_embedding_path), + [self.config.seq_length, self.config.hidden_size], + ), + name="position_embedding_table", + parallel_optimizer=False, + ) + self.top_query_embedding.embedding_table = Parameter( + initializer( + load_param(top_query_embedding_path), + [self.config.seq_length, self.config.hidden_size], + ), + name="query_embedding_table", + parallel_optimizer=False, + ) class PanguAlphaModel(nn.Cell): @@ -483,12 +516,21 @@ def __init__(self, config): ) self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.backbone = PanguAlpha_Model(config) - self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage) + self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage( + self.head.pipeline_stage + ) - def construct(self, input_ids, input_position, attention_mask, - init_reset=True, batch_valid_length=None): - output_states, word_table = self.backbone(input_ids, input_position, attention_mask, - init_reset, batch_valid_length) + def construct( + self, + input_ids, + input_position, + attention_mask, + init_reset=True, + batch_valid_length=None, + ): + output_states, word_table = self.backbone( + input_ids, input_position, attention_mask, init_reset, batch_valid_length + ) logits = self.head(output_states, word_table) return logits @@ -528,9 +570,12 @@ def construct(self, input_ids, input_position=None, attention_mask=None): r"""Forward process of the pangu alpha model""" tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) # P.Print()("==net tokens is:", tokens) - input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1)) - decoder_attention_masks = self.slice2(attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len), - (1, 1, 1)) + input_position = self.slice( + input_position, (0, 0), (self.batch_size, self.len), (1, 1) + ) + decoder_attention_masks = self.slice2( + attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len), (1, 1, 1) + ) input_mask = F.cast(self.not_equal(tokens, self.eod_token), mstype.float32) logits = self.network(tokens, input_position, decoder_attention_masks) # P.Print()("==logits_is:", logits, ",shape is:", logits.shape) @@ -569,10 +614,14 @@ def __init__(self, backbone, generate=False, pad_token=6, seq_length=2048): self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),)) self.get_attention_mask = AttentionMask(seq_length) self.expand = P.ExpandDims().shard(((1, 1, 1),)) - self.all_ones_attention_mask = Tensor(np.ones((1, 1, seq_length)), mstype.float32) + self.all_ones_attention_mask = Tensor( + np.ones((1, 1, seq_length)), mstype.float32 + ) self.not_equal = P.NotEqual().shard(((1, 1), ())) - def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None): + def construct( + self, input_ids, current_index, init_reset=True, batch_valid_length=None + ): """evaluation net""" # input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32) input_mask = F.cast(self.not_equal(input_ids, self.pad_token), mstype.float32) @@ -583,9 +632,12 @@ def construct(self, input_ids, current_index, init_reset=True, batch_valid_lengt attention_mask = self.get_attention_mask(input_mask) input_position = F.tuple_to_array(F.make_range(seq_length)) input_position = P.Tile()(input_position, (bs, 1)) - logits = self.backbone(input_ids, input_position, attention_mask, - init_reset, batch_valid_length) - index = current_index.view(-1, ) + logits = self.backbone( + input_ids, input_position, attention_mask, init_reset, batch_valid_length + ) + index = current_index.view( + -1, + ) # P.Print()("==logits_is:", logits, ",shape is:", logits.shape) # P.Print()("==index_is:", index, ",shape is:", index.shape) logits = self.gather(logits, index, 0) @@ -618,10 +670,14 @@ def __init__(self, backbone, generate=False, pad_token=6, seq_length=2048): self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),)) self.get_attention_mask = AttentionMask(seq_length) self.expand = P.ExpandDims().shard(((1, 1, 1),)) - self.all_ones_attention_mask = Tensor(np.ones((1, 1, seq_length)), mstype.float32) + self.all_ones_attention_mask = Tensor( + np.ones((1, 1, seq_length)), mstype.float32 + ) self.not_equal = P.NotEqual().shard(((1, 1), ())) - def construct(self, input_ids, init_reset=True, batch_valid_length=None, attention_mask=None): + def construct( + self, input_ids, init_reset=True, batch_valid_length=None, attention_mask=None + ): """evaluation net""" # input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32) input_mask = F.cast(self.not_equal(input_ids, self.pad_token), mstype.float32) @@ -633,8 +689,9 @@ def construct(self, input_ids, init_reset=True, batch_valid_length=None, attenti attention_mask = self.get_attention_mask(input_mask) input_position = F.tuple_to_array(F.make_range(seq_length)) input_position = P.Tile()(input_position, (bs, 1)) - logits = self.backbone(input_ids, input_position, attention_mask, - init_reset, batch_valid_length) + logits = self.backbone( + input_ids, input_position, attention_mask, init_reset, batch_valid_length + ) return logits @@ -674,9 +731,12 @@ def construct(self, input_ids, loss_mask, input_position, attention_mask): r"""Forward process of the pangu alpha model""" tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) # P.Print()("==net tokens is:", tokens) - input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1)) - decoder_attention_masks = self.slice2(attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len), - (1, 1, 1)) + input_position = self.slice( + input_position, (0, 0), (self.batch_size, self.len), (1, 1) + ) + decoder_attention_masks = self.slice2( + attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len), (1, 1, 1) + ) input_mask = F.cast(self.not_equal(tokens, self.eod_token), mstype.float32) logits = self.network(tokens, input_position, decoder_attention_masks) # P.Print()("===logits: ", logits, ", shape: ", logits.shape) diff --git a/codegeex/mindspore/src/pangu_alpha_config.py b/codegeex/mindspore/src/pangu_alpha_config.py index eff164e..ee31a34 100644 --- a/codegeex/mindspore/src/pangu_alpha_config.py +++ b/codegeex/mindspore/src/pangu_alpha_config.py @@ -23,28 +23,29 @@ class PanguAlphaConfig: PanGUConfig config class which defines the model size """ - def __init__(self, - batch_size=32, - seq_length=2048, - vocab_size=40000, - hidden_size=768, - ffn_hidden_size=768, - num_layers=12, - num_heads=12, - load_ckpt_path=None, - param_init_type=mstype.float32, - post_layernorm_residual=False, - dropout_rate=0.1, - eod_token=50256, - use_past=False, - hidden_act="fast_gelu", - eod_reset=True, - enable_offload=False, - use_moe=False, - per_dp_dim_expert_num=4, - parallel_config=None, - softmax_compute_type=mstype.float16, - ): + def __init__( + self, + batch_size=32, + seq_length=2048, + vocab_size=40000, + hidden_size=768, + ffn_hidden_size=768, + num_layers=12, + num_heads=12, + load_ckpt_path=None, + param_init_type=mstype.float32, + post_layernorm_residual=False, + dropout_rate=0.1, + eod_token=50256, + use_past=False, + hidden_act="fast_gelu", + eod_reset=True, + enable_offload=False, + use_moe=False, + per_dp_dim_expert_num=4, + parallel_config=None, + softmax_compute_type=mstype.float16, + ): self.batch_size = batch_size self.seq_length = seq_length self.vocab_size = vocab_size @@ -79,7 +80,7 @@ def __str__(self): def set_parse(args_opt): r""" - Set config according to the mode + Set config according to the mode """ if args_opt.mode == "200B": args_opt.embedding_size = 16384 diff --git a/codegeex/mindspore/src/pangu_alpha_fp16_predict.py b/codegeex/mindspore/src/pangu_alpha_fp16_predict.py index 8948224..07f0502 100644 --- a/codegeex/mindspore/src/pangu_alpha_fp16_predict.py +++ b/codegeex/mindspore/src/pangu_alpha_fp16_predict.py @@ -86,9 +86,7 @@ def __init__(self, config): self.use_past = config.use_past self.batch_size = config.batch_size - def construct( - self, input_ids, input_position, init_reset, batch_valid_length - ): + def construct(self, input_ids, input_position, init_reset, batch_valid_length): word_embedding, word_table = self.word_embedding(input_ids) if self.use_past and not self.is_first_iteration: _, seq_length = F.shape(input_ids) @@ -106,20 +104,20 @@ class QueryLayer(TransformerEncoderLayer): r"""Query Layer at the final layer.""" def __init__( - self, - batch_size, - hidden_size, - ffn_hidden_size, - num_heads, - seq_length, - attention_dropout_rate=0.1, - hidden_dropout_rate=0.1, - post_layernorm_residual=False, - param_init_type=mstype.float32, - hidden_act="fast_gelu", - use_past=False, - parallel_config=None, - softmax_compute_type=mstype.float32, + self, + batch_size, + hidden_size, + ffn_hidden_size, + num_heads, + seq_length, + attention_dropout_rate=0.1, + hidden_dropout_rate=0.1, + post_layernorm_residual=False, + param_init_type=mstype.float32, + hidden_act="fast_gelu", + use_past=False, + parallel_config=None, + softmax_compute_type=mstype.float32, ): super(QueryLayer, self).__init__( batch_size=batch_size, @@ -138,12 +136,12 @@ def __init__( ) def construct( - self, - x, - query_vector, - input_mask, - init_reset=True, - batch_valid_length=None, + self, + x, + query_vector, + input_mask, + init_reset=True, + batch_valid_length=None, ): r""" The forward process of the block. @@ -225,9 +223,7 @@ class PanGuHead(Cell): logits: Tensor, the logits of the corresponding inputs """ - def __init__( - self, hidden_size, compute_type=mstype.float16, parallel_config=None - ): + def __init__(self, hidden_size, compute_type=mstype.float16, parallel_config=None): super(PanGuHead, self).__init__() if parallel_config.vocab_emb_dp: self.matmul = P.MatMul(transpose_b=True).shard( @@ -248,24 +244,22 @@ def __init__( def construct(self, state, embed): state = P.Reshape()(state, (-1, self.hidden_size)) # output logits over vocabulary [bs*seq_length, vocab_size] - logits = self.matmul( - self.cast(state, self.dtype), self.cast(embed, self.dtype) - ) + logits = self.matmul(self.cast(state, self.dtype), self.cast(embed, self.dtype)) return logits def set_parallel_configure_for_layer( - network, layer_id, offset, parallel_config, layers + network, layer_id, offset, parallel_config, layers ): r""" - Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. + Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. - Args: - network(Cell) - Represents the transformer block - layer_id(int) - Means the layer index for the current module, counts from zero. - offset(int) - Means the layer_index needs a offset, if there are other modules in the net. - layers(int) - The total layers used for the model. + Args: + network(Cell) - Represents the transformer block + layer_id(int) - Means the layer index for the current module, counts from zero. + offset(int) - Means the layer_index needs a offset, if there are other modules in the net. + layers(int) - The total layers used for the model. """ # Used for the pipeline's stages setting # As the final layer is not included here, so we need to manually add here. @@ -297,9 +291,7 @@ def __init__(self, config): self.is_pipeline = config.parallel_config.pipeline_stage > 1 self.embedding = EmbeddingLayer(config) self.config = config - self.layernorm = _LayerNorm((config.hidden_size,)).to_float( - mstype.float32 - ) + self.layernorm = _LayerNorm((config.hidden_size,)).to_float(mstype.float32) if config.parallel_config.pipeline_stage > 1: self.layernorm.set_comm_fusion(2) else: @@ -307,35 +299,35 @@ def __init__(self, config): config.parallel_config.gradient_aggregation_group ) self.layernorm.shard(((config.parallel_config.data_parallel, 1),)) - self.layernorm.pipeline_stage = ( - config.parallel_config.pipeline_stage - 1 - ) + self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1 # Configure the shard configure of the Embedding layer self.embedding.pipeline_stage = 0 self.num_layers = config.num_layers if config.use_moe: moe_config = MoEConfig( expert_num=config.parallel_config.data_parallel - * config.per_dp_dim_expert_num + * config.per_dp_dim_expert_num ) else: moe_config = MoEConfig(expert_num=1) # The shard setting of Transformer is set within the class StackedTransformer - self.blocks = TransformerEncoder(num_layers=config.num_layers - 1, - batch_size=config.batch_size, - hidden_size=config.hidden_size, - ffn_hidden_size=config.ffn_hidden_size, - num_heads=config.num_heads, - seq_length=config.seq_length, - attention_dropout_rate=config.dropout_rate, - hidden_dropout_rate=config.dropout_rate, - lambda_func=set_parallel_configure_for_layer, - hidden_act="fast_gelu", - param_init_type=config.param_init_type, - use_past=config.use_past, - parallel_config=config.parallel_config, - moe_config=moe_config, - softmax_compute_type=config.softmax_compute_type).blocks + self.blocks = TransformerEncoder( + num_layers=config.num_layers - 1, + batch_size=config.batch_size, + hidden_size=config.hidden_size, + ffn_hidden_size=config.ffn_hidden_size, + num_heads=config.num_heads, + seq_length=config.seq_length, + attention_dropout_rate=config.dropout_rate, + hidden_dropout_rate=config.dropout_rate, + lambda_func=set_parallel_configure_for_layer, + hidden_act="fast_gelu", + param_init_type=config.param_init_type, + use_past=config.use_past, + parallel_config=config.parallel_config, + moe_config=moe_config, + softmax_compute_type=config.softmax_compute_type, + ).blocks for block in self.blocks: block.attention.dense1.bias.parallel_optimizer = False block.attention.dense2.bias.parallel_optimizer = False @@ -343,36 +335,48 @@ def __init__(self, config): block.output.mapping.bias.parallel_optimizer = False copied_parallel_config = copy.deepcopy(config.parallel_config) copied_parallel_config.vocab_emb_dp = True - self.top_query_embedding = VocabEmbedding(vocab_size=config.seq_length, - embedding_size=config.hidden_size, - param_init=initializer("normal", - [config.seq_length, config.hidden_size], - dtype=config.param_init_type), - parallel_config=copied_parallel_config.embedding_dp_mp_config) - self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1 + self.top_query_embedding = VocabEmbedding( + vocab_size=config.seq_length, + embedding_size=config.hidden_size, + param_init=initializer( + "normal", + [config.seq_length, config.hidden_size], + dtype=config.param_init_type, + ), + parallel_config=copied_parallel_config.embedding_dp_mp_config, + ) + self.top_query_embedding.pipeline_stage = ( + config.parallel_config.pipeline_stage - 1 + ) if config.parallel_config.pipeline_stage > 1: self.top_query_embedding.set_comm_fusion(2) else: - self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group) - - self.top_query_layer = QueryLayer(batch_size=config.batch_size, - hidden_size=config.hidden_size, - ffn_hidden_size=config.ffn_hidden_size, - num_heads=config.num_heads, - seq_length=config.seq_length, - attention_dropout_rate=config.dropout_rate, - hidden_dropout_rate=config.dropout_rate, - hidden_act=config.hidden_act, - param_init_type=config.param_init_type, - use_past=config.use_past, - parallel_config=config.parallel_config) + self.top_query_embedding.set_comm_fusion( + config.parallel_config.gradient_aggregation_group + ) + + self.top_query_layer = QueryLayer( + batch_size=config.batch_size, + hidden_size=config.hidden_size, + ffn_hidden_size=config.ffn_hidden_size, + num_heads=config.num_heads, + seq_length=config.seq_length, + attention_dropout_rate=config.dropout_rate, + hidden_dropout_rate=config.dropout_rate, + hidden_act=config.hidden_act, + param_init_type=config.param_init_type, + use_past=config.use_past, + parallel_config=config.parallel_config, + ) self.top_query_layer.attention.dense1.bias.parallel_optimizer = False self.top_query_layer.attention.dense2.bias.parallel_optimizer = False self.top_query_layer.attention.dense3.bias.parallel_optimizer = False self.top_query_layer.output.mapping.bias.parallel_optimizer = False if config.parallel_config.recompute: self.top_query_layer.recompute() - self.top_query_layer.set_comm_fusion(config.parallel_config.gradient_aggregation_group) + self.top_query_layer.set_comm_fusion( + config.parallel_config.gradient_aggregation_group + ) self.top_query_layer.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.dtype = mstype.float16 @@ -380,13 +384,18 @@ def __init__(self, config): # if config.load_ckpt_path: # self.load_embedding_from_ckpt(config.load_ckpt_path) - def construct(self, input_ids, - input_position, - encoder_masks, - init_reset=True, - batch_valid_length=None): + def construct( + self, + input_ids, + input_position, + encoder_masks, + init_reset=True, + batch_valid_length=None, + ): r"""forward pass of the model""" - embed, word_table = self.embedding(input_ids, input_position, init_reset, batch_valid_length) + embed, word_table = self.embedding( + input_ids, input_position, init_reset, batch_valid_length + ) self.print("PanguAlpha_Model: embed_0", embed) hidden_state = P.Cast()(embed, self.dtype) self.print("PanguAlpha_Model: hidden_state_0", embed) @@ -396,13 +405,20 @@ def construct(self, input_ids, # encoder_mask = self.create_encoder_mask(encoder_masks) if self.blocks is not None: for i in range(self.num_layers - 1): - hidden_state, _ = self.blocks[i](hidden_state, encoder_masks, init_reset, batch_valid_length) + hidden_state, _ = self.blocks[i]( + hidden_state, encoder_masks, init_reset, batch_valid_length + ) self.print("PanguAlpha_Model: hidden_state_", hidden_state) if self.is_pipeline: top_query_hidden_states, _ = self.top_query_embedding(input_position) top_query_hidden_states = self.reshape_to_2d(top_query_hidden_states) - encoder_output, _ = self.top_query_layer(hidden_state, top_query_hidden_states, - encoder_masks, init_reset, batch_valid_length) + encoder_output, _ = self.top_query_layer( + hidden_state, + top_query_hidden_states, + encoder_masks, + init_reset, + batch_valid_length, + ) encoder_output = self.layernorm(encoder_output) else: hidden_state = self.reshape_to_2d(hidden_state) @@ -410,8 +426,13 @@ def construct(self, input_ids, encoder_output = P.Cast()(encoder_output, self.dtype) top_query_hidden_states, _ = self.top_query_embedding(input_position) top_query_hidden_states = self.reshape_to_2d(top_query_hidden_states) - encoder_output, _ = self.top_query_layer(encoder_output, top_query_hidden_states, - encoder_masks, init_reset, batch_valid_length) + encoder_output, _ = self.top_query_layer( + encoder_output, + top_query_hidden_states, + encoder_masks, + init_reset, + batch_valid_length, + ) self.print("PanguAlpha_Model: input_position", input_position) self.print("PanguAlpha_Model: top_query_hidden_state", top_query_hidden_states) self.print("PanguAlpha_Model: encoder_output", encoder_output) @@ -443,29 +464,51 @@ def load_param(path): # three embedding needed to be loaded # Loading the embedding table from the ckpt path: - position_embedding_path = os.path.join(load_ckpt_path, 'position_embedding.npy') - word_embedding_path = os.path.join(load_ckpt_path, 'word_embedding.npy') - top_query_embedding_path = os.path.join(load_ckpt_path, 'top_query_embedding.npy') - self.embedding.word_embedding.embedding_table = Parameter(initializer(load_param(word_embedding_path), - [self.config.vocab_size, - self.config.hidden_size]), - name='word_embedding_table', parallel_optimizer=False) + position_embedding_path = os.path.join(load_ckpt_path, "position_embedding.npy") + word_embedding_path = os.path.join(load_ckpt_path, "word_embedding.npy") + top_query_embedding_path = os.path.join( + load_ckpt_path, "top_query_embedding.npy" + ) + self.embedding.word_embedding.embedding_table = Parameter( + initializer( + load_param(word_embedding_path), + [self.config.vocab_size, self.config.hidden_size], + ), + name="word_embedding_table", + parallel_optimizer=False, + ) # self.summary("load_word_embedding", self.embedding.word_embedding.embedding_table) - self.print("PanguAlpha_Model: load_word_embedding", self.embedding.word_embedding.embedding_table) - self.embedding.position_embedding.embedding_table = Parameter(initializer(load_param(position_embedding_path), - [self.config.seq_length, - self.config.hidden_size]), - name='position_embedding_table', - parallel_optimizer=False) + self.print( + "PanguAlpha_Model: load_word_embedding", + self.embedding.word_embedding.embedding_table, + ) + self.embedding.position_embedding.embedding_table = Parameter( + initializer( + load_param(position_embedding_path), + [self.config.seq_length, self.config.hidden_size], + ), + name="position_embedding_table", + parallel_optimizer=False, + ) # self.summary("load_position_embedding", self.embedding.position_embedding.embedding_table) - self.print("PanguAlpha_Model: load_position_embedding", self.embedding.position_embedding.embedding_table) + self.print( + "PanguAlpha_Model: load_position_embedding", + self.embedding.position_embedding.embedding_table, + ) - self.top_query_embedding.embedding_table = Parameter(initializer(load_param(top_query_embedding_path), - [self.config.seq_length, - self.config.hidden_size]), - name='query_embedding_table', parallel_optimizer=False) + self.top_query_embedding.embedding_table = Parameter( + initializer( + load_param(top_query_embedding_path), + [self.config.seq_length, self.config.hidden_size], + ), + name="query_embedding_table", + parallel_optimizer=False, + ) # self.summary("top_query_embedding", self.embedding.top_query_embedding.embedding_table) - self.print("PanguAlpha_Model: top_query_embedding", self.embedding.top_query_embedding.embedding_table) + self.print( + "PanguAlpha_Model: top_query_embedding", + self.embedding.top_query_embedding.embedding_table, + ) class PanguAlphaModel(nn.Cell): @@ -493,13 +536,22 @@ def __init__(self, config): ) self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.backbone = PanguAlpha_Model(config) - self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage) + self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage( + self.head.pipeline_stage + ) self.print = P.Print() - def construct(self, input_ids, input_position, attention_mask, - init_reset=True, batch_valid_length=None): - output_states, word_table = self.backbone(input_ids, input_position, attention_mask, - init_reset, batch_valid_length) + def construct( + self, + input_ids, + input_position, + attention_mask, + init_reset=True, + batch_valid_length=None, + ): + output_states, word_table = self.backbone( + input_ids, input_position, attention_mask, init_reset, batch_valid_length + ) self.print("PanguAlphaModel: output_states", output_states) self.print("PanguAlphaModel: word_table", word_table) logits = self.head(output_states, word_table) @@ -543,9 +595,12 @@ def construct(self, input_ids, input_position=None, attention_mask=None): r"""Forward process of the pangu alpha model""" tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) # P.Print()("==net tokens is:", tokens) - input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1)) - decoder_attention_masks = self.slice2(attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len), - (1, 1, 1)) + input_position = self.slice( + input_position, (0, 0), (self.batch_size, self.len), (1, 1) + ) + decoder_attention_masks = self.slice2( + attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len), (1, 1, 1) + ) input_mask = F.cast(self.not_equal(tokens, self.eod_token), mstype.float32) logits = self.network(tokens, input_position, decoder_attention_masks) # P.Print()("==logits_is:", logits, ",shape is:", logits.shape) @@ -584,10 +639,14 @@ def __init__(self, backbone, generate=False, pad_token=6, seq_length=2048): self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),)) self.get_attention_mask = AttentionMask(seq_length) self.expand = P.ExpandDims().shard(((1, 1, 1),)) - self.all_ones_attention_mask = Tensor(np.ones((1, 1, seq_length)), mstype.float32) + self.all_ones_attention_mask = Tensor( + np.ones((1, 1, seq_length)), mstype.float32 + ) self.print = P.Print() - def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None): + def construct( + self, input_ids, current_index, init_reset=True, batch_valid_length=None + ): """evaluation net""" input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32) bs, seq_length = F.shape(input_ids) @@ -600,10 +659,13 @@ def construct(self, input_ids, current_index, init_reset=True, batch_valid_lengt self.print("EvalNet: input_position_0", input_position) input_position = P.Tile()(input_position, (bs, 1)) self.print("EvalNet: input_position_1", input_position) - logits = self.backbone(input_ids, input_position, attention_mask, - init_reset, batch_valid_length) + logits = self.backbone( + input_ids, input_position, attention_mask, init_reset, batch_valid_length + ) self.print("EvalNet: logits", logits) - index = current_index.view(-1, ) + index = current_index.view( + -1, + ) self.print("EvalNet: index", index) logits = self.gather(logits, index, 0) logits = logits.view(bs, 1, -1) diff --git a/codegeex/mindspore/src/pangu_alpha_wrapcell.py b/codegeex/mindspore/src/pangu_alpha_wrapcell.py index da733a3..3b254b6 100644 --- a/codegeex/mindspore/src/pangu_alpha_wrapcell.py +++ b/codegeex/mindspore/src/pangu_alpha_wrapcell.py @@ -57,9 +57,7 @@ def _clip_grad(clip_type, clip_value, grad): F.cast(F.tuple_to_array((clip_value,)), dt), ) else: - new_grad = nn.ClipByNorm()( - grad, F.cast(F.tuple_to_array((clip_value,)), dt) - ) + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) return new_grad @@ -105,14 +103,16 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): """ def __init__( - self, - network, - optimizer, - scale_update_cell=None, - enable_global_norm=False, - config=None, + self, + network, + optimizer, + scale_update_cell=None, + enable_global_norm=False, + config=None, ): - super(PanguAlphaTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) + super(PanguAlphaTrainOneStepWithLossScaleCell, self).__init__( + network, optimizer, scale_update_cell + ) self.network = network self.config = config self.weights = optimizer.parameters @@ -127,7 +127,9 @@ def __init__( self.clip = ClipByGlobalNorm(self.weights, config) self.cast = P.Cast() - def construct(self, input_ids, input_position, attention_mask, layer_past=None, sens=None): + def construct( + self, input_ids, input_position, attention_mask, layer_past=None, sens=None + ): """Defines the computation performed.""" weights = self.weights # Forward process @@ -139,8 +141,8 @@ def construct(self, input_ids, input_position, attention_mask, layer_past=None, scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) # Backward process using loss scale grads = self.grad(self.network, weights)( - input_ids, input_position, attention_mask, - scaling_sens_filled) + input_ids, input_position, attention_mask, scaling_sens_filled + ) # apply grad reducer on grads grads = self.grad_reducer(grads) @@ -150,8 +152,8 @@ def construct(self, input_ids, input_position, attention_mask, layer_past=None, grads, clip_value = self.clip(grads) else: grads = self.hyper_map( - F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), - grads) + F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads + ) # Check whether overflow cond = self.get_overflow_status(status, grads) overflow = self.process_loss_scale(cond) @@ -178,8 +180,17 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell): scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ - def __init__(self, network, optimizer, config, scale_update_cell=None, enable_global_norm=True): - super(PanguAlphaTrainPipelineWithLossScaleCell, self).__init__(auto_prefix=False) + def __init__( + self, + network, + optimizer, + config, + scale_update_cell=None, + enable_global_norm=True, + ): + super(PanguAlphaTrainPipelineWithLossScaleCell, self).__init__( + auto_prefix=False + ) self.config = config self.network = network self.network.add_flags(defer_inline=True) @@ -191,14 +202,19 @@ def __init__(self, network, optimizer, config, scale_update_cell=None, enable_gl self.reducer_flag = False self.allreduce = P.AllReduce() self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + if self.parallel_mode in [ + ParallelMode.DATA_PARALLEL, + ParallelMode.HYBRID_PARALLEL, + ]: self.reducer_flag = True self.grad_reducer = F.identity self.degree = 1 if self.reducer_flag: self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.grad_reducer = DistributedGradReducer( + optimizer.parameters, False, self.degree + ) + self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() @@ -211,20 +227,22 @@ def __init__(self, network, optimizer, config, scale_update_cell=None, enable_gl self.reshape = P.Reshape() self.loss_scaling_manager = scale_update_cell if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") + self.loss_scale = Parameter( + Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale", + ) self.clip = ClipByGlobalNorm(self.weights, self.config) self.micro_size = config.parallel_config.micro_batch_num self.opt_shard = _get_enable_parallel_optimizer() @C.add_flags(has_effect=True) def construct( - self, - input_ids, - input_position, - attention_mask, - past=None, - sens=None, + self, + input_ids, + input_position, + attention_mask, + past=None, + sens=None, ): """Defines the computation performed.""" weights = self.weights @@ -251,16 +269,22 @@ def construct( # apply grad reducer on grads if self.opt_shard: grads = self.grad_reducer(grads) - grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads) + grads = self.hyper_map( + F.partial(shard_grad_scale, scaling_sens * self.degree), + grads, + self.accu_grads, + ) else: accu_grads = self.grad_reducer(self.accu_grads) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) + grads = self.hyper_map( + F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads + ) if self.enable_global_norm: grads, _ = self.clip(grads) else: grads = self.hyper_map( - F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), - grads) + F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads + ) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) diff --git a/codegeex/mindspore/src/pangu_alpha_wrapcell_finetune.py b/codegeex/mindspore/src/pangu_alpha_wrapcell_finetune.py index faebd0e..2e2f2e7 100644 --- a/codegeex/mindspore/src/pangu_alpha_wrapcell_finetune.py +++ b/codegeex/mindspore/src/pangu_alpha_wrapcell_finetune.py @@ -57,9 +57,7 @@ def _clip_grad(clip_type, clip_value, grad): F.cast(F.tuple_to_array((clip_value,)), dt), ) else: - new_grad = nn.ClipByNorm()( - grad, F.cast(F.tuple_to_array((clip_value,)), dt) - ) + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) return new_grad @@ -105,14 +103,16 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): """ def __init__( - self, - network, - optimizer, - scale_update_cell=None, - enable_global_norm=False, - config=None, + self, + network, + optimizer, + scale_update_cell=None, + enable_global_norm=False, + config=None, ): - super(PanguAlphaTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) + super(PanguAlphaTrainOneStepWithLossScaleCell, self).__init__( + network, optimizer, scale_update_cell + ) self.network = network self.config = config self.weights = optimizer.parameters @@ -139,8 +139,8 @@ def construct(self, input_ids, loss_mask, input_position, attention_mask): scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) # Backward process using loss scale grads = self.grad(self.network, weights)( - input_ids, loss_mask, input_position, attention_mask, - scaling_sens_filled) + input_ids, loss_mask, input_position, attention_mask, scaling_sens_filled + ) # apply grad reducer on grads grads = self.grad_reducer(grads) @@ -150,8 +150,8 @@ def construct(self, input_ids, loss_mask, input_position, attention_mask): grads, clip_value = self.clip(grads) else: grads = self.hyper_map( - F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), - grads) + F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads + ) # Check whether overflow cond = self.get_overflow_status(status, grads) overflow = self.process_loss_scale(cond) @@ -178,8 +178,17 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell): scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ - def __init__(self, network, optimizer, config, scale_update_cell=None, enable_global_norm=True): - super(PanguAlphaTrainPipelineWithLossScaleCell, self).__init__(auto_prefix=False) + def __init__( + self, + network, + optimizer, + config, + scale_update_cell=None, + enable_global_norm=True, + ): + super(PanguAlphaTrainPipelineWithLossScaleCell, self).__init__( + auto_prefix=False + ) self.config = config self.network = network self.network.add_flags(defer_inline=True) @@ -191,14 +200,19 @@ def __init__(self, network, optimizer, config, scale_update_cell=None, enable_gl self.reducer_flag = False self.allreduce = P.AllReduce() self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + if self.parallel_mode in [ + ParallelMode.DATA_PARALLEL, + ParallelMode.HYBRID_PARALLEL, + ]: self.reducer_flag = True self.grad_reducer = F.identity self.degree = 1 if self.reducer_flag: self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.grad_reducer = DistributedGradReducer( + optimizer.parameters, False, self.degree + ) + self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() @@ -211,21 +225,23 @@ def __init__(self, network, optimizer, config, scale_update_cell=None, enable_gl self.reshape = P.Reshape() self.loss_scaling_manager = scale_update_cell if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") + self.loss_scale = Parameter( + Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale", + ) self.clip = ClipByGlobalNorm(self.weights, self.config) self.micro_size = config.parallel_config.micro_batch_num self.opt_shard = _get_enable_parallel_optimizer() @C.add_flags(has_effect=True) def construct( - self, - input_ids, - loss_mask, - input_position, - attention_mask, - past=None, - sens=None, + self, + input_ids, + loss_mask, + input_position, + attention_mask, + past=None, + sens=None, ): """Defines the computation performed.""" weights = self.weights @@ -253,16 +269,22 @@ def construct( # apply grad reducer on grads if self.opt_shard: grads = self.grad_reducer(grads) - grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads) + grads = self.hyper_map( + F.partial(shard_grad_scale, scaling_sens * self.degree), + grads, + self.accu_grads, + ) else: accu_grads = self.grad_reducer(self.accu_grads) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) + grads = self.hyper_map( + F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads + ) if self.enable_global_norm: grads, _ = self.clip(grads) else: grads = self.hyper_map( - F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), - grads) + F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads + ) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) diff --git a/codegeex/mindspore/src/preprocess.py b/codegeex/mindspore/src/preprocess.py index e19a602..527290a 100644 --- a/codegeex/mindspore/src/preprocess.py +++ b/codegeex/mindspore/src/preprocess.py @@ -32,7 +32,7 @@ def chunks(lst, n): """yield n sized chunks from list""" for i in range(0, len(lst), n): - yield lst[i: i + n] + yield lst[i : i + n] def package_file(it, n): @@ -94,9 +94,7 @@ def tokenize_openwebtext(tokenizer, iterator, seq_length, eot): for para in f.read().split("\n\n"): if para: tokenized_text = tokenizer.tokenize(para) - content += tokenizer.convert_tokens_to_ids( - tokenized_text - ) + [eot] + content += tokenizer.convert_tokens_to_ids(tokenized_text) + [eot] for chunk in chunks(content, seq_length): sample = {} if len(chunk) == seq_length: @@ -111,9 +109,7 @@ def tokenize_wiki(tokenizer, file_path, seq_length, eot): for para in clean_wikitext(f.read()).split("\n\n"): if para and para.strip().startswith("=") is False: tokenized_text = tokenizer.tokenize(para) - content += tokenizer.convert_tokens_to_ids(tokenized_text) + [ - eot - ] + content += tokenizer.convert_tokens_to_ids(tokenized_text) + [eot] for chunk in chunks(content, seq_length): sample = {} if len(chunk) == seq_length: @@ -160,9 +156,7 @@ def task_unit(iterator, tokenizer, seq_length, eot, parallel_writer=True): print("Process {} transformed {} records.".format(index, count)) except StopIteration: if data_batch: - writer.write_raw_data( - data_batch, parallel_writer=parallel_writer - ) + writer.write_raw_data(data_batch, parallel_writer=parallel_writer) print("Process {} transformed {} records.".format(index, count)) break @@ -194,9 +188,7 @@ def task_unit(iterator, tokenizer, seq_length, eot, parallel_writer=True): schema = { args.data_column_name: {"type": "int32", "shape": [-1]}, } - writer = FileWriter( - file_name=args.output_file, shard_num=args.file_partition - ) + writer = FileWriter(file_name=args.output_file, shard_num=args.file_partition) writer.add_schema(schema, args.dataset_type) writer.open_and_set_header() @@ -219,14 +211,14 @@ def task_unit(iterator, tokenizer, seq_length, eot, parallel_writer=True): transforms_count = 0 if args.dataset_type == "wiki": for x in tokenize_wiki( - word_tokenizer, args.input_glob, args.seq_length, args.eot + word_tokenizer, args.input_glob, args.seq_length, args.eot ): transforms_count += 1 writer.write_raw_data([x]) print("Transformed {} records.".format(transforms_count)) elif args.dataset_type == "lambada": for x in tokenize_lambada( - word_tokenizer, args.input_glob, args.seq_length, args.eot + word_tokenizer, args.input_glob, args.seq_length, args.eot ): transforms_count += 1 writer.write_raw_data([x]) @@ -242,9 +234,7 @@ def task_unit(iterator, tokenizer, seq_length, eot, parallel_writer=True): ) pool.map(map_func, package_file(file_iter, args.file_batch_size)) else: - raise ValueError( - "Not support dataset type: {}".format(args.dataset_type) - ) + raise ValueError("Not support dataset type: {}".format(args.dataset_type)) writer.commit() out_file = args.output_file diff --git a/codegeex/mindspore/src/sat_dataset.py b/codegeex/mindspore/src/sat_dataset.py index bb6295e..4093f23 100644 --- a/codegeex/mindspore/src/sat_dataset.py +++ b/codegeex/mindspore/src/sat_dataset.py @@ -64,20 +64,27 @@ def __len__(self): def __getitem__(self, idx): item = self.dataset[idx][0] - return (item[:self.seq_len],) if self.seq_len <= len(item) else ( - np.concatenate((item, np.ones(self.seq_len - len(item)) * self.eod_id), axis=0),) + return ( + (item[: self.seq_len],) + if self.seq_len <= len(item) + else ( + np.concatenate( + (item, np.ones(self.seq_len - len(item)) * self.eod_id), axis=0 + ), + ) + ) # return (np.pad(item, (0, 1), constant_values=self.eod_id),) class BinaryDataset(Dataset): def __init__( - self, - path, - process_fn, - length_per_sample=64 + 1024 + 4096, - dtype="int32", - preload=False, - **kwargs, + self, + path, + process_fn, + length_per_sample=64 + 1024 + 4096, + dtype="int32", + preload=False, + **kwargs, ): # TODO ARGS assert length_per_sample is not None self.length_per_sample = length_per_sample @@ -169,7 +176,9 @@ def __getitem__(self, idx): else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] sample_idx = sample_idx % len(self.datasets[dataset_idx]) - return tuple(col.astype(np.int64) for col in self.datasets[dataset_idx][sample_idx]) + return tuple( + col.astype(np.int64) for col in self.datasets[dataset_idx][sample_idx] + ) class RandomMappingDataset(Dataset): @@ -209,20 +218,25 @@ def __init__(self, ds, indices, block_size): self.wrapped_data = ds self.wrapped_data_len = len(ds) self.indices = indices - self.len = len(indices) * (len(ds) // block_size) + np.sum(indices < (len(ds) % block_size)) + self.len = len(indices) * (len(ds) // block_size) + np.sum( + indices < (len(ds) % block_size) + ) def __len__(self): return self.len def __getitem__(self, index): return self.wrapped_data[ - (index // len(self.indices)) * self.block_size + self.indices[index % len(self.indices)] - ] + (index // len(self.indices)) * self.block_size + + self.indices[index % len(self.indices)] + ] class SubsetDataset(Dataset): def __init__(self, ds, start, length): - assert start >= 0 and length > 0 and start + length <= len(ds), "Illegal start or length" + assert ( + start >= 0 and length > 0 and start + length <= len(ds) + ), "Illegal start or length" self.ds = ds self.start = start self.length = length @@ -256,6 +270,7 @@ def split_train_val_test(ds, split=[0.99, 0.01, 0.0], seed=None): start_idx += proportion return rtn_ds + # def split_ds(ds, split=[0.99, 0.01, 0.0], seed=1): # """ # Split a dataset into subsets given proportions of how diff --git a/codegeex/mindspore/src/tokenization_jieba.py b/codegeex/mindspore/src/tokenization_jieba.py index 3f047eb..49507f3 100644 --- a/codegeex/mindspore/src/tokenization_jieba.py +++ b/codegeex/mindspore/src/tokenization_jieba.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for OpenAI GPT.""" -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import absolute_import, division, print_function, unicode_literals from io import open @@ -22,18 +21,18 @@ import sentencepiece as spm -class JIEBATokenizer(): +class JIEBATokenizer: r""" Jieba Tokenizer """ def __init__(self, vocab_file, model_file, max_len=None): self.max_len = max_len if max_len is not None else int(1e12) - f = open(vocab_file, 'r') + f = open(vocab_file, "r") lines = f.readlines() self.encoder = {} for line in enumerate(lines): - key = line[1].split('\t')[0] + key = line[1].split("\t")[0] self.encoder[key] = line[0] self.decoder = {v: k for k, v in self.encoder.items()} @@ -41,9 +40,9 @@ def __init__(self, vocab_file, model_file, max_len=None): self.sp = spm.SentencePieceProcessor(model_file=model_file) self.translator = str.maketrans(" \n", "\u2582\u2583") - self.eod_id = self.encoder[''] - self.eot_id = self.encoder[''] - self.pad_id = self.encoder[''] + self.eod_id = self.encoder[""] + self.eot_id = self.encoder[""] + self.pad_id = self.encoder[""] @property def vocab_size(self): @@ -57,8 +56,10 @@ def eod(self): return self.eod_id def tokenize(self, text): - """ Tokenize a string. """ - seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)] + """Tokenize a string.""" + seg_list = [ + x.translate(self.translator) for x in jieba.cut(text, cut_all=False) + ] new_seg = " ".join(seg_list) return self.sp.encode(new_seg) @@ -74,5 +75,5 @@ def encode(self, text): def decode(self, tokens): text = self.sp.decode(tokens) - text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n') + text = text.replace(" ", "").replace("\u2582", " ").replace("\u2583", "\n") return text diff --git a/codegeex/mindspore/src/utils.py b/codegeex/mindspore/src/utils.py index 767fc5d..70694fc 100644 --- a/codegeex/mindspore/src/utils.py +++ b/codegeex/mindspore/src/utils.py @@ -30,7 +30,12 @@ from mindspore.common.tensor import Tensor from mindspore.communication.management import get_rank, get_group_size, create_group from mindspore.nn import AdamWeightDecay -from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR +from mindspore.nn.learning_rate_schedule import ( + LearningRateSchedule, + PolynomialDecayLR, + WarmUpLR, + CosineDecayLR, +) from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P @@ -39,36 +44,51 @@ class FP32StateAdamWeightDecay(AdamWeightDecay): r""" - This class is almost same with the mindspore's AdamWeightDecay implements, the - only difference is the optimizer's state will be always initialized with float32, - where the original AdamWeightDecay will initialize the optimizer's state with float16, - if the parameters are initialized with fp16. - This setting will avoid overflow in training PanGu-Alpha model using fp16. + This class is almost same with the mindspore's AdamWeightDecay implements, the + only difference is the optimizer's state will be always initialized with float32, + where the original AdamWeightDecay will initialize the optimizer's state with float16, + if the parameters are initialized with fp16. + This setting will avoid overflow in training PanGu-Alpha model using fp16. """ - def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): - super(FP32StateAdamWeightDecay, self).__init__(params, learning_rate=learning_rate, - beta1=beta1, - beta2=beta2, - eps=eps, - weight_decay=weight_decay) - - self.moments1 = self.clone_state(self.parameters, prefix='adam_m', init='zeros') - self.moments2 = self.clone_state(self.parameters, prefix='adam_v', init='zeros') + def __init__( + self, + params, + learning_rate=1e-3, + beta1=0.9, + beta2=0.999, + eps=1e-6, + weight_decay=0.0, + ): + super(FP32StateAdamWeightDecay, self).__init__( + params, + learning_rate=learning_rate, + beta1=beta1, + beta2=beta2, + eps=eps, + weight_decay=weight_decay, + ) + + self.moments1 = self.clone_state(self.parameters, prefix="adam_m", init="zeros") + self.moments2 = self.clone_state(self.parameters, prefix="adam_v", init="zeros") def clone_state(self, parameter_tuple, prefix, init): r""" - parameter_tuple: ParameterTuple. The parameters of the network - prefix: str. The prefix name of the parameters - init: str. The initialization method + parameter_tuple: ParameterTuple. The parameters of the network + prefix: str. The prefix name of the parameters + init: str. The initialization method """ new = [] for old_param in parameter_tuple: - new_state = Parameter(initializer(init, shape=old_param.shape, dtype=mstype.float32)) + new_state = Parameter( + initializer(init, shape=old_param.shape, dtype=mstype.float32) + ) new_state.param_info = old_param.param_info.clone() new_state.is_init = False - new_state.set_data(initializer(init, shape=old_param.shape, dtype=mstype.float32)) - new_state.name = prefix + '.' + new_state.name + new_state.set_data( + initializer(init, shape=old_param.shape, dtype=mstype.float32) + ) + new_state.name = prefix + "." + new_state.name new.append(new_state) return ParameterTuple(new) @@ -109,7 +129,9 @@ def _get_model_parallel_group(mp): local_stage_rank_id = rank % per_stage_device_nums index = local_stage_rank_id // mp group = range(0, mp) - rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group] + rank_str_list = [ + str(x + index * mp + stage_id * per_stage_device_nums) for x in group + ] rank_list_str = "-".join(rank_str_list) rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group] return rank_list, rank_list_str @@ -128,7 +150,9 @@ def _get_pipeline_group(): local_stage_rank_id = rank % per_stage_device_nums group = range(0, stage_nums) rank_list = [local_stage_rank_id + x * per_stage_device_nums for x in group] - rank_str_list = [str(local_stage_rank_id + x * per_stage_device_nums) for x in group] + rank_str_list = [ + str(local_stage_rank_id + x * per_stage_device_nums) for x in group + ] rank_list_str = "-".join(rank_str_list) return rank_list, rank_list_str @@ -143,7 +167,9 @@ def __init__(self, params, config): self.norm = nn.Norm() self.hyper_map = C.HyperMap() self.is_pipeline = context.get_auto_parallel_context("pipeline_stages") > 1 - optimizer_weight_shard_size = context.get_auto_parallel_context("optimizer_weight_shard_size") + optimizer_weight_shard_size = context.get_auto_parallel_context( + "optimizer_weight_shard_size" + ) if self.is_pipeline: if context.get_auto_parallel_context("enable_parallel_optimizer"): group_size = get_group_size() // config.parallel_config.pipeline_stage @@ -160,14 +186,19 @@ def __init__(self, params, config): self.allreduce = P.AllReduce(group=group_name) pipeline_group_list, pipeline_group_name = _get_pipeline_group() hashed = hashlib.md5(pipeline_group_name.encode()).hexdigest()[:48] - print(f"Creating hash value for the group_name hash({pipeline_group_name})={hashed}") + print( + f"Creating hash value for the group_name hash({pipeline_group_name})={hashed}" + ) pipeline_group_name = str(hashed) create_group(pipeline_group_name, pipeline_group_list) self.allreduce2 = P.AllReduce(group=pipeline_group_name) else: opt_shard_size = config.parallel_config.data_parallel mp = config.parallel_config.model_parallel - if context.get_auto_parallel_context("enable_parallel_optimizer") and optimizer_weight_shard_size > 0: + if ( + context.get_auto_parallel_context("enable_parallel_optimizer") + and optimizer_weight_shard_size > 0 + ): opt_shard_size = optimizer_weight_shard_size group_size = opt_shard_size * mp world_size = get_group_size() @@ -178,17 +209,31 @@ def __init__(self, params, config): self.allreduce_group_size = () for x in params: - if "projection.bias" not in x.name and "layernorm" not in x.name and "embedding_table" not in x.name: - self.allreduce_group_size = self.allreduce_group_size + (dense_repeat_num * 1.0,) + if ( + "projection.bias" not in x.name + and "layernorm" not in x.name + and "embedding_table" not in x.name + ): + self.allreduce_group_size = self.allreduce_group_size + ( + dense_repeat_num * 1.0, + ) elif "embedding_table" not in x.name: - self.allreduce_group_size = self.allreduce_group_size + (layernorm_and_bias_repeat_num * 1.0,) + self.allreduce_group_size = self.allreduce_group_size + ( + layernorm_and_bias_repeat_num * 1.0, + ) else: - if not config.parallel_config.vocab_emb_dp and "position_embedding.embedding_table" not in x.name \ - and "top_query_embedding_table" not in x.name: - self.allreduce_group_size = self.allreduce_group_size + \ - (word_embbedding_repeat_num * 1.0,) + if ( + not config.parallel_config.vocab_emb_dp + and "position_embedding.embedding_table" not in x.name + and "top_query_embedding_table" not in x.name + ): + self.allreduce_group_size = self.allreduce_group_size + ( + word_embbedding_repeat_num * 1.0, + ) else: - self.allreduce_group_size = self.allreduce_group_size + (position_embedding_repeat_num * 1.0,) + self.allreduce_group_size = self.allreduce_group_size + ( + position_embedding_repeat_num * 1.0, + ) def construct(self, grads): """Calculate global norm construct""" @@ -225,7 +270,12 @@ def construct(self, grads): grads, global_norm_value = self.global_norm(grads) cond = P.GreaterEqual()(global_norm_value, self.clip_norm) global_norm = F.select(cond, global_norm_value, self.clip_norm) - grads = self.hyper_map(F.partial(apply_global_norm, self.enable_grad_fp16, self.clip_norm, global_norm), grads) + grads = self.hyper_map( + F.partial( + apply_global_norm, self.enable_grad_fp16, self.clip_norm, global_norm + ), + grads, + ) return grads, global_norm_value @@ -234,22 +284,26 @@ class LearningRate(LearningRateSchedule): Warmup-decay learning rate for PanguAlpha network. """ - def __init__(self, - learning_rate, - end_learning_rate, - warmup_steps, - decay_steps, - power=1.0, - use_cosine=True): + def __init__( + self, + learning_rate, + end_learning_rate, + warmup_steps, + decay_steps, + power=1.0, + use_cosine=True, + ): super(LearningRate, self).__init__() self.warmup_flag = False if warmup_steps > 0: self.warmup_flag = True self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) - self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, - decay_steps, power) - self.cosine_decay_lr = CosineDecayLR(end_learning_rate, learning_rate, - decay_steps) + self.decay_lr = PolynomialDecayLR( + learning_rate, end_learning_rate, decay_steps, power + ) + self.cosine_decay_lr = CosineDecayLR( + end_learning_rate, learning_rate, decay_steps + ) self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) self.greater = P.Greater() @@ -265,8 +319,9 @@ def construct(self, global_step): else: decay_lr = self.cosine_decay_lr(global_step) if self.warmup_flag: - is_warmup = self.cast(self.greater(self.warmup_steps, global_step), - mstype.float32) + is_warmup = self.cast( + self.greater(self.warmup_steps, global_step), mstype.float32 + ) warmup_lr = self.warmup_lr(global_step) lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr else: @@ -277,286 +332,335 @@ def construct(self, global_step): def add_inference_params(opt): """Add inference params""" - opt.add_argument("--frequency_penalty", - type=float, - default=1.5, - help="coefficient for frequency_penalty") - opt.add_argument("--presence_penalty", - type=float, - default=0.3, - help="coefficient for presence_penalty") - opt.add_argument("--max_generate_length", - type=int, - default=2048, - help="the maximum number of generated token") - opt.add_argument("--top_k_num", - type=int, - default=3, - help="the number for top_k sampling") - opt.add_argument("--top_p", - type=float, - default=1.0, - help="top_p sampling threshold, enabled if less than 1.0") - opt.add_argument("--end_token", - type=int, - default=50256, - help="the token id for ") - opt.add_argument("--use_pynative_op", - type=int, - default=0, - help="Whether use pynative op for postproecess") - opt.add_argument("--use_past", - type=str, - default="true", - choices=["true", "false"], - help="Whether enable state reuse") + opt.add_argument( + "--frequency_penalty", + type=float, + default=1.5, + help="coefficient for frequency_penalty", + ) + opt.add_argument( + "--presence_penalty", + type=float, + default=0.3, + help="coefficient for presence_penalty", + ) + opt.add_argument( + "--max_generate_length", + type=int, + default=2048, + help="the maximum number of generated token", + ) + opt.add_argument( + "--top_k_num", type=int, default=3, help="the number for top_k sampling" + ) + opt.add_argument( + "--top_p", + type=float, + default=1.0, + help="top_p sampling threshold, enabled if less than 1.0", + ) + opt.add_argument( + "--end_token", + type=int, + default=50256, + help="the token id for ", + ) + opt.add_argument( + "--use_pynative_op", + type=int, + default=0, + help="Whether use pynative op for postproecess", + ) + opt.add_argument( + "--use_past", + type=str, + default="true", + choices=["true", "false"], + help="Whether enable state reuse", + ) def add_training_params(opt): """Add training params""" - opt.add_argument("--seq_length", - type=int, - default=2048, - help="sequence length, default is 2048.") - opt.add_argument("--vocab_size", - type=int, - default=40000, - help="vocabulary size, default is 40000.") - opt.add_argument("--embedding_size", - type=int, - default=16384, - help="embedding table size, default is 16384.") - opt.add_argument("--num_layers", - type=int, - default=64, - help="total layers, default is 64.") - opt.add_argument("--num_heads", - type=int, - default=128, - help="head size, default is 128.") - opt.add_argument("--stage_num", - type=int, - default=1, - help="Pipeline stage num, default is 1.") - opt.add_argument("--micro_size", - type=int, - default=1, - help="Pipeline micro_size, default is 1.") - opt.add_argument("--eod_reset", - type=int, - default=1, - help="Enable eod mask, default is 1.") - opt.add_argument("--warmup_step", - type=int, - default=2000, - help="Warmup step, default is 2000.") - opt.add_argument("--decay_steps", - type=int, - default=200000, - help="Decay step, default is 200000.") - opt.add_argument("--optimizer", - type=str, - default="adam", - choices=["adam", "lamb"], - help="select which optimizer to be used, default adam") - opt.add_argument("--opt_offload", - type=int, default=0, - help="Enable optimizer status offload to host CPU, default is 0") - opt.add_argument("--use_moe", - type=int, default=0, - help="Use moe, default is 0") - opt.add_argument("--per_dp_dim_expert_num", - type=int, default=1, - help="Expert nums in one data parallel dim, only effective when applying moe, default is 1") - opt.add_argument("--eod_id", - type=int, default=50256, - help="The id of end of document") - opt.add_argument("--epoch_size", - type=int, default=1, - help="The training epoch") - opt.add_argument("--sink_size", - type=int, default=2, - help="The sink size of the training. default is 2") - opt.add_argument("--full_batch", - default=1, type=int, - help="Import the full size of a batch for each card, default is 1") - opt.add_argument("--optimizer_shard", - type=int, - default=1, - help="Enable optimizer parallel, default is 1") - opt.add_argument("--per_batch_size", - type=int, - default=0, - help="The batch size for each data parallel way. default 6") - opt.add_argument("--start_lr", - type=float, - default=5e-5, - help="The start learning rate. default 5e-5") - opt.add_argument("--dropout_rate", - type=float, - default=0.1, - help="The dropout rate. default 0.1") - opt.add_argument("--end_lr", - type=float, - default=1e-6, - help="The end learning rate. default 1e-6") - opt.add_argument("--op_level_model_parallel_num", - type=int, - default=8, - help="The model parallel way. default 8") - opt.add_argument("--word_emb_dp", - type=int, default=1, - choices=[0, 1], - help="Whether do data parallel in word embedding. default 1") - opt.add_argument("--gradient_aggregation_group", - type=int, default=4, - help="The gradient communication fusion group. default 4") - opt.add_argument("--data_column_name", - type=str, default="input_ids", - help="Column name of datasets") + opt.add_argument( + "--seq_length", type=int, default=2048, help="sequence length, default is 2048." + ) + opt.add_argument( + "--vocab_size", + type=int, + default=40000, + help="vocabulary size, default is 40000.", + ) + opt.add_argument( + "--embedding_size", + type=int, + default=16384, + help="embedding table size, default is 16384.", + ) + opt.add_argument( + "--num_layers", type=int, default=64, help="total layers, default is 64." + ) + opt.add_argument( + "--num_heads", type=int, default=128, help="head size, default is 128." + ) + opt.add_argument( + "--stage_num", type=int, default=1, help="Pipeline stage num, default is 1." + ) + opt.add_argument( + "--micro_size", type=int, default=1, help="Pipeline micro_size, default is 1." + ) + opt.add_argument( + "--eod_reset", type=int, default=1, help="Enable eod mask, default is 1." + ) + opt.add_argument( + "--warmup_step", type=int, default=2000, help="Warmup step, default is 2000." + ) + opt.add_argument( + "--decay_steps", type=int, default=200000, help="Decay step, default is 200000." + ) + opt.add_argument( + "--optimizer", + type=str, + default="adam", + choices=["adam", "lamb"], + help="select which optimizer to be used, default adam", + ) + opt.add_argument( + "--opt_offload", + type=int, + default=0, + help="Enable optimizer status offload to host CPU, default is 0", + ) + opt.add_argument("--use_moe", type=int, default=0, help="Use moe, default is 0") + opt.add_argument( + "--per_dp_dim_expert_num", + type=int, + default=1, + help="Expert nums in one data parallel dim, only effective when applying moe, default is 1", + ) + opt.add_argument( + "--eod_id", type=int, default=50256, help="The id of end of document" + ) + opt.add_argument("--epoch_size", type=int, default=1, help="The training epoch") + opt.add_argument( + "--sink_size", + type=int, + default=2, + help="The sink size of the training. default is 2", + ) + opt.add_argument( + "--full_batch", + default=1, + type=int, + help="Import the full size of a batch for each card, default is 1", + ) + opt.add_argument( + "--optimizer_shard", + type=int, + default=1, + help="Enable optimizer parallel, default is 1", + ) + opt.add_argument( + "--per_batch_size", + type=int, + default=0, + help="The batch size for each data parallel way. default 6", + ) + opt.add_argument( + "--start_lr", + type=float, + default=5e-5, + help="The start learning rate. default 5e-5", + ) + opt.add_argument( + "--dropout_rate", type=float, default=0.1, help="The dropout rate. default 0.1" + ) + opt.add_argument( + "--end_lr", type=float, default=1e-6, help="The end learning rate. default 1e-6" + ) + opt.add_argument( + "--op_level_model_parallel_num", + type=int, + default=8, + help="The model parallel way. default 8", + ) + opt.add_argument( + "--word_emb_dp", + type=int, + default=1, + choices=[0, 1], + help="Whether do data parallel in word embedding. default 1", + ) + opt.add_argument( + "--gradient_aggregation_group", + type=int, + default=4, + help="The gradient communication fusion group. default 4", + ) + opt.add_argument( + "--data_column_name", + type=str, + default="input_ids", + help="Column name of datasets", + ) def add_retrain_params(opt): """ Add parameters about retrain. """ - opt.add_argument("--pre_trained", - type=str, - default=None, - help="Pretrained checkpoint path.") - opt.add_argument("--save_checkpoint_path", - type=str, - default=None, - help="Save checkpoint path.") - opt.add_argument("--save_checkpoint_obs_path", - type=str, - default=None, - help="Save checkpoint path on OBS.") - opt.add_argument("--keep_checkpoint_max", - type=int, - default=1, - help="Max checkpoint save number.") - opt.add_argument("--save_checkpoint_steps", - type=int, - default=2000, - help="Save checkpoint step number.") - opt.add_argument("--save_checkpoint", - type=ast.literal_eval, - default=False, - help="Whether save checkpoint in local disk.") - opt.add_argument("--ckpt_name_prefix", - type=str, - default="pangu", - help="Saving checkpoint name prefix.") - opt.add_argument("--has_trained_epoches", - type=int, - default=0, - help="Epoches has been trained before.") - opt.add_argument("--has_trained_steps", - type=int, - default=0, - help="Steps has been trained before.") + opt.add_argument( + "--pre_trained", type=str, default=None, help="Pretrained checkpoint path." + ) + opt.add_argument( + "--save_checkpoint_path", type=str, default=None, help="Save checkpoint path." + ) + opt.add_argument( + "--save_checkpoint_obs_path", + type=str, + default=None, + help="Save checkpoint path on OBS.", + ) + opt.add_argument( + "--keep_checkpoint_max", type=int, default=1, help="Max checkpoint save number." + ) + opt.add_argument( + "--save_checkpoint_steps", + type=int, + default=2000, + help="Save checkpoint step number.", + ) + opt.add_argument( + "--save_checkpoint", + type=ast.literal_eval, + default=False, + help="Whether save checkpoint in local disk.", + ) + opt.add_argument( + "--ckpt_name_prefix", + type=str, + default="pangu", + help="Saving checkpoint name prefix.", + ) + opt.add_argument( + "--has_trained_epoches", + type=int, + default=0, + help="Epoches has been trained before.", + ) + opt.add_argument( + "--has_trained_steps", + type=int, + default=0, + help="Steps has been trained before.", + ) def get_args(inference=False): """train function for PanguAlpha""" parser = argparse.ArgumentParser(description="PanguAlpha training") - parser.add_argument('--device_id', - type=int, - default=0, - help="Device id, default is 0.") - parser.add_argument("--device_num", - type=int, - default=128, - help="Use device nums, default is 128.") - parser.add_argument("--distribute", - type=str, - default="true", - choices=["true", "false"], - help="Run distribute, default is true.") - parser.add_argument("--load_ckpt_name", - type=str, - default=None, - help="checkpint file name.") - parser.add_argument("--load_ckpt_path", - type=str, - default=None, - help="checkpoint file path.") - parser.add_argument("--load_ckpt_epoch", - type=int, - default=None, - help="checkpoint epoch.") - parser.add_argument('--code_data', - type=str, - required=True, - help='Location of code data.') - parser.add_argument("--tb_dir", - type=str, - required=True, - help="Location of tensorboard log") - parser.add_argument("--language", - type=str, - default=None, - help="Language of task") - parser.add_argument("--part", - type=int, - default=None, - help="Part of task") - parser.add_argument('--eval_data_url', - required=False, - default=None, - help='Location of eval data.') - parser.add_argument('--train_url', - required=False, - default=None, - help='Location of training outputs.') - parser.add_argument("--run_type", - type=str, - default="predict", - choices=["train", "predict"], - help="The run type") - parser.add_argument("--mode", - type=str, - default="2.6B", - choices=["200B", "13B", "2.6B", "base", "dev", "self_define"], - help="The scale of the model parameters") - parser.add_argument("--device_target", - type=str, - default="Ascend", - choices=["Ascend", "GPU"], - help="The running device") - parser.add_argument("--strategy_load_ckpt_path", - type=str, - default="", - help="The training prallel strategy for the model.") - parser.add_argument("--tokenizer_path", - type=str, - default="./tokenizer_path", - help="The path where stores vocab and vocab model file") - parser.add_argument("--param_init_type", - type=str, - default="fp32", - help="The initialization type for parameters. Default fp32.") - parser.add_argument("--offline", - type=int, - default=1, - help="Running on cloud of not. Default 1.") - parser.add_argument("--export", - type=int, - default=0, - help="Whether export mindir for serving.") - parser.add_argument("--incremental_training", - type=int, - default=0, - help="Enable incremental training. Default 0.") - parser.add_argument("--train_and_eval_mode", - type=int, - default=0, - help="Enable evaling while training. Default 0.") - parser.add_argument("--eval_steps", - type=int, - default=10, - help="The eval step in train and eval mode. Default 10.") + parser.add_argument( + "--device_id", type=int, default=0, help="Device id, default is 0." + ) + parser.add_argument( + "--device_num", type=int, default=128, help="Use device nums, default is 128." + ) + parser.add_argument( + "--distribute", + type=str, + default="true", + choices=["true", "false"], + help="Run distribute, default is true.", + ) + parser.add_argument( + "--load_ckpt_name", type=str, default=None, help="checkpint file name." + ) + parser.add_argument( + "--load_ckpt_path", type=str, default=None, help="checkpoint file path." + ) + parser.add_argument( + "--load_ckpt_epoch", type=int, default=None, help="checkpoint epoch." + ) + parser.add_argument( + "--code_data", type=str, required=True, help="Location of code data." + ) + parser.add_argument( + "--tb_dir", type=str, required=True, help="Location of tensorboard log" + ) + parser.add_argument("--language", type=str, default=None, help="Language of task") + parser.add_argument("--part", type=int, default=None, help="Part of task") + parser.add_argument( + "--eval_data_url", required=False, default=None, help="Location of eval data." + ) + parser.add_argument( + "--train_url", + required=False, + default=None, + help="Location of training outputs.", + ) + parser.add_argument( + "--run_type", + type=str, + default="predict", + choices=["train", "predict"], + help="The run type", + ) + parser.add_argument( + "--mode", + type=str, + default="2.6B", + choices=["200B", "13B", "2.6B", "base", "dev", "self_define"], + help="The scale of the model parameters", + ) + parser.add_argument( + "--device_target", + type=str, + default="Ascend", + choices=["Ascend", "GPU"], + help="The running device", + ) + parser.add_argument( + "--strategy_load_ckpt_path", + type=str, + default="", + help="The training prallel strategy for the model.", + ) + parser.add_argument( + "--tokenizer_path", + type=str, + default="./tokenizer_path", + help="The path where stores vocab and vocab model file", + ) + parser.add_argument( + "--param_init_type", + type=str, + default="fp32", + help="The initialization type for parameters. Default fp32.", + ) + parser.add_argument( + "--offline", type=int, default=1, help="Running on cloud of not. Default 1." + ) + parser.add_argument( + "--export", type=int, default=0, help="Whether export mindir for serving." + ) + parser.add_argument( + "--incremental_training", + type=int, + default=0, + help="Enable incremental training. Default 0.", + ) + parser.add_argument( + "--train_and_eval_mode", + type=int, + default=0, + help="Enable evaling while training. Default 0.", + ) + parser.add_argument( + "--eval_steps", + type=int, + default=10, + help="The eval step in train and eval mode. Default 10.", + ) parser.add_argument( "--profiling", type=int, @@ -586,16 +690,17 @@ def get_args(inference=False): def download_data(src_data_url, tgt_data_path, rank): """ - Download the dataset from the obs. - src_data_url (Str): should be the dataset path in the obs - tgt_data_path (Str): the local dataset path - rank (Int): the current rank id + Download the dataset from the obs. + src_data_url (Str): should be the dataset path in the obs + tgt_data_path (Str): the local dataset path + rank (Int): the current rank id """ cache_url = tgt_data_path EXEC_PATH = "/tmp" if rank % 8 == 0: import moxing as mox + print("Modify the time out from 300 to 30000") print("begin download dataset", flush=True) @@ -610,6 +715,7 @@ def download_data(src_data_url, tgt_data_path, rank): while not os.path.exists("%s/install.txt" % (EXEC_PATH)): time.sleep(1) + # class LossSummaryCallback(Callback): # def __init__(self, summary_dir, bucket, local_rank=0, has_trained_epoch=0, has_trained_step=0, syn_times=100): # self._summary_dir = summary_dir diff --git a/codegeex/mindspore/train.py b/codegeex/mindspore/train.py index 9ccd461..2b62357 100644 --- a/codegeex/mindspore/train.py +++ b/codegeex/mindspore/train.py @@ -29,7 +29,11 @@ import moxing as mox from mindspore import context from mindspore.context import ParallelMode -from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, MicroBatchInterleaved +from mindspore.nn.wrap.cell_wrapper import ( + PipelineCell, + _VirtualDatasetCell, + MicroBatchInterleaved, +) from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.parallel import set_algo_parameters from mindspore.parallel._cost_model_context import _set_multi_subgraphs @@ -38,7 +42,11 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import TimeMonitor from mindspore.train.model import Model -from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net +from mindspore.train.serialization import ( + load_distributed_checkpoint, + load_checkpoint, + load_param_into_net, +) from tensorboardX import SummaryWriter from src.adam import AdamWeightDecayOp @@ -47,20 +55,26 @@ from src.metrics import PPLMetric, ValidationLoss from src.pangu_alpha import PanGUAlphaWithLoss, PanguAlphaModel from src.pangu_alpha_config import set_parse, PanguAlphaConfig -from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell +from src.pangu_alpha_wrapcell import ( + PanguAlphaTrainOneStepWithLossScaleCell, + PanguAlphaTrainPipelineWithLossScaleCell, +) from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay from src.utils import download_data project_root = os.path.abspath( - os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..") -print('project_root:', project_root) + os.path.dirname(os.path.realpath(__file__)) + os.path.sep + ".." +) +print("project_root:", project_root) def set_weight_decay(params): """ Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest """ - decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() + decay_filter = ( + lambda x: "layernorm" not in x.name.lower() and "bias" not in x.name.lower() + ) decay_params = list(filter(decay_filter, params)) other_params = list(filter(lambda x: not decay_filter(x), params)) group_params = [ @@ -77,7 +91,12 @@ def add_checkpoint_callback_policy(args_param, callback, rank_id): """ if args_param.save_checkpoint: # checkpoint store epoch_num and step_num info - ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}] + ckpt_append_info = [ + { + "epoch_num": args_param.has_trained_epoches, + "step_num": args_param.has_trained_steps, + } + ] ckpt_config = CheckpointConfig( save_checkpoint_steps=args_param.save_checkpoint_steps, keep_checkpoint_max=args_param.keep_checkpoint_max, @@ -86,18 +105,22 @@ def add_checkpoint_callback_policy(args_param, callback, rank_id): ) # save checkpoint into rank directory - ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id), - directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), - config=ckpt_config) + ckpoint_cb = ModelCheckpoint( + prefix=args_param.ckpt_name_prefix + str(rank_id), + directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), + config=ckpt_config, + ) callback.append(ckpoint_cb) - saveckpt_cb = SaveCheckpointCallback(cache_dir=args_param.save_checkpoint_path, - bucket=args_param.save_checkpoint_obs_path, - local_rank=rank_id, - has_trained_epoch=args_param.has_trained_epoches, - has_trained_step=args_param.has_trained_steps, - syn_times=args_param.save_checkpoint_steps) + saveckpt_cb = SaveCheckpointCallback( + cache_dir=args_param.save_checkpoint_path, + bucket=args_param.save_checkpoint_obs_path, + local_rank=rank_id, + has_trained_epoch=args_param.has_trained_epoches, + has_trained_step=args_param.has_trained_steps, + syn_times=args_param.save_checkpoint_steps, + ) callback.append(saveckpt_cb) @@ -109,10 +132,14 @@ def set_parallel_context(args_opt): print("rank_id is {}, device_num is {}".format(rank, device_num)) context.reset_auto_parallel_context() context.set_auto_parallel_context( - parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, - full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, - enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt', - optimizer_weight_shard_size=16) + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, + gradients_mean=False, + full_batch=bool(args_opt.full_batch), + strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, + enable_parallel_optimizer=bool(args_opt.optimizer_shard), + strategy_ckpt_save_file="strategy.ckpt", + optimizer_weight_shard_size=16, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() return rank, device_num @@ -122,9 +149,7 @@ def run_train(args_opt): r"""The main training process.""" os.environ["HCCL_CONNECT_TIMEOUT"] = "2000" # Set execution mode - context.set_context( - mode=context.GRAPH_MODE, device_target=args_opt.device_target - ) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) if args_opt.profiling: profiler = Profiler(output_path="/cache/profiler_data") context.set_context(variable_memory_max_size="30GB") @@ -137,21 +162,29 @@ def run_train(args_opt): save_graphs=False, save_graphs_path="/cache/graphs_of_device_id_" + str(rank), ) - cache_url = '/cache/Data/' - eval_cache_url = '/cache/EvalData/' + cache_url = "/cache/Data/" + eval_cache_url = "/cache/EvalData/" if not args_opt.offline: - download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank) - download_data(src_data_url=args_opt.eval_data_url, tgt_data_path=eval_cache_url, rank=rank) + download_data( + src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank + ) + download_data( + src_data_url=args_opt.eval_data_url, tgt_data_path=eval_cache_url, rank=rank + ) # Set model property model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) batch_size = args_opt.per_batch_size * data_parallel_num - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=bool(args_opt.optimizer_shard), - vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True, - gradient_aggregation_group=args_opt.gradient_aggregation_group) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=bool(args_opt.optimizer_shard), + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + gradient_aggregation_group=args_opt.gradient_aggregation_group, + ) micro_interleaved_size = args_opt.micro_interleaved_size config = PanguAlphaConfig( @@ -180,24 +213,37 @@ def run_train(args_opt): loss = CrossEntropyLoss(config.parallel_config.dp_mp_config) if micro_interleaved_size > 1: print("===using MicroBatchInterleaved", flush=True) - pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss), - micro_interleaved_size) + pangu_alpha_with_loss_net = MicroBatchInterleaved( + PanGUAlphaWithLoss(config, pangu_alpha, loss), micro_interleaved_size + ) else: pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss) pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net) print("=====args_opt is: ", args_opt, flush=True) # Warm-up and cosine decay learning rate - lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr, - warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps) + lr = LearningRate( + learning_rate=args_opt.start_lr, + end_learning_rate=args_opt.end_lr, + warmup_steps=args_opt.warmup_step, + decay_steps=args_opt.decay_steps, + ) params = pangu_alpha_with_loss.trainable_params() group_params = set_weight_decay(params) if args_opt.optimizer == "lamb": optimizer = nn.Lamb(group_params, learning_rate=lr) elif args_opt.opt_offload: - optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95, - param_init_type=config.param_init_type) + optimizer = AdamWeightDecayOp( + group_params, + learning_rate=lr, + eps=1e-8, + beta1=0.9, + beta2=0.95, + param_init_type=config.param_init_type, + ) else: - optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) + optimizer = FP32StateAdamWeightDecay( + group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95 + ) # Initial scaling sens loss_scale_value = math.pow(2, 32) epoch_num = args_opt.epoch_size @@ -206,11 +252,17 @@ def run_train(args_opt): time.sleep(rank * 0.05) os.mkdir(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}")) ckpt_name = f"code-13B{rank}_20-{args_opt.load_ckpt_epoch}_2.ckpt" - if not mox.file.exists(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name)): + if not mox.file.exists( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name) + ): print(f"Checkpoint from rank {rank} doesn't exist!") - mox.file.copy(os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), - os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) - param_dict = load_checkpoint(os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name)) + mox.file.copy( + os.path.join(args_opt.load_ckpt_path, f"rank_{rank}", ckpt_name), + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name), + ) + param_dict = load_checkpoint( + os.path.join(args_opt.save_checkpoint_path, f"rank_{rank}", ckpt_name) + ) # TODO: remove after warming-up! # param_dict.pop('global_step') # TODO: add them back if not for the 1st run! @@ -221,7 +273,9 @@ def run_train(args_opt): os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/1') + ) if num == device_num: break if rank % 64 == 0: @@ -231,30 +285,47 @@ def run_train(args_opt): if args_opt.tb_dir is not None and rank == device_num - 1: os.makedirs(args_opt.tb_dir, exist_ok=True) summary_writer = SummaryWriter(args_opt.tb_dir) - os.system(f'chomd 777 -R {args_opt.tb_dir}') + os.system(f"chomd 777 -R {args_opt.tb_dir}") else: summary_writer = None # Dataset loading mindrecord files - ds, ds_eval = create_dataset(config.batch_size * micro_interleaved_size, data_path=args_opt.code_data, - args_opt=args_opt, data_start_index=0, - eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), - eod_id=args_opt.eod_id, - device_num=device_num, rank=rank, epoch=epoch_num, - train_and_eval=bool(args_opt.train_and_eval_mode), val_ratio=0.001) + ds, ds_eval = create_dataset( + config.batch_size * micro_interleaved_size, + data_path=args_opt.code_data, + args_opt=args_opt, + data_start_index=0, + eod_reset=config.eod_reset, + full_batch=bool(args_opt.full_batch), + eod_id=args_opt.eod_id, + device_num=device_num, + rank=rank, + epoch=epoch_num, + train_and_eval=bool(args_opt.train_and_eval_mode), + val_ratio=0.001, + ) actual_epoch_num = int(ds.get_dataset_size() / args_opt.sink_size) callback = [ TimeMonitor(args_opt.sink_size), ] - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) + update_cell = DynamicLossScaleUpdateCell( + loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000 + ) pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell( - pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, - config=config) + pangu_alpha_with_loss, + optimizer=optimizer, + scale_update_cell=update_cell, + enable_global_norm=True, + config=config, + ) if ds_eval: ppl_metric = PPLMetric(config.seq_length) validation_loss = ValidationLoss(eod_token=args_opt.eod_id) - model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss, - metrics={"ppl": ppl_metric, "validation_loss": validation_loss}) + model = Model( + pangu_alpha_with_grads, + eval_network=pangu_alpha_with_loss, + metrics={"ppl": ppl_metric, "validation_loss": validation_loss}, + ) callback.append( EvalCallBack( model=model, @@ -265,7 +336,7 @@ def run_train(args_opt): has_trained_step=args_opt.has_trained_steps, local_rank=rank, rank_size=device_num, - tb_writer=summary_writer + tb_writer=summary_writer, ) ) else: @@ -273,16 +344,26 @@ def run_train(args_opt): if args_opt.load_ckpt_epoch > 0: print("===build model and load ckpt") time_stamp = datetime.datetime.now() - print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before building", flush=True) - model.build(train_dataset=ds, sink_size=args_opt.sink_size, epoch=actual_epoch_num) + print( + f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before building", + flush=True, + ) + model.build( + train_dataset=ds, sink_size=args_opt.sink_size, epoch=actual_epoch_num + ) time_stamp = datetime.datetime.now() - print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before loading ckpt", flush=True) + print( + f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} before loading ckpt", + flush=True, + ) net_not_load = load_param_into_net(pangu_alpha_with_loss, param_dict) opt_not_load = load_param_into_net(optimizer, param_dict) os.mkdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/2/rank_{rank}') while True: - num = len(os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/2')) + num = len( + os.listdir(f'/home/work/sfs/cache/{os.environ["BATCH_JOB_ID"]}/2') + ) if num == device_num: break if rank % 64 == 0: @@ -303,18 +384,32 @@ def run_train(args_opt): if not args_opt.profiling: add_checkpoint_callback_policy(args_opt, callback, rank) if args_opt.incremental_training: - strategy = model.infer_train_layout(train_dataset=ds, sink_size=args_opt.sink_size) + strategy = model.infer_train_layout( + train_dataset=ds, sink_size=args_opt.sink_size + ) print("======start load_distributed checkpoint", flush=True) # For 2.6B and 13B models, the number of ckpt files is 512. - ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"filerted_{ckpt_rank}.ckpt") for ckpt_rank in - range(0, 512)] + ckpt_file_list = [ + os.path.join(args_opt.load_ckpt_path, f"filerted_{ckpt_rank}.ckpt") + for ckpt_rank in range(0, 512) + ] print(f"Loading from path {ckpt_file_list[0]}", flush=True) load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy) - print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True) + print( + "Dataset size: {}, actual_epoch_num: {}".format( + ds.get_dataset_size(), actual_epoch_num + ), + flush=True, + ) try: - model.train(10 if args_opt.profiling else actual_epoch_num, ds, callbacks=callback, - sink_size=args_opt.sink_size, dataset_sink_mode=True) + model.train( + 10 if args_opt.profiling else actual_epoch_num, + ds, + callbacks=callback, + sink_size=args_opt.sink_size, + dataset_sink_mode=True, + ) finally: if args_opt.profiling: jobid = os.environ["BATCH_JOB_ID"] @@ -322,12 +417,16 @@ def run_train(args_opt): rank_id = rank if context.get_context("save_graphs"): mox.file.make_dirs("s3://wudao-1/yyf/graphs_" + jobid) - mox.file.copy_parallel(src_url="/cache/graphs_of_device_id_" + str(rank_id), - dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/graphs_of_device_id_" + str(rank_id), + dst_url="s3://wudao-1/yyf/graphs_" + jobid + "/" + str(rank_id), + ) if rank_id % 8 == 0: mox.file.make_dirs("s3://wudao-1/yyf/profiler_" + jobid) - mox.file.copy_parallel(src_url="/cache/profiler_data", - dst_url="s3://wudao-1/yyf/profiler_" + jobid + "/" + str(rank_id)) + mox.file.copy_parallel( + src_url="/cache/profiler_data", + dst_url="s3://wudao-1/yyf/profiler_" + jobid + "/" + str(rank_id), + ) def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch): @@ -336,17 +435,26 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch): """ print("======start single checkpoint", flush=True) ckpt_name = args_param.ckpt_name_prefix - ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*.ckpt") + ckpt_pattern = os.path.join( + args_param.save_checkpoint_path, + "rank_{}".format(D.get_rank()), + f"{ckpt_name}*.ckpt", + ) ckpt_all_files = glob.glob(ckpt_pattern) if not ckpt_all_files: - print(f"There is no ckpt file in {args_param.save_checkpoint_path}, " - f"current ckpt_files found is {ckpt_all_files} " - f"with pattern {ckpt_pattern}, so skip the loading.") + print( + f"There is no ckpt file in {args_param.save_checkpoint_path}, " + f"current ckpt_files found is {ckpt_all_files} " + f"with pattern {ckpt_pattern}, so skip the loading." + ) return - ckpt_exp_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), - f"{ckpt_name}*_breakpoint.ckpt") + ckpt_exp_pattern = os.path.join( + args_param.save_checkpoint_path, + "rank_{}".format(D.get_rank()), + f"{ckpt_name}*_breakpoint.ckpt", + ) ckpt_exp_files = glob.glob(ckpt_exp_pattern) ckpt_files = [] for file in ckpt_all_files: @@ -354,17 +462,21 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch): ckpt_files.append(file) if not ckpt_files: - print(f"There is no ckpt file in {args_param.save_checkpoint_path}, " - f"current ckpt_files found is {ckpt_files} " - f"with pattern {ckpt_pattern}, so skip the loading.") + print( + f"There is no ckpt file in {args_param.save_checkpoint_path}, " + f"current ckpt_files found is {ckpt_files} " + f"with pattern {ckpt_pattern}, so skip the loading." + ) return ckpt_files.sort(key=os.path.getmtime, reverse=True) time_stamp = datetime.datetime.now() - print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading", - flush=True) + print( + f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading", + flush=True, + ) # Load checkpoint files latest file - print(f'Start to load from {ckpt_files[0]}') + print(f"Start to load from {ckpt_files[0]}") param_dict = load_checkpoint(ckpt_files[0]) if param_dict.get("epoch_num") and param_dict.get("step_num"): args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) @@ -389,15 +501,18 @@ def get_exception_checkpoints(args_param): ckpt_file_list = [] ckpt_name = args_param.ckpt_name_prefix for ckpt_rank in restore_rank_list: - ckpt_pattern = os.path.join(args_param.save_checkpoint_path, - f"rank_{ckpt_rank}", - f"{ckpt_name}*_breakpoint.ckpt") + ckpt_pattern = os.path.join( + args_param.save_checkpoint_path, + f"rank_{ckpt_rank}", + f"{ckpt_name}*_breakpoint.ckpt", + ) ckpt_files = glob.glob(ckpt_pattern) if not ckpt_files: print( f"There is no ckpt file in {args_param.save_checkpoint_path}, " f"current ckpt_files found is {ckpt_files} " - f"with pattern {ckpt_pattern}, so skip the loading.") + f"with pattern {ckpt_pattern}, so skip the loading." + ) return None ckpt_files.sort(key=os.path.getmtime, reverse=True) ckpt_file_list.append(ckpt_files[0]) @@ -461,19 +576,19 @@ def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, map_rank_id = restore_ranks_map_json.get(key) print(f"loading map rank id {map_rank_id}") - ckpt_pattern = os.path.join(args_param.save_checkpoint_path, - f"rank_{map_rank_id}", - f"{ckpt_name}*breakpoint.ckpt") + ckpt_pattern = os.path.join( + args_param.save_checkpoint_path, + f"rank_{map_rank_id}", + f"{ckpt_name}*breakpoint.ckpt", + ) ckpt_files = glob.glob(ckpt_pattern) ckpt_files.sort(key=os.path.getmtime, reverse=True) print(f" checkpoint files {ckpt_files[0]}") param_dict = load_checkpoint(ckpt_files[0]) print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}") if param_dict.get("epoch_num") and param_dict.get("step_num"): - args_param.has_trained_epoches = int( - param_dict["epoch_num"].data.asnumpy()) - args_param.has_trained_steps = int( - param_dict["step_num"].data.asnumpy()) + args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) + args_param.has_trained_steps = int(param_dict["step_num"].data.asnumpy()) # Load checkpoint files model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch) @@ -494,10 +609,14 @@ def set_pipeline_parallel_context(args_opt): print("rank_id is {}, device_num is {}".format(rank_id, device_num)) context.reset_auto_parallel_context() context.set_auto_parallel_context( - parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, - full_batch=bool(args_opt.full_batch), loss_repeated_mean=True, - device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard), - pipeline_stages=args_opt.stage_num) + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, + gradients_mean=False, + full_batch=bool(args_opt.full_batch), + loss_repeated_mean=True, + device_num=device_num, + enable_parallel_optimizer=bool(args_opt.optimizer_shard), + pipeline_stages=args_opt.stage_num, + ) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() return rank_id, device_num @@ -506,9 +625,11 @@ def set_pipeline_parallel_context(args_opt): def run_train_pipeline(args_opt): r"""The main training process in pipeline.""" # Set hccl connect time - os.environ['HCCL_CONNECT_TIMEOUT'] = "6000" + os.environ["HCCL_CONNECT_TIMEOUT"] = "6000" - context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target) + context.set_context( + save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target + ) if args_opt.profiling: profiler = Profiler(output_path="./profiler_data") context.set_context(variable_memory_max_size="30GB") @@ -519,11 +640,17 @@ def run_train_pipeline(args_opt): rank_id, device_num = set_pipeline_parallel_context(args_opt) # copy data from the cloud to the /cache/Data - cache_url = '/cache/Data/' - eval_cache_url = '/cache/EvalData/' + cache_url = "/cache/Data/" + eval_cache_url = "/cache/EvalData/" if not args_opt.offline: - download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank_id) - download_data(src_data_url=args_opt.eval_data_url, tgt_data_path=eval_cache_url, rank=rank_id) + download_data( + src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank_id + ) + download_data( + src_data_url=args_opt.eval_data_url, + tgt_data_path=eval_cache_url, + rank=rank_id, + ) model_parallel_num = args_opt.op_level_model_parallel_num stage_device_num = int(device_num / args_opt.stage_num) data_parallel_num = int(stage_device_num / model_parallel_num) @@ -533,14 +660,15 @@ def run_train_pipeline(args_opt): per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num * args_opt.micro_size - parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, - model_parallel=model_parallel_num, - pipeline_stage=args_opt.stage_num, - micro_batch_num=args_opt.micro_size, - optimizer_shard=bool(args_opt.optimizer_shard), - vocab_emb_dp=bool(args_opt.word_emb_dp), - recompute=True, - ) + parallel_config = TransformerOpParallelConfig( + data_parallel=data_parallel_num, + model_parallel=model_parallel_num, + pipeline_stage=args_opt.stage_num, + micro_batch_num=args_opt.micro_size, + optimizer_shard=bool(args_opt.optimizer_shard), + vocab_emb_dp=bool(args_opt.word_emb_dp), + recompute=True, + ) config = PanguAlphaConfig( batch_size=batch_size // parallel_config.micro_batch_num, num_heads=args_opt.num_heads, @@ -562,27 +690,47 @@ def run_train_pipeline(args_opt): print("===config is: ", config, flush=True) pangu_alpha = PanguAlphaModel(config=config) loss = CrossEntropyLoss(config.parallel_config.dp_mp_config) - pangu_alpha_with_loss_net = PipelineCell(PanGUAlphaWithLoss(config, pangu_alpha, loss), - config.parallel_config.micro_batch_num) + pangu_alpha_with_loss_net = PipelineCell( + PanGUAlphaWithLoss(config, pangu_alpha, loss), + config.parallel_config.micro_batch_num, + ) pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net) print("=====args_opt is: ", args_opt, flush=True) - lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr, - warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps) + lr = LearningRate( + learning_rate=args_opt.start_lr, + end_learning_rate=args_opt.end_lr, + warmup_steps=args_opt.warmup_step, + decay_steps=args_opt.decay_steps, + ) params = pangu_alpha.infer_param_pipeline_stage() group_params = set_weight_decay(params) if args_opt.optimizer == "lamb": optimizer = nn.Lamb(group_params, learning_rate=lr) elif args_opt.opt_offload: - optimizer = AdamWeightDecayOp(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95, - param_init_type=config.param_init_type) + optimizer = AdamWeightDecayOp( + group_params, + learning_rate=lr, + eps=1e-8, + beta1=0.9, + beta2=0.95, + param_init_type=config.param_init_type, + ) else: - optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) + optimizer = nn.AdamWeightDecay( + group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8 + ) - ds = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=args_opt.code_data, - device_num=stage_device_num, args_opt=args_opt, - rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0, - full_batch=context.get_auto_parallel_context("full_batch"), - column_name=args_opt.data_column_name) + ds = create_dataset( + config.batch_size * parallel_config.micro_batch_num, + data_path=args_opt.code_data, + device_num=stage_device_num, + args_opt=args_opt, + rank=rank_id % stage_device_num, + eod_reset=True, + data_start_index=0, + full_batch=context.get_auto_parallel_context("full_batch"), + column_name=args_opt.data_column_name, + ) epoch_num = args_opt.epoch_size step_per_epoch = ds.get_dataset_size() callback_size = args_opt.sink_size @@ -599,19 +747,37 @@ def run_train_pipeline(args_opt): ), ] loss_scale_value = math.pow(2, 32) - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) + update_cell = DynamicLossScaleUpdateCell( + loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000 + ) pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell( - pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell) + pangu_alpha_with_loss, + optimizer=optimizer, + config=config, + scale_update_cell=update_cell, + ) if args_opt.train_and_eval_mode: - ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=eval_cache_url, - args_opt=args_opt, - device_num=stage_device_num, rank=rank_id % stage_device_num, eod_reset=True, - data_start_index=0, full_batch=bool(args_opt.full_batch), - column_name=args_opt.data_column_name, - num_samples=args_opt.eval_steps * config.batch_size) + ds_eval = create_dataset( + config.batch_size * parallel_config.micro_batch_num, + data_path=eval_cache_url, + args_opt=args_opt, + device_num=stage_device_num, + rank=rank_id % stage_device_num, + eod_reset=True, + data_start_index=0, + full_batch=bool(args_opt.full_batch), + column_name=args_opt.data_column_name, + num_samples=args_opt.eval_steps * config.batch_size, + ) ppl_metric = PPLMetric(config.seq_length) - pangu_alpha_with_loss_eval_net = _VirtualDatasetCell(PanGUAlphaWithLoss(config, pangu_alpha, loss)) - model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss_eval_net, metrics={"ppl": ppl_metric}) + pangu_alpha_with_loss_eval_net = _VirtualDatasetCell( + PanGUAlphaWithLoss(config, pangu_alpha, loss) + ) + model = Model( + pangu_alpha_with_grads, + eval_network=pangu_alpha_with_loss_eval_net, + metrics={"ppl": ppl_metric}, + ) model.build(ds, ds_eval, sink_size=callback_size) eval_callback = EvalCallBack(model, ds_eval, ppl_metric) callback.append(eval_callback) @@ -619,10 +785,23 @@ def run_train_pipeline(args_opt): model = Model(pangu_alpha_with_grads) if args_opt.pre_trained: - flag = restore_exception_checkpoint(args_opt, callback_size, ds, model, - pangu_alpha_with_grads, epoch=actual_epoch_num) + flag = restore_exception_checkpoint( + args_opt, + callback_size, + ds, + model, + pangu_alpha_with_grads, + epoch=actual_epoch_num, + ) if not flag: - restore_checkpoint(args_opt, callback_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num) + restore_checkpoint( + args_opt, + callback_size, + ds, + model, + pangu_alpha_with_grads, + epoch=actual_epoch_num, + ) callback = [ TimeMonitor(callback_size), @@ -639,8 +818,13 @@ def run_train_pipeline(args_opt): # add_checkpoint_callback_policy(args_opt, callback, rank_id) print("------ train start -------") - model.train(10 if args_opt.profiling else actual_epoch_num, ds, callbacks=callback, - sink_size=callback_size, dataset_sink_mode=True) + model.train( + 10 if args_opt.profiling else actual_epoch_num, + ds, + callbacks=callback, + sink_size=callback_size, + dataset_sink_mode=True, + ) if args_opt.profiling: profiler.analyse() @@ -652,7 +836,9 @@ def run_train_pipeline(args_opt): raise ValueError("The per_batch_size has not been configured.") if opt.stage_num > 1: if bool(opt.use_moe) or bool(opt.opt_offload): - raise ValueError("Currently, moe and host device mode is not supported in pipeline parallel.") + raise ValueError( + "Currently, moe and host device mode is not supported in pipeline parallel." + ) run_train_pipeline(opt) else: run_train(opt) diff --git a/codegeex/oneflow/__init__.py b/codegeex/oneflow/__init__.py index 16975c0..f97fbb6 100644 --- a/codegeex/oneflow/__init__.py +++ b/codegeex/oneflow/__init__.py @@ -1 +1 @@ -from .codegeex_model import CodeGeeXModel \ No newline at end of file +from .codegeex_model import CodeGeeXModel diff --git a/codegeex/oneflow/codegeex_model.py b/codegeex/oneflow/codegeex_model.py index b1a2bea..64a70f3 100644 --- a/codegeex/oneflow/codegeex_model.py +++ b/codegeex/oneflow/codegeex_model.py @@ -4,11 +4,16 @@ from oneflow.nn.parameter import Parameter from ..quantization import QuantizedLinear + def fast_gelu(x): """Mindspore's fast gelu implementation.""" - if hasattr(torch._C, 'quick_gelu'): + if hasattr(torch._C, "quick_gelu"): return torch._C.quick_gelu(x) - return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(0.851 * (x - torch.abs(x))) + return ( + x + / (1 + torch.exp(-1.702 * torch.abs(x))) + * torch.exp(0.851 * (x - torch.abs(x))) + ) class MLP(torch.nn.Module): @@ -20,7 +25,7 @@ class MLP(torch.nn.Module): """ def __init__( - self, + self, hidden_size, ): super(MLP, self).__init__() @@ -47,7 +52,7 @@ def forward(self, hidden_states): output = self.dense_4h_to_h(intermediate_parallel) return output - + class SelfAttention(torch.nn.Module): """self-attention layer abstract class. @@ -56,9 +61,9 @@ class SelfAttention(torch.nn.Module): """ def __init__( - self, + self, hidden_size, - num_attention_heads, + num_attention_heads, layer_number, fp16=True, attention_softmax_in_fp32=True, @@ -71,8 +76,10 @@ def __init__( self.layer_number = max(1, layer_number) assert self.hidden_size % self.num_attention_heads == 0 - self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) - + self.hidden_size_per_attention_head = int( + self.hidden_size // self.num_attention_heads + ) + self.query = torch.nn.Linear(self.hidden_size, self.hidden_size) self.key = torch.nn.Linear(self.hidden_size, self.hidden_size) self.value = torch.nn.Linear(self.hidden_size, self.hidden_size) @@ -98,34 +105,47 @@ def forward( # Query, Key, and Value # ===================== - if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear): - query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([hidden_states, hidden_states, hidden_states], - [self.query.weight, self.key.weight, self.value.weight], - [self.query.bias, self.key.bias, self.value.bias]) + if hasattr(torch._C, "grouped_matmul_bias") and not isinstance( + self.query, QuantizedLinear + ): + query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias( + [hidden_states, hidden_states, hidden_states], + [self.query.weight, self.key.weight, self.value.weight], + [self.query.bias, self.key.bias, self.value.bias], + ) else: query_layer = self.query(hidden_states) key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) - - fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2') + + fallback = not hasattr(torch._C, "fused_multi_head_attention_inference_v2") if fallback: - if hasattr(torch._C, 'fused_codegeex_qkv_reshape'): - query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads) + if hasattr(torch._C, "fused_codegeex_qkv_reshape"): + ( + query_layer, + key_layer, + value_layer, + ) = torch._C.fused_codegeex_qkv_reshape( + query_layer, key_layer, value_layer, self.num_attention_heads + ) else: - new_query_layer_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) query_layer = query_layer.view(*new_query_layer_shape) - new_query_layer_shape = key_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) key_layer = key_layer.view(*new_query_layer_shape) - new_query_layer_shape = value_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) value_layer = value_layer.view(*new_query_layer_shape) # ================================== @@ -134,10 +154,10 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), - key_layer), dim=0) - value_layer = torch.cat((past_value.type_as(value_layer), - value_layer), dim=0) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) if get_key_value: present = (key_layer, value_layer) @@ -146,18 +166,29 @@ def forward( # =================================== # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) - key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) + query_layer = query_layer.contiguous().view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.contiguous().view( + output_size[3], output_size[0] * output_size[1], -1 + ) # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.matmul(query_layer.transpose(0, 1), - key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor + matmul_result = ( + torch.matmul( + query_layer.transpose(0, 1), + key_layer.transpose(0, 1).transpose(1, 2), + ) + / self.norm_factor + ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -171,29 +202,38 @@ def forward( with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ - ..., - attention_scores.size(3) - 1, - :attention_scores.size(3)].unsqueeze(2) + ..., + attention_scores.size(3) - 1, + : attention_scores.size(3), + ].unsqueeze(2) else: attention_mask = attention_mask[ - ..., - :attention_scores.size(3), - :attention_scores.size(3)] + ..., + : attention_scores.size(3), + : attention_scores.size(3), + ] if context_length is not None: attention_mask = torch.clone(attention_mask) attention_mask[:, :, context_length:, :] = True - + attention_mask = ~attention_mask attention_mask = attention_mask.contiguous() # attention scores and attention mask [b, np, sq, sk] # attention_scores = attention_mask_func(attention_scores, attention_mask) - if hasattr(torch._C, 'fused_scale_mask_softmax'): + if hasattr(torch._C, "fused_scale_mask_softmax"): if self.attention_softmax_in_fp32: - attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half() + attention_probs = torch._C.fused_scale_mask_softmax( + attention_scores.float(), + attention_mask, + fill_value=-10000.0, + scale=1.0, + ).half() else: - attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0) + attention_probs = torch._C.fused_scale_mask_softmax( + attention_scores, attention_mask, fill_value=-10000.0, scale=1.0 + ) else: attention_scores = attention_scores - attention_mask * 10000.0 if self.attention_softmax_in_fp32: @@ -209,19 +249,26 @@ def forward( # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) - # change view [sq, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [sq, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) - context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) + context_layer = torch.bmm( + attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0) + ) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) @@ -230,8 +277,7 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) context_layer = context_layer.view(*new_context_layer_shape) else: if layer_past is not None: @@ -249,20 +295,19 @@ def forward( ) if get_key_value: present = (key_layer, value_layer) - - context_layer = torch._C.fused_multi_head_attention_inference_v2( - query=query_layer, - key=key_layer, - value=value_layer, - query_head_size=self.hidden_size_per_attention_head, - causal=True, - causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0], - query_layout="MB(HK)", - key_layout="MB(HK)", - value_layout="MB(HK)", - output_layout="MB(HK)", - ) + context_layer = torch._C.fused_multi_head_attention_inference_v2( + query=query_layer, + key=key_layer, + value=value_layer, + query_head_size=self.hidden_size_per_attention_head, + causal=True, + causal_diagonal_offset=key_layer.shape[0] - query_layer.shape[0], + query_layout="MB(HK)", + key_layout="MB(HK)", + value_layout="MB(HK)", + output_layout="MB(HK)", + ) # ================= # Output. [sq, b, h] @@ -298,7 +343,9 @@ def __init__( self.layer_number = max(1, layer_number) assert self.hidden_size % self.num_attention_heads == 0 - self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) + self.hidden_size_per_attention_head = int( + self.hidden_size // self.num_attention_heads + ) self.query = torch.nn.Linear(self.hidden_size, self.hidden_size) self.key = torch.nn.Linear(self.hidden_size, self.hidden_size) @@ -308,7 +355,7 @@ def __init__( self.softmax = torch.nn.Softmax(dim=-1) self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size) - + def forward( self, hidden_states, @@ -321,34 +368,47 @@ def forward( ): # hidden_states: [sq, b, h] - if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear): - query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([query_hidden_state, hidden_states, hidden_states], - [self.query.weight, self.key.weight, self.value.weight], - [self.query.bias, self.key.bias, self.value.bias]) + if hasattr(torch._C, "grouped_matmul_bias") and not isinstance( + self.query, QuantizedLinear + ): + query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias( + [query_hidden_state, hidden_states, hidden_states], + [self.query.weight, self.key.weight, self.value.weight], + [self.query.bias, self.key.bias, self.value.bias], + ) else: query_layer = self.query(query_hidden_state) key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) - - fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2') + + fallback = not hasattr(torch._C, "fused_multi_head_attention_inference_v2") if fallback: - if hasattr(torch._C, 'fused_codegeex_qkv_reshape'): - query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads) + if hasattr(torch._C, "fused_codegeex_qkv_reshape"): + ( + query_layer, + key_layer, + value_layer, + ) = torch._C.fused_codegeex_qkv_reshape( + query_layer, key_layer, value_layer, self.num_attention_heads + ) else: - new_query_layer_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) query_layer = query_layer.view(*new_query_layer_shape) - new_query_layer_shape = key_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) key_layer = key_layer.view(*new_query_layer_shape) - new_query_layer_shape = value_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) value_layer = value_layer.view(*new_query_layer_shape) # ================================== @@ -357,10 +417,10 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), - key_layer), dim=0) - value_layer = torch.cat((past_value.type_as(value_layer), - value_layer), dim=0) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) if get_key_value: present = (key_layer, value_layer) @@ -369,18 +429,29 @@ def forward( # =================================== # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) # [s, b, np, hn] -> [s, b * np, hn] - query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) - key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) + query_layer = query_layer.contiguous().view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.contiguous().view( + output_size[3], output_size[0] * output_size[1], -1 + ) # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.matmul(query_layer.transpose(0, 1), - key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor + matmul_result = ( + torch.matmul( + query_layer.transpose(0, 1), + key_layer.transpose(0, 1).transpose(1, 2), + ) + / self.norm_factor + ) # change view to [b, np, s, s] attention_scores = matmul_result.view(*output_size) @@ -393,14 +464,14 @@ def forward( with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ - ..., - attention_scores.size(3) - 1, - :attention_scores.size(3)].unsqueeze(2) + ..., + attention_scores.size(3) - 1, + : attention_scores.size(3), + ].unsqueeze(2) else: attention_mask = attention_mask[ - ..., - :attention_scores.size(3), - :attention_scores.size(3)] + ..., : attention_scores.size(3), : attention_scores.size(3) + ] if context_length is not None: attention_mask = torch.clone(attention_mask) @@ -413,7 +484,7 @@ def forward( attention_probs = self.softmax(attention_scores.float()).half() else: attention_probs = self.softmax(attention_scores) - + # ========================= # Context layer. [sq, b, hp] # ========================= @@ -422,20 +493,27 @@ def forward( # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) # change view [sq, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) + context_layer = torch.bmm( + attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0) + ) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) @@ -444,8 +522,7 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) context_layer = context_layer.view(*new_context_layer_shape) else: @@ -465,18 +542,18 @@ def forward( if get_key_value: present = (key_layer, value_layer) - if hasattr(torch._C, 'fused_multi_head_attention_inference_v2'): + if hasattr(torch._C, "fused_multi_head_attention_inference_v2"): context_layer = torch._C.fused_multi_head_attention_inference_v2( - query=query_layer, - key=key_layer, - value=value_layer, - query_head_size=self.hidden_size_per_attention_head, - causal=True, - causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0], - query_layout="MB(HK)", - key_layout="MB(HK)", - value_layout="MB(HK)", - output_layout="MB(HK)", + query=query_layer, + key=key_layer, + value=value_layer, + query_head_size=self.hidden_size_per_attention_head, + causal=True, + causal_diagonal_offset=key_layer.shape[0] - query_layer.shape[0], + query_layout="MB(HK)", + key_layout="MB(HK)", + value_layout="MB(HK)", + output_layout="MB(HK)", ) # ================= @@ -498,10 +575,10 @@ class TransformerLayer(torch.nn.Module): """ def __init__( - self, + self, hidden_size, num_attention_heads, - layer_number, + layer_number, layernorm_epsilon=1e-5, fp16=True, attention_softmax_in_fp32=True, @@ -512,19 +589,23 @@ def __init__( self.layer_number = layer_number # Layernorm on the input data. - self.input_layernorm = torch.nn.LayerNorm(hidden_size, - eps=self.layernorm_epsilon) + self.input_layernorm = torch.nn.LayerNorm( + hidden_size, eps=self.layernorm_epsilon + ) # Self attention. - self.attention = SelfAttention(hidden_size, - num_attention_heads, - layer_number, - fp16, - attention_softmax_in_fp32) + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_number, + fp16, + attention_softmax_in_fp32, + ) # Layernorm on the input data. - self.post_attention_layernorm = torch.nn.LayerNorm(self.hidden_size, - eps=self.layernorm_epsilon) + self.post_attention_layernorm = torch.nn.LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon + ) self.mlp = MLP(self.hidden_size) def forward( @@ -543,13 +624,15 @@ def forward( layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output, attention_mask = self.attention(layernorm_output, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length, - layer_id=layer_id) + attention_output, attention_mask = self.attention( + layernorm_output, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + layer_id=layer_id, + ) if get_key_value: attention_output, presents = attention_output @@ -557,7 +640,7 @@ def forward( # Residual connection. residual = hidden_states layernorm_input = attention_output + residual - + # Use FP32 for Layernorm # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -577,7 +660,7 @@ class TopQueryLayer(torch.nn.Module): """ def __init__( - self, + self, hidden_size, num_attention_heads, layer_number, @@ -590,16 +673,18 @@ def __init__( self.layer_number = layer_number # Use FP32 for Layernorm - self.input_layernorm = torch.nn.LayerNorm(self.hidden_size, - eps=self.layernorm_epsilon) + self.input_layernorm = torch.nn.LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon + ) # Self attention. - self.attention = TopQuerySelfAttention(self.hidden_size, - self.num_attention_heads, - self.layer_number) + self.attention = TopQuerySelfAttention( + self.hidden_size, self.num_attention_heads, self.layer_number + ) # Layernorm on the input data. - self.post_attention_layernorm = torch.nn.LayerNorm(self.hidden_size, - eps=self.layernorm_epsilon) + self.post_attention_layernorm = torch.nn.LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon + ) # MLP self.mlp = MLP(self.hidden_size) @@ -622,13 +707,15 @@ def forward( layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output = self.attention(layernorm_output, - query_hidden_state, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + attention_output = self.attention( + layernorm_output, + query_hidden_state, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: attention_output, presents = attention_output @@ -636,7 +723,7 @@ def forward( # Residual connection. residual = hidden_states layernorm_input = attention_output + residual - + # Use FP32 for Layernorm # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -678,22 +765,27 @@ def __init__( if self.num_unique_layers is None: self.num_unique_layers = self.num_layers - assert self.num_layers % self.num_unique_layers == 0, \ - 'number of layers should be divisible by number of unique layers' - + assert ( + self.num_layers % self.num_unique_layers == 0 + ), "number of layers should be divisible by number of unique layers" + # Transformer layers. def build_layer(layer_number): - return TransformerLayer(self.hidden_size, self.num_attention_heads, layer_number) + return TransformerLayer( + self.hidden_size, self.num_attention_heads, layer_number + ) self.layers = torch.nn.ModuleList( - [build_layer(i + 1) for i in range(self.num_unique_layers)]) + [build_layer(i + 1) for i in range(self.num_unique_layers)] + ) - self.topQueryLayer = TopQueryLayer(self.hidden_size, - self.num_attention_heads, - self.num_unique_layers) + self.topQueryLayer = TopQueryLayer( + self.hidden_size, self.num_attention_heads, self.num_unique_layers + ) - self.final_layernorm = torch.nn.LayerNorm(self.hidden_size, - eps=self.layernorm_epsilon) + self.final_layernorm = torch.nn.LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon + ) def _get_layer_index(self, layer_number): return layer_number % self.num_unique_layers @@ -723,13 +815,15 @@ def forward( past = None if layer_past is not None: past = layer_past[index] - hidden_states, attention_mask = layer(hidden_states, - attention_mask, - layer_past=past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length, - layer_id=index) + hidden_states, attention_mask = layer( + hidden_states, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + layer_id=index, + ) if get_key_value: hidden_states, present = hidden_states presents.append(present) @@ -744,13 +838,15 @@ def forward( past = None if layer_past is not None: past = layer_past[self.num_layers] - hidden_states = self.topQueryLayer(hidden_states_, - query_hidden_state, - origin_attention_mask, - layer_past=past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + hidden_states = self.topQueryLayer( + hidden_states_, + query_hidden_state, + origin_attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: hidden_states, present = hidden_states @@ -789,35 +885,39 @@ def __init__( self.hidden_size = hidden_size self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length - + # Word embeddings. self.word_embeddings = torch.nn.Embedding(self.vocab_size, self.hidden_size) - self._word_embeddings_key = 'word_embeddings' - + self._word_embeddings_key = "word_embeddings" + # Position embedding. - self.position_embeddings = torch.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.position_embeddings = torch.nn.Embedding( + self.max_sequence_length, self.hidden_size + ) self.position_embeddings = self.position_embeddings.half() - self._position_embeddings_key = 'position_embeddings' - + self._position_embeddings_key = "position_embeddings" + def forward(self, input_ids, position_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings - + return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) - + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict( + destination, prefix, keep_vars + ) + state_dict_[ + self._position_embeddings_key + ] = self.position_embeddings.state_dict(destination, prefix, keep_vars) + return state_dict_ def load_state_dict(self, state_dict, strict=True): @@ -830,10 +930,9 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] - state_dict_["weight"] = state_dict_["weight"][:self.vocab_size] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] + state_dict_["weight"] = state_dict_["weight"][: self.vocab_size] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -843,11 +942,10 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) - + class QueryEmbedding(torch.nn.Module): """Language model embeddings. @@ -871,24 +969,27 @@ def __init__( self.max_sequence_length = max_sequence_length # Top query position embedding (serial). - self.top_query_embeddings = torch.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.top_query_embeddings = torch.nn.Embedding( + self.max_sequence_length, self.hidden_size + ) self.top_query_embeddings = self.top_query_embeddings.half() - self._top_query_embeddings_key = 'top_query_embeddings' - + self._top_query_embeddings_key = "top_query_embeddings" + def forward(self, position_ids): # Embeddings. embeddings = self.top_query_embeddings(position_ids) - + return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._top_query_embeddings_key] \ - = self.top_query_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[ + self._top_query_embeddings_key + ] = self.top_query_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ @@ -902,11 +1003,10 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'top_query_embeddings' in key: - state_dict_[key.split('top_query_embeddings.')[1]] \ - = state_dict[key] + if "top_query_embeddings" in key: + state_dict_[key.split("top_query_embeddings.")[1]] = state_dict[key] self.top_query_embeddings.load_state_dict(state_dict_, strict=strict) - + class TransformerLanguageModel(torch.nn.Module): """Transformer language model. @@ -939,32 +1039,32 @@ def __init__( self.max_position_embeddings = max_position_embeddings # Embeddings - self.embedding = Embedding(self.hidden_size, - self.padded_vocab_size, - self.max_position_embeddings) - self._embedding_key = 'embedding' + self.embedding = Embedding( + self.hidden_size, self.padded_vocab_size, self.max_position_embeddings + ) + self._embedding_key = "embedding" # Query embeddings - self.topQueryEmbedding = QueryEmbedding(self.hidden_size, - self.padded_vocab_size, - self.max_position_embeddings) - self._topQueryEmbedding_key = 'topQueryEmbedding' + self.topQueryEmbedding = QueryEmbedding( + self.hidden_size, self.padded_vocab_size, self.max_position_embeddings + ) + self._topQueryEmbedding_key = "topQueryEmbedding" # Transformer - self.transformer = Transformer(self.hidden_size, - self.num_attention_heads, - self.num_layers) - self._transformer_key = 'transformer' + self.transformer = Transformer( + self.hidden_size, self.num_attention_heads, self.num_layers + ) + self._transformer_key = "transformer" def forward( - self, - input_ids, - position_ids, - attention_mask, - layer_past=None, - get_key_value=False, - prompt_length=None, - context_length=None, + self, + input_ids, + position_ids, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, ): # Embeddings. @@ -973,30 +1073,39 @@ def forward( queryEmbedding_out = self.topQueryEmbedding(query_position_ids) # Transformer. - transformer_output = self.transformer(embedding_output, - queryEmbedding_out, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + transformer_output = self.transformer( + embedding_output, + queryEmbedding_out, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) return transformer_output - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._embedding_key] \ - = self.embedding.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._topQueryEmbedding_key] \ - = self.topQueryEmbedding.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._transformer_key] \ - = self.transformer.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[ + self._embedding_key + ] = self.embedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[ + self._topQueryEmbedding_key + ] = self.topQueryEmbedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[ + self._transformer_key + ] = self.transformer.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) return state_dict_ @@ -1010,7 +1119,7 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if '_embeddings' in key: + if "_embeddings" in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) @@ -1020,7 +1129,7 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if '_embeddings' in key: + if "_embeddings" in key: state_dict_[key] = state_dict[key] self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict) @@ -1031,8 +1140,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'transformer.' in key: - state_dict_[key.split('transformer.')[1]] = state_dict[key] + if "transformer." in key: + state_dict_[key.split("transformer.")[1]] = state_dict[key] self.transformer.load_state_dict(state_dict_, strict=strict) @@ -1048,14 +1157,16 @@ def __init__( max_position_embeddings, ): super(CodeGeeXModel, self).__init__() - - self.language_model = TransformerLanguageModel(hidden_size, - num_layers, - num_attention_heads, - padded_vocab_size, - max_position_embeddings) + + self.language_model = TransformerLanguageModel( + hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings, + ) self._language_model_key = "language_model" - + def forward( self, input_ids, @@ -1067,31 +1178,38 @@ def forward( context_length=None, ): # Language model. - lm_output = self.language_model(input_ids, - position_ids, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: lm_output, presents = lm_output - output = F.linear(lm_output, self.language_model.embedding.word_embeddings.weight.half()) - + output = F.linear( + lm_output, self.language_model.embedding.word_embeddings.weight.half() + ) + if get_key_value: output = [output, presents] return output - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[ + self._language_model_key + ] = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) return state_dict_ def load_state_dict(self, state_dict, strict=True): diff --git a/codegeex/oneflow/inference.py b/codegeex/oneflow/inference.py index 23db106..2346b1e 100644 --- a/codegeex/oneflow/inference.py +++ b/codegeex/oneflow/inference.py @@ -10,10 +10,10 @@ def get_ltor_masks_and_position_ids( - data, - eod_token, - reset_position_ids, - reset_attention_mask, + data, + eod_token, + reset_position_ids, + reset_attention_mask, ): """Build masks and position id for left to right model.""" @@ -65,9 +65,9 @@ def get_ltor_masks_and_position_ids( def get_batch( - context_tokens, - micro_batch_size, - eod_token, + context_tokens, + micro_batch_size, + eod_token, reset_position_ids=False, reset_attention_mask=False, ): @@ -125,15 +125,15 @@ def pad_batch(batch, pad_id, seq_length): def forward_step( - model, - tokens, - seq_length, - position_ids, - attention_mask, - layer_past=None, - get_key_value=None, - prompt_length=None, - context_length=None, + model, + tokens, + seq_length, + position_ids, + attention_mask, + layer_past=None, + get_key_value=None, + prompt_length=None, + context_length=None, ): # Forward pass through the model. output_tensor = model( @@ -156,28 +156,30 @@ def forward_step( def get_token_stream( - model, - tokenizer, - seq_length, - out_seq_length, - context_tokens, - return_scores: bool = False, - prompt_length: int = None, - micro_batch_size: int = None, - bad_ids: List = None, - temperature: float = 1.0, - topp: float = 1.0, - topk: int = 0.0, - greedy: bool = False, - recompute: bool = False, + model, + tokenizer, + seq_length, + out_seq_length, + context_tokens, + return_scores: bool = False, + prompt_length: int = None, + micro_batch_size: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + greedy: bool = False, + recompute: bool = False, ): - context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length) + context_tokens, context_lengths = pad_batch( + context_tokens, tokenizer.eos_token_id, seq_length + ) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch( - context_tokens_tensor, + context_tokens_tensor, micro_batch_size, tokenizer.eos_token_id, ) @@ -215,23 +217,23 @@ def switch(val1, val2, boolean): def sample_sequence_batch( - model, - tokenizer, - context_tokens, - context_lengths, - attention_mask, - position_ids, - seq_length, - out_seq_length, - maxlen=None, - return_scores: bool = False, - prompt_length: int = None, - bad_ids: List = None, - temperature: float = 1.0, - topp: float = 1.0, - topk: int = 0.0, - recompute: bool = False, - greedy: bool = False, + model, + tokenizer, + context_tokens, + context_lengths, + attention_mask, + position_ids, + seq_length, + out_seq_length, + maxlen=None, + return_scores: bool = False, + prompt_length: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + recompute: bool = False, + greedy: bool = False, ): model.eval() with torch.no_grad(): @@ -257,30 +259,32 @@ def sample_sequence_batch( while context_length <= (maxlen): if recompute: - logits = model(tokens, - position_ids, - attention_mask, - prompt_length=prompt_length, - context_length=context_length, - ) + logits = model( + tokens, + position_ids, + attention_mask, + prompt_length=prompt_length, + context_length=context_length, + ) logits = logits[:, context_length - 1, :] else: if counter == 0: tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] else: - tokens2use = tokens[:, context_length - 1].view( - batch_size, -1) + tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( - batch_size, -1) - logits, layer_past = model(tokens2use, - positions2use, - attention_mask, - layer_past=layer_past, - get_key_value=True, - prompt_length=prompt_length, - context_length=context_length, - ) + batch_size, -1 + ) + logits, layer_past = model( + tokens2use, + positions2use, + attention_mask, + layer_past=layer_past, + get_key_value=True, + prompt_length=prompt_length, + context_length=context_length, + ) logits = logits[:, -1].view(batch_size, -1).contiguous() if bad_ids is not None: @@ -314,12 +318,12 @@ def sample_sequence_batch( lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) - + if return_scores: yield tokens, (lengths, scores) else: yield tokens, lengths - + context_length += 1 counter += 1 if done: diff --git a/codegeex/paddle/__init__.py b/codegeex/paddle/__init__.py index 16975c0..f97fbb6 100644 --- a/codegeex/paddle/__init__.py +++ b/codegeex/paddle/__init__.py @@ -1 +1 @@ -from .codegeex_model import CodeGeeXModel \ No newline at end of file +from .codegeex_model import CodeGeeXModel diff --git a/codegeex/paddle/codegeex_model.py b/codegeex/paddle/codegeex_model.py index 4d946b0..f59a083 100644 --- a/codegeex/paddle/codegeex_model.py +++ b/codegeex/paddle/codegeex_model.py @@ -5,7 +5,11 @@ def fast_gelu(x): """Mindspore's fast gelu implementation.""" - return x / (1 + paddle.exp(-1.702 * paddle.abs(x))) * paddle.exp(0.851 * (x - paddle.abs(x))) + return ( + x + / (1 + paddle.exp(-1.702 * paddle.abs(x))) + * paddle.exp(0.851 * (x - paddle.abs(x))) + ) class MLP(paddle.nn.Layer): @@ -18,7 +22,7 @@ class MLP(paddle.nn.Layer): """ def __init__( - self, + self, hidden_size, ): super(MLP, self).__init__() @@ -45,7 +49,7 @@ def forward(self, hidden_states): output = self.dense_4h_to_h(intermediate_parallel) return output - + class SelfAttention(paddle.nn.Layer): """self-attention layer abstract class. @@ -55,9 +59,9 @@ class SelfAttention(paddle.nn.Layer): """ def __init__( - self, + self, hidden_size, - num_attention_heads, + num_attention_heads, layer_number, fp16=True, attention_softmax_in_fp32=True, @@ -70,8 +74,10 @@ def __init__( self.layer_number = max(1, layer_number) assert self.hidden_size % self.num_attention_heads == 0 - self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) - + self.hidden_size_per_attention_head = int( + self.hidden_size // self.num_attention_heads + ) + self.query = paddle.nn.Linear(self.hidden_size, self.hidden_size) self.key = paddle.nn.Linear(self.hidden_size, self.hidden_size) self.value = paddle.nn.Linear(self.hidden_size, self.hidden_size) @@ -100,19 +106,22 @@ def forward( key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) - new_query_layer_shape = query_layer.shape[:-1] + \ - [self.num_attention_heads, - self.hidden_size_per_attention_head] + new_query_layer_shape = query_layer.shape[:-1] + [ + self.num_attention_heads, + self.hidden_size_per_attention_head, + ] query_layer = query_layer.reshape(new_query_layer_shape) - new_query_layer_shape = key_layer.shape[:-1] + \ - [self.num_attention_heads, - self.hidden_size_per_attention_head] + new_query_layer_shape = key_layer.shape[:-1] + [ + self.num_attention_heads, + self.hidden_size_per_attention_head, + ] key_layer = key_layer.reshape(new_query_layer_shape) - new_query_layer_shape = value_layer.shape[:-1] + \ - [self.num_attention_heads, - self.hidden_size_per_attention_head] + new_query_layer_shape = value_layer.shape[:-1] + [ + self.num_attention_heads, + self.hidden_size_per_attention_head, + ] value_layer = value_layer.reshape(new_query_layer_shape) # ================================== @@ -121,10 +130,12 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - key_layer = paddle.concat((past_key.cast(key_layer.dtype), - key_layer), axis=0) - value_layer = paddle.concat((past_value.cast(value_layer.dtype), - value_layer), axis=0) + key_layer = paddle.concat( + (past_key.cast(key_layer.dtype), key_layer), axis=0 + ) + value_layer = paddle.concat( + (past_value.cast(value_layer.dtype), value_layer), axis=0 + ) if get_key_value: present = (key_layer, value_layer) @@ -133,18 +144,29 @@ def forward( # =================================== # [b, np, sq, sk] - output_size = (query_layer.shape[1], - query_layer.shape[2], - query_layer.shape[0], - key_layer.shape[0]) + output_size = ( + query_layer.shape[1], + query_layer.shape[2], + query_layer.shape[0], + key_layer.shape[0], + ) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape([output_size[2], output_size[0] * output_size[1], -1]) - key_layer = key_layer.reshape([output_size[3], output_size[0] * output_size[1], -1]) + query_layer = query_layer.reshape( + [output_size[2], output_size[0] * output_size[1], -1] + ) + key_layer = key_layer.reshape( + [output_size[3], output_size[0] * output_size[1], -1] + ) # Raw attention scores. [b * np, sq, sk] - matmul_result = paddle.matmul(query_layer.transpose([1, 0, 2]), - key_layer.transpose([1, 0, 2]).transpose([0, 2, 1])) / self.norm_factor + matmul_result = ( + paddle.matmul( + query_layer.transpose([1, 0, 2]), + key_layer.transpose([1, 0, 2]).transpose([0, 2, 1]), + ) + / self.norm_factor + ) # change view to [b, np, sq, sk] attention_scores = matmul_result.reshape(output_size) @@ -157,14 +179,12 @@ def forward( with paddle.no_grad(): if layer_past is not None: attention_mask = attention_mask[ - ..., - attention_scores.shape[3] - 1, - :attention_scores.shape[3]].unsqueeze(2) + ..., attention_scores.shape[3] - 1, : attention_scores.shape[3] + ].unsqueeze(2) else: attention_mask = attention_mask[ - ..., - :attention_scores.shape[3], - :attention_scores.shape[3]] + ..., : attention_scores.shape[3], : attention_scores.shape[3] + ] if context_length is not None: attention_mask = paddle.clone(attention_mask) @@ -174,7 +194,9 @@ def forward( # attention_scores = attention_mask_func(attention_scores, attention_mask) attention_scores = attention_scores - attention_mask * 10000.0 if self.attention_softmax_in_fp32: - attention_probs = self.softmax(attention_scores.cast("float32")).cast("float16") + attention_probs = self.softmax(attention_scores.cast("float32")).cast( + "float16" + ) else: attention_probs = self.softmax(attention_scores) @@ -186,19 +208,26 @@ def forward( # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.shape[1], - value_layer.shape[2], - query_layer.shape[0], - value_layer.shape[3]) + output_size = ( + value_layer.shape[1], + value_layer.shape[2], + query_layer.shape[0], + value_layer.shape[3], + ) - # change view [sq, b * np, hn] - value_layer = value_layer.reshape([value_layer.shape[0], output_size[0] * output_size[1], -1]) + # change view [sq, b * np, hn] + value_layer = value_layer.reshape( + [value_layer.shape[0], output_size[0] * output_size[1], -1] + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.reshape([output_size[0] * output_size[1], - output_size[2], -1]) + attention_probs = attention_probs.reshape( + [output_size[0] * output_size[1], output_size[2], -1] + ) - context_layer = paddle.bmm(attention_probs, value_layer.unsqueeze(0).transpose([0, 2, 1, 3]).squeeze(0)) + context_layer = paddle.bmm( + attention_probs, value_layer.unsqueeze(0).transpose([0, 2, 1, 3]).squeeze(0) + ) # change view [b, np, sq, hn] context_layer = context_layer.reshape(output_size) @@ -207,8 +236,9 @@ def forward( context_layer = context_layer.transpose([2, 0, 1, 3]) # # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.shape[:-2] + \ - [self.hidden_size,] + new_context_layer_shape = context_layer.shape[:-2] + [ + self.hidden_size, + ] context_layer = context_layer.reshape(new_context_layer_shape) # ================= @@ -246,7 +276,9 @@ def __init__( self.layer_number = max(1, layer_number) assert self.hidden_size % self.num_attention_heads == 0 - self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) + self.hidden_size_per_attention_head = int( + self.hidden_size // self.num_attention_heads + ) self.query = paddle.nn.Linear(self.hidden_size, self.hidden_size) self.key = paddle.nn.Linear(self.hidden_size, self.hidden_size) @@ -256,7 +288,7 @@ def __init__( self.softmax = paddle.nn.Softmax(axis=-1) self.dense = paddle.nn.Linear(self.hidden_size, self.hidden_size) - + def forward( self, hidden_states, @@ -273,19 +305,22 @@ def forward( key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) - new_query_layer_shape = query_layer.shape[:-1] + \ - [self.num_attention_heads, - self.hidden_size_per_attention_head] + new_query_layer_shape = query_layer.shape[:-1] + [ + self.num_attention_heads, + self.hidden_size_per_attention_head, + ] query_layer = query_layer.reshape(new_query_layer_shape) - new_query_layer_shape = key_layer.shape[:-1] + \ - [self.num_attention_heads, - self.hidden_size_per_attention_head] + new_query_layer_shape = key_layer.shape[:-1] + [ + self.num_attention_heads, + self.hidden_size_per_attention_head, + ] key_layer = key_layer.reshape(new_query_layer_shape) - new_query_layer_shape = value_layer.shape[:-1] + \ - [self.num_attention_heads, - self.hidden_size_per_attention_head] + new_query_layer_shape = value_layer.shape[:-1] + [ + self.num_attention_heads, + self.hidden_size_per_attention_head, + ] value_layer = value_layer.reshape(new_query_layer_shape) # ================================== @@ -294,10 +329,12 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - key_layer = paddle.concat((past_key.cast(key_layer.dtype), - key_layer), axis=0) - value_layer = paddle.concat((past_value.cast(value_layer.dtype), - value_layer), axis=0) + key_layer = paddle.concat( + (past_key.cast(key_layer.dtype), key_layer), axis=0 + ) + value_layer = paddle.concat( + (past_value.cast(value_layer.dtype), value_layer), axis=0 + ) if get_key_value: present = (key_layer, value_layer) @@ -306,18 +343,29 @@ def forward( # =================================== # [b, np, sq, sk] - output_size = (query_layer.shape[1], - query_layer.shape[2], - query_layer.shape[0], - key_layer.shape[0]) + output_size = ( + query_layer.shape[1], + query_layer.shape[2], + query_layer.shape[0], + key_layer.shape[0], + ) # [s, b, np, hn] -> [s, b * np, hn] - query_layer = query_layer.reshape([output_size[2], output_size[0] * output_size[1], -1]) - key_layer = key_layer.reshape([output_size[3], output_size[0] * output_size[1], -1]) + query_layer = query_layer.reshape( + [output_size[2], output_size[0] * output_size[1], -1] + ) + key_layer = key_layer.reshape( + [output_size[3], output_size[0] * output_size[1], -1] + ) # Raw attention scores. [b * np, sq, sk] - matmul_result = paddle.matmul(query_layer.transpose([1, 0, 2]), - key_layer.transpose([1, 0, 2]).transpose([0, 2, 1])) / self.norm_factor + matmul_result = ( + paddle.matmul( + query_layer.transpose([1, 0, 2]), + key_layer.transpose([1, 0, 2]).transpose([0, 2, 1]), + ) + / self.norm_factor + ) # change view to [b, np, s, s] attention_scores = matmul_result.reshape(output_size) @@ -330,14 +378,12 @@ def forward( with paddle.no_grad(): if layer_past is not None: attention_mask = attention_mask[ - ..., - attention_scores.shape[3] - 1, - :attention_scores.shape[3]].unsqueeze(2) + ..., attention_scores.shape[3] - 1, : attention_scores.shape[3] + ].unsqueeze(2) else: attention_mask = attention_mask[ - ..., - :attention_scores.shape[3], - :attention_scores.shape[3]] + ..., : attention_scores.shape[3], : attention_scores.shape[3] + ] if context_length is not None: attention_mask = paddle.clone(attention_mask) @@ -347,10 +393,12 @@ def forward( # attention_scores = attention_mask_func(attention_scores, attention_mask) attention_scores = attention_scores - attention_mask * 10000.0 if self.attention_softmax_in_fp32: - attention_probs = self.softmax(attention_scores.cast("float32")).cast("float16") + attention_probs = self.softmax(attention_scores.cast("float32")).cast( + "float16" + ) else: attention_probs = self.softmax(attention_scores) - + # ========================= # Context layer. [sq, b, hp] # ========================= @@ -359,20 +407,27 @@ def forward( # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.shape[1], - value_layer.shape[2], - query_layer.shape[0], - value_layer.shape[3]) + output_size = ( + value_layer.shape[1], + value_layer.shape[2], + query_layer.shape[0], + value_layer.shape[3], + ) # change view [sq, b * np, hn] - value_layer = value_layer.reshape([value_layer.shape[0], output_size[0] * output_size[1], -1]) + value_layer = value_layer.reshape( + [value_layer.shape[0], output_size[0] * output_size[1], -1] + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.reshape([output_size[0] * output_size[1], - output_size[2], -1]) + attention_probs = attention_probs.reshape( + [output_size[0] * output_size[1], output_size[2], -1] + ) # matmul: [b * np, sq, hn] - context_layer = paddle.bmm(attention_probs, value_layer.unsqueeze(0).transpose([0, 2, 1, 3]).squeeze(0)) + context_layer = paddle.bmm( + attention_probs, value_layer.unsqueeze(0).transpose([0, 2, 1, 3]).squeeze(0) + ) # change view [b, np, sq, hn] context_layer = context_layer.reshape(output_size) @@ -381,8 +436,9 @@ def forward( context_layer = context_layer.transpose([2, 0, 1, 3]) # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.shape[:-2] + \ - [self.hidden_size,] + new_context_layer_shape = context_layer.shape[:-2] + [ + self.hidden_size, + ] context_layer = context_layer.reshape(new_context_layer_shape) # ================= @@ -405,10 +461,10 @@ class TransformerLayer(paddle.nn.Layer): """ def __init__( - self, + self, hidden_size, num_attention_heads, - layer_number, + layer_number, layernorm_epsilon=1e-5, fp16=True, attention_softmax_in_fp32=True, @@ -419,19 +475,23 @@ def __init__( self.layer_number = layer_number # Layernorm on the input data. - self.input_layernorm = paddle.nn.LayerNorm(hidden_size, - epsilon=self.layernorm_epsilon) + self.input_layernorm = paddle.nn.LayerNorm( + hidden_size, epsilon=self.layernorm_epsilon + ) # Self attention. - self.attention = SelfAttention(hidden_size, - num_attention_heads, - layer_number, - fp16, - attention_softmax_in_fp32) + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_number, + fp16, + attention_softmax_in_fp32, + ) # Layernorm on the input data. - self.post_attention_layernorm = paddle.nn.LayerNorm(self.hidden_size, - epsilon=self.layernorm_epsilon) + self.post_attention_layernorm = paddle.nn.LayerNorm( + self.hidden_size, epsilon=self.layernorm_epsilon + ) self.mlp = MLP(self.hidden_size) def forward( @@ -449,12 +509,14 @@ def forward( layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output = self.attention(layernorm_output, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + attention_output = self.attention( + layernorm_output, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: attention_output, presents = attention_output @@ -462,7 +524,7 @@ def forward( # Residual connection. residual = hidden_states layernorm_input = attention_output + residual - + # Use FP32 for Layernorm # layernorm_output = self.post_attention_layernorm(layernorm_input.cast("float32")).cast("float16") layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -483,7 +545,7 @@ class TopQueryLayer(paddle.nn.Layer): """ def __init__( - self, + self, hidden_size, num_attention_heads, layer_number, @@ -496,16 +558,18 @@ def __init__( self.layer_number = layer_number # Use FP32 for Layernorm - self.input_layernorm = paddle.nn.LayerNorm(self.hidden_size, - epsilon=self.layernorm_epsilon) + self.input_layernorm = paddle.nn.LayerNorm( + self.hidden_size, epsilon=self.layernorm_epsilon + ) # Self attention. - self.attention = TopQuerySelfAttention(self.hidden_size, - self.num_attention_heads, - self.layer_number) + self.attention = TopQuerySelfAttention( + self.hidden_size, self.num_attention_heads, self.layer_number + ) # Layernorm on the input data. - self.post_attention_layernorm = paddle.nn.LayerNorm(self.hidden_size, - epsilon=self.layernorm_epsilon) + self.post_attention_layernorm = paddle.nn.LayerNorm( + self.hidden_size, epsilon=self.layernorm_epsilon + ) # MLP self.mlp = MLP(self.hidden_size) @@ -528,13 +592,15 @@ def forward( layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output = self.attention(layernorm_output, - query_hidden_state, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + attention_output = self.attention( + layernorm_output, + query_hidden_state, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: attention_output, presents = attention_output @@ -542,7 +608,7 @@ def forward( # Residual connection. residual = hidden_states layernorm_input = attention_output + residual - + # Use FP32 for Layernorm # layernorm_output = self.post_attention_layernorm(layernorm_input.cast("float32")).cast("float16") layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -584,22 +650,27 @@ def __init__( if self.num_unique_layers is None: self.num_unique_layers = self.num_layers - assert self.num_layers % self.num_unique_layers == 0, \ - 'number of layers should be divisible by number of unique layers' - + assert ( + self.num_layers % self.num_unique_layers == 0 + ), "number of layers should be divisible by number of unique layers" + # Transformer layers. def build_layer(layer_number): - return TransformerLayer(self.hidden_size, self.num_attention_heads, layer_number) + return TransformerLayer( + self.hidden_size, self.num_attention_heads, layer_number + ) self.layers = paddle.nn.LayerList( - [build_layer(i + 1) for i in range(self.num_unique_layers)]) + [build_layer(i + 1) for i in range(self.num_unique_layers)] + ) - self.topQueryLayer = TopQueryLayer(self.hidden_size, - self.num_attention_heads, - self.num_unique_layers) + self.topQueryLayer = TopQueryLayer( + self.hidden_size, self.num_attention_heads, self.num_unique_layers + ) - self.final_layernorm = paddle.nn.LayerNorm(self.hidden_size, - epsilon=self.layernorm_epsilon) + self.final_layernorm = paddle.nn.LayerNorm( + self.hidden_size, epsilon=self.layernorm_epsilon + ) def _get_layer_index(self, layer_number): return layer_number % self.num_unique_layers @@ -621,7 +692,6 @@ def forward( hidden_states = hidden_states.transpose([1, 0, 2]) query_hidden_state = query_hidden_state.transpose([1, 0, 2]) - if get_key_value: presents = [] for index in range(self.num_layers): @@ -629,12 +699,14 @@ def forward( past = None if layer_past is not None: past = layer_past[index] - hidden_states = layer(hidden_states, - attention_mask, - layer_past=past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + hidden_states = layer( + hidden_states, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: hidden_states, present = hidden_states presents.append(present) @@ -649,13 +721,15 @@ def forward( past = None if layer_past is not None: past = layer_past[self.num_layers] - hidden_states = self.topQueryLayer(hidden_states_, - query_hidden_state, - attention_mask, - layer_past=past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + hidden_states = self.topQueryLayer( + hidden_states_, + query_hidden_state, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: hidden_states, present = hidden_states @@ -695,35 +769,39 @@ def __init__( self.hidden_size = hidden_size self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length - + # Word embeddings. self.word_embeddings = paddle.nn.Embedding(self.vocab_size, self.hidden_size) - self._word_embeddings_key = 'word_embeddings' - + self._word_embeddings_key = "word_embeddings" + # Position embedding. - self.position_embeddings = paddle.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.position_embeddings = paddle.nn.Embedding( + self.max_sequence_length, self.hidden_size + ) self.position_embeddings = self.position_embeddings.to(dtype="float16") - self._position_embeddings_key = 'position_embeddings' - + self._position_embeddings_key = "position_embeddings" + def forward(self, input_ids, position_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings - + return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) - + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict( + destination, prefix, keep_vars + ) + state_dict_[ + self._position_embeddings_key + ] = self.position_embeddings.state_dict(destination, prefix, keep_vars) + return state_dict_ def set_state_dict(self, state_dict, use_structured_name=True): @@ -736,11 +814,12 @@ def set_state_dict(self, state_dict, use_structured_name=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] - state_dict_["weight"] = state_dict_["weight"][:self.vocab_size] - self.word_embeddings.set_state_dict(state_dict_, use_structured_name=use_structured_name) + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] + state_dict_["weight"] = state_dict_["weight"][: self.vocab_size] + self.word_embeddings.set_state_dict( + state_dict_, use_structured_name=use_structured_name + ) # Position embedding. if self._position_embeddings_key in state_dict: @@ -749,11 +828,12 @@ def set_state_dict(self, state_dict, use_structured_name=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] - self.position_embeddings.set_state_dict(state_dict_, use_structured_name=use_structured_name) - + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] + self.position_embeddings.set_state_dict( + state_dict_, use_structured_name=use_structured_name + ) + class QueryEmbedding(paddle.nn.Layer): """Language model embeddings. @@ -778,24 +858,27 @@ def __init__( self.max_sequence_length = max_sequence_length # Top query position embedding (serial). - self.top_query_embeddings = paddle.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.top_query_embeddings = paddle.nn.Embedding( + self.max_sequence_length, self.hidden_size + ) self.top_query_embeddings = self.top_query_embeddings.to(dtype="float16") - self._top_query_embeddings_key = 'top_query_embeddings' - + self._top_query_embeddings_key = "top_query_embeddings" + def forward(self, position_ids): # Embeddings. embeddings = self.top_query_embeddings(position_ids) - + return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._top_query_embeddings_key] \ - = self.top_query_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[ + self._top_query_embeddings_key + ] = self.top_query_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ @@ -809,11 +892,12 @@ def set_state_dict(self, state_dict, use_structured_name=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'top_query_embeddings' in key: - state_dict_[key.split('top_query_embeddings.')[1]] \ - = state_dict[key] - self.top_query_embeddings.set_state_dict(state_dict_, use_structured_name=use_structured_name) - + if "top_query_embeddings" in key: + state_dict_[key.split("top_query_embeddings.")[1]] = state_dict[key] + self.top_query_embeddings.set_state_dict( + state_dict_, use_structured_name=use_structured_name + ) + class TransformerLanguageModel(paddle.nn.Layer): """Transformer language model. @@ -847,32 +931,32 @@ def __init__( self.max_position_embeddings = max_position_embeddings # Embeddings - self.embedding = Embedding(self.hidden_size, - self.padded_vocab_size, - self.max_position_embeddings) - self._embedding_key = 'embedding' + self.embedding = Embedding( + self.hidden_size, self.padded_vocab_size, self.max_position_embeddings + ) + self._embedding_key = "embedding" # Query embeddings - self.topQueryEmbedding = QueryEmbedding(self.hidden_size, - self.padded_vocab_size, - self.max_position_embeddings) - self._topQueryEmbedding_key = 'topQueryEmbedding' + self.topQueryEmbedding = QueryEmbedding( + self.hidden_size, self.padded_vocab_size, self.max_position_embeddings + ) + self._topQueryEmbedding_key = "topQueryEmbedding" # Transformer - self.transformer = Transformer(self.hidden_size, - self.num_attention_heads, - self.num_layers) - self._transformer_key = 'transformer' + self.transformer = Transformer( + self.hidden_size, self.num_attention_heads, self.num_layers + ) + self._transformer_key = "transformer" def forward( - self, - input_ids, - position_ids, - attention_mask, - layer_past=None, - get_key_value=False, - prompt_length=None, - context_length=None, + self, + input_ids, + position_ids, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, ): # Embeddings. @@ -881,30 +965,39 @@ def forward( queryEmbedding_out = self.topQueryEmbedding(query_position_ids) # Transformer. - transformer_output = self.transformer(embedding_output, - queryEmbedding_out, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + transformer_output = self.transformer( + embedding_output, + queryEmbedding_out, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) return transformer_output - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._embedding_key] \ - = self.embedding.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._topQueryEmbedding_key] \ - = self.topQueryEmbedding.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._transformer_key] \ - = self.transformer.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[ + self._embedding_key + ] = self.embedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[ + self._topQueryEmbedding_key + ] = self.topQueryEmbedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[ + self._transformer_key + ] = self.transformer.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) return state_dict_ @@ -918,9 +1011,11 @@ def set_state_dict(self, state_dict, use_structured_name=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if '_embeddings' in key: + if "_embeddings" in key: state_dict_[key] = state_dict[key] - self.embedding.set_state_dict(state_dict_, use_structured_name=use_structured_name) + self.embedding.set_state_dict( + state_dict_, use_structured_name=use_structured_name + ) if self._topQueryEmbedding_key in state_dict: state_dict_ = state_dict[self._topQueryEmbedding_key] @@ -928,9 +1023,11 @@ def set_state_dict(self, state_dict, use_structured_name=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if '_embeddings' in key: + if "_embeddings" in key: state_dict_[key] = state_dict[key] - self.topQueryEmbedding.set_state_dict(state_dict_, use_structured_name=use_structured_name) + self.topQueryEmbedding.set_state_dict( + state_dict_, use_structured_name=use_structured_name + ) # Transformer. if self._transformer_key in state_dict: @@ -939,9 +1036,11 @@ def set_state_dict(self, state_dict, use_structured_name=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'transformer.' in key: - state_dict_[key.split('transformer.')[1]] = state_dict[key] - self.transformer.set_state_dict(state_dict_, use_structured_name=use_structured_name) + if "transformer." in key: + state_dict_[key.split("transformer.")[1]] = state_dict[key] + self.transformer.set_state_dict( + state_dict_, use_structured_name=use_structured_name + ) class CodeGeeXModel(paddle.nn.Layer): @@ -956,14 +1055,16 @@ def __init__( max_position_embeddings, ): super(CodeGeeXModel, self).__init__() - - self.language_model = TransformerLanguageModel(hidden_size, - num_layers, - num_attention_heads, - padded_vocab_size, - max_position_embeddings) + + self.language_model = TransformerLanguageModel( + hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings, + ) self._language_model_key = "language_model" - + def forward( self, input_ids, @@ -975,31 +1076,41 @@ def forward( context_length=None, ): # Language model. - lm_output = self.language_model(input_ids, - position_ids, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: lm_output, presents = lm_output - output = F.linear(lm_output, self.language_model.embedding.word_embeddings.weight.cast("float16").transpose([1, 0])) - + output = F.linear( + lm_output, + self.language_model.embedding.word_embeddings.weight.cast( + "float16" + ).transpose([1, 0]), + ) + if get_key_value: output = [output, presents] return output - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[ + self._language_model_key + ] = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) return state_dict_ def set_state_dict(self, state_dict, use_structured_name=True): @@ -1007,4 +1118,6 @@ def set_state_dict(self, state_dict, use_structured_name=True): if self._language_model_key in state_dict: state_dict = state_dict[self._language_model_key] - self.language_model.set_state_dict(state_dict, use_structured_name=use_structured_name) + self.language_model.set_state_dict( + state_dict, use_structured_name=use_structured_name + ) diff --git a/codegeex/paddle/inference.py b/codegeex/paddle/inference.py index cbc8158..1a32d9b 100644 --- a/codegeex/paddle/inference.py +++ b/codegeex/paddle/inference.py @@ -10,10 +10,10 @@ def get_ltor_masks_and_position_ids( - data, - eod_token, - reset_position_ids, - reset_attention_mask, + data, + eod_token, + reset_position_ids, + reset_attention_mask, ): """Build masks and position id for left to right model.""" @@ -65,9 +65,9 @@ def get_ltor_masks_and_position_ids( def get_batch( - context_tokens, - micro_batch_size, - eod_token, + context_tokens, + micro_batch_size, + eod_token, reset_position_ids=False, reset_attention_mask=False, ): @@ -125,15 +125,15 @@ def pad_batch(batch, pad_id, seq_length): def forward_step( - model, - tokens, - seq_length, - position_ids, - attention_mask, - layer_past=None, - get_key_value=None, - prompt_length=None, - context_length=None, + model, + tokens, + seq_length, + position_ids, + attention_mask, + layer_past=None, + get_key_value=None, + prompt_length=None, + context_length=None, ): # Forward pass through the model. output_tensor = model( @@ -156,28 +156,30 @@ def forward_step( def get_token_stream( - model, - tokenizer, - seq_length, - out_seq_length, - context_tokens, - return_scores: bool = False, - prompt_length: int = None, - micro_batch_size: int = None, - bad_ids: List = None, - temperature: float = 1.0, - topp: float = 1.0, - topk: int = 0.0, - greedy: bool = False, - recompute: bool = False, + model, + tokenizer, + seq_length, + out_seq_length, + context_tokens, + return_scores: bool = False, + prompt_length: int = None, + micro_batch_size: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + greedy: bool = False, + recompute: bool = False, ): - context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length) + context_tokens, context_lengths = pad_batch( + context_tokens, tokenizer.eos_token_id, seq_length + ) context_tokens_tensor = paddle.to_tensor(context_tokens, dtype="int64") context_length_tensor = paddle.to_tensor(context_lengths, dtype="int64") context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch( - context_tokens_tensor, + context_tokens_tensor, micro_batch_size, tokenizer.eos_token_id, ) @@ -215,23 +217,23 @@ def switch(val1, val2, boolean): def sample_sequence_batch( - model, - tokenizer, - context_tokens, - context_lengths, - attention_mask, - position_ids, - seq_length, - out_seq_length, - maxlen=None, - return_scores: bool = False, - prompt_length: int = None, - bad_ids: List = None, - temperature: float = 1.0, - topp: float = 1.0, - topk: int = 0.0, - recompute: bool = False, - greedy: bool = False, + model, + tokenizer, + context_tokens, + context_lengths, + attention_mask, + position_ids, + seq_length, + out_seq_length, + maxlen=None, + return_scores: bool = False, + prompt_length: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + recompute: bool = False, + greedy: bool = False, ): model.eval() with paddle.no_grad(): @@ -257,30 +259,32 @@ def sample_sequence_batch( while context_length <= (maxlen): if recompute: - logits = model(tokens, - position_ids, - attention_mask, - prompt_length=prompt_length, - context_length=context_length, - ) + logits = model( + tokens, + position_ids, + attention_mask, + prompt_length=prompt_length, + context_length=context_length, + ) logits = logits[:, context_length - 1, :] else: if counter == 0: tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] else: - tokens2use = tokens[:, context_length - 1].reshape([ - batch_size, -1]) - positions2use = position_ids[:, context_length - 1].reshape([ - batch_size, -1]) - logits, layer_past = model(tokens2use, - positions2use, - attention_mask, - layer_past=layer_past, - get_key_value=True, - prompt_length=prompt_length, - context_length=context_length, - ) + tokens2use = tokens[:, context_length - 1].reshape([batch_size, -1]) + positions2use = position_ids[:, context_length - 1].reshape( + [batch_size, -1] + ) + logits, layer_past = model( + tokens2use, + positions2use, + attention_mask, + layer_past=layer_past, + get_key_value=True, + prompt_length=prompt_length, + context_length=context_length, + ) logits = logits[:, -1].reshape([batch_size, -1]) if bad_ids is not None: @@ -314,12 +318,12 @@ def sample_sequence_batch( lengths[just_finished.reshape([-1])] = context_length is_done = is_done | done_token done = paddle.all(is_done.cast("bool")) - + if return_scores: yield tokens, (lengths, scores) else: yield tokens, lengths - + context_length += 1 counter += 1 if done: diff --git a/codegeex/paddle/pt_to_pdparams.py b/codegeex/paddle/pt_to_pdparams.py index c2638f4..fb3260d 100644 --- a/codegeex/paddle/pt_to_pdparams.py +++ b/codegeex/paddle/pt_to_pdparams.py @@ -28,17 +28,9 @@ def WalkDict(x): def parse_opt(): parser = argparse.ArgumentParser() + parser.add_argument("--pt", type=str, required=True, help="Path to pt checkpoint.") parser.add_argument( - "--pt", - type=str, - required=True, - help="Path to pt checkpoint." - ) - parser.add_argument( - "--pdparams", - type=str, - required=True, - help="Path to pdparams checkpoint." + "--pdparams", type=str, required=True, help="Path to pdparams checkpoint." ) opt = parser.parse_args() return opt diff --git a/codegeex/quantization/quantize.py b/codegeex/quantization/quantize.py index dbca608..25eccc8 100644 --- a/codegeex/quantization/quantize.py +++ b/codegeex/quantization/quantize.py @@ -3,12 +3,23 @@ from torch.nn.parameter import Parameter from codegeex.kernels import extract_weight_to_half from codegeex.megatron.mpu.layers import RowParallelLinear, ColumnParallelLinear -from codegeex.megatron.mpu.mappings import copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region +from codegeex.megatron.mpu.mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_tensor_model_parallel_region, + reduce_from_tensor_model_parallel_region, + scatter_to_tensor_model_parallel_region, +) class W8A16Linear(torch.autograd.Function): @staticmethod - def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): + def forward( + ctx, + inp: torch.Tensor, + quant_w: torch.Tensor, + scale_w: torch.Tensor, + weight_bit_width, + ): ctx.inp_shape = inp.size() ctx.weight_shape = quant_w.size() ctx.weight_bit_width = weight_bit_width @@ -31,46 +42,59 @@ def backward(ctx, grad_output: torch.Tensor): class QuantizedLinear(torch.nn.Module): def __init__( - self, + self, in_features: int, out_features: int, - weight_bit_width: int, - weight: torch.Tensor = None, - bias: torch.Tensor = None, - *args, - **kwargs + weight_bit_width: int, + weight: torch.Tensor = None, + bias: torch.Tensor = None, + *args, + **kwargs, ): super(QuantizedLinear, self).__init__() - + self.in_features = in_features self.out_features = out_features self.weight_bit_width = weight_bit_width if weight is None: self.weight = torch.empty( - self.out_features, self.in_features * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] + self.out_features, + self.in_features * weight_bit_width // 8, + dtype=torch.int8, + device=kwargs["device"], + ) + self.weight_scale = torch.empty( + self.out_features, dtype=kwargs["params_dtype"], device=kwargs["device"] ) - self.weight_scale = torch.empty(self.out_features, dtype=kwargs["params_dtype"], device=kwargs["device"]) else: - self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() - self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) + self.weight_scale = ( + weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) + ).half() + self.weight = torch.round(weight / self.weight_scale[:, None]).to( + torch.int8 + ) if weight_bit_width == 4: self.weight = compress_int4_weight(self.weight) if bias is None: - self.register_parameter('bias', None) + self.register_parameter("bias", None) else: self.bias = bias - + self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) + self.weight_scale = Parameter( + self.weight_scale.to(kwargs["device"]), requires_grad=False + ) def forward(self, input_): # Matrix multiply. - output = W8A16Linear.apply(input_, self.weight, self.weight_scale, self.weight_bit_width) + output = W8A16Linear.apply( + input_, self.weight, self.weight_scale, self.weight_bit_width + ) if self.bias is not None: output = output + self.bias - + return output @@ -79,13 +103,15 @@ def __init__( self, input_size: int, output_size: int, - weight_bit_width: int, - weight: torch.Tensor = None, - bias: torch.Tensor = None, - *args, + weight_bit_width: int, + weight: torch.Tensor = None, + bias: torch.Tensor = None, + *args, **kwargs, ): - super(QuantizedColumnParallelLinear, self).__init__(input_size, output_size, *args, **kwargs) + super(QuantizedColumnParallelLinear, self).__init__( + input_size, output_size, *args, **kwargs + ) self.input_size = input_size self.output_size = output_size self.weight_bit_width = weight_bit_width @@ -97,29 +123,42 @@ def __init__( if weight is None: self.weight = torch.empty( - self.output_size, self.input_size * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] + self.output_size, + self.input_size * weight_bit_width // 8, + dtype=torch.int8, + device=kwargs["device"], + ) + self.weight_scale = torch.empty( + self.output_size, dtype=kwargs["params_dtype"], device=kwargs["device"] ) - self.weight_scale = torch.empty(self.output_size, dtype=kwargs["params_dtype"], device=kwargs["device"]) else: - self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() - self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) + self.weight_scale = ( + weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) + ).half() + self.weight = torch.round(weight / self.weight_scale[:, None]).to( + torch.int8 + ) if weight_bit_width == 4: self.weight = compress_int4_weight(self.weight) if bias is None: - self.register_parameter('bias', None) + self.register_parameter("bias", None) else: del self.bias self.bias = bias - + self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) + self.weight_scale = Parameter( + self.weight_scale.to(kwargs["device"]), requires_grad=False + ) def forward(self, input_): # Set up backprop all-reduce. input_parallel = copy_to_tensor_model_parallel_region(input_) # Matrix multiply. - output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width) + output_parallel = W8A16Linear.apply( + input_parallel, self.weight, self.weight_scale, self.weight_bit_width + ) if self.bias is not None and not self.skip_bias_add: output_parallel = output_parallel + self.bias if self.gather_output: @@ -127,9 +166,9 @@ def forward(self, input_): output = gather_from_tensor_model_parallel_region(output_parallel) else: output = output_parallel - + output_bias = self.bias if self.skip_bias_add else None - + return output, output_bias @@ -138,13 +177,15 @@ def __init__( self, input_size: int, output_size: int, - weight_bit_width: int, - weight: torch.Tensor = None, - bias: torch.Tensor = None, - *args, + weight_bit_width: int, + weight: torch.Tensor = None, + bias: torch.Tensor = None, + *args, **kwargs, ): - super(QuantizedRowParallelLinear, self).__init__(input_size, output_size, *args, **kwargs) + super(QuantizedRowParallelLinear, self).__init__( + input_size, output_size, *args, **kwargs + ) self.input_size = input_size self.output_size = output_size self.weight_bit_width = weight_bit_width @@ -153,26 +194,37 @@ def __init__( else: self.skip_bias_add = False del self.weight - + if weight is None: self.weight = torch.empty( - self.output_size, self.input_size * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] + self.output_size, + self.input_size * weight_bit_width // 8, + dtype=torch.int8, + device=kwargs["device"], + ) + self.weight_scale = torch.empty( + self.output_size, dtype=kwargs["params_dtype"], device=kwargs["device"] ) - self.weight_scale = torch.empty(self.output_size, dtype=kwargs["params_dtype"], device=kwargs["device"]) else: - self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() - self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) + self.weight_scale = ( + weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) + ).half() + self.weight = torch.round(weight / self.weight_scale[:, None]).to( + torch.int8 + ) if weight_bit_width == 4: self.weight = compress_int4_weight(self.weight) if bias is None: - self.register_parameter('bias', None) + self.register_parameter("bias", None) else: del self.bias self.bias = bias - + self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) + self.weight_scale = Parameter( + self.weight_scale.to(kwargs["device"]), requires_grad=False + ) def forward(self, input_): # Set up backprop all-reduce. @@ -181,7 +233,9 @@ def forward(self, input_): else: input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. - output_parallel = W8A16Linear.apply(input_parallel, self.weight, self.weight_scale, self.weight_bit_width) + output_parallel = W8A16Linear.apply( + input_parallel, self.weight, self.weight_scale, self.weight_bit_width + ) # All-reduce across all the partitions. output_ = reduce_from_tensor_model_parallel_region(output_parallel) if self.bias is not None and not self.skip_bias_add: @@ -189,19 +243,19 @@ def forward(self, input_): else: output = output_ output_bias = self.bias if self.skip_bias_add else None - + return output, output_bias - + def quantize(model, weight_bit_width, backend="torch"): """Replace fp16 linear with quantized linear""" - + for i in range(len(model.language_model.transformer.layers) + 1): if i == len(model.language_model.transformer.layers): layer = model.language_model.transformer.topQueryLayer else: layer = model.language_model.transformer.layers[i] - + if backend == "torch": layer.attention.query = QuantizedLinear( in_features=layer.attention.query.in_features, @@ -325,5 +379,5 @@ def quantize(model, weight_bit_width, backend="torch"): params_dtype=torch.half, device=layer.mlp.dense_4h_to_h.weight.device, ) - - return model \ No newline at end of file + + return model diff --git a/codegeex/quantization/quantize_oneflow.py b/codegeex/quantization/quantize_oneflow.py index e752b2f..3e5208d 100644 --- a/codegeex/quantization/quantize_oneflow.py +++ b/codegeex/quantization/quantize_oneflow.py @@ -1,7 +1,8 @@ -import numpy as np +import numpy as np import oneflow as torch from oneflow.nn.parameter import Parameter + def _pack_int8_to_int4(x): np_x = x.numpy() l = np_x[..., 0::2] @@ -35,12 +36,14 @@ def _quantize(num_bits, symmetric, x, group_dim, group_size, quant_type): quantized = _pack_int8_to_int4(quantized) return (quantized, scale_float.squeeze(group_dim + 1).to(x.dtype), None) else: - unsigned_max = float(2 ** num_bits) - 1 + unsigned_max = float(2**num_bits) - 1 mn = x_reshaped.min(dim=group_dim + 1, keepdim=True).values mx = x_reshaped.max(dim=group_dim + 1, keepdim=True).values scale_float = (mx - mn) / unsigned_max quantized = ( - torch.round((x_reshaped - mn) / scale_float).reshape(x.shape).to(torch.uint8) + torch.round((x_reshaped - mn) / scale_float) + .reshape(x.shape) + .to(torch.uint8) ) if num_bits == 4: quantized = _pack_int8_to_int4(quantized) @@ -50,19 +53,20 @@ def _quantize(num_bits, symmetric, x, group_dim, group_size, quant_type): mn.squeeze(group_dim + 1).to(x.dtype), ) + class QuantizedLinear(torch.nn.Module): def __init__( - self, + self, in_features: int, out_features: int, - weight_bit_width: int, - weight: torch.Tensor = None, - bias: torch.Tensor = None, - *args, + weight_bit_width: int, + weight: torch.Tensor = None, + bias: torch.Tensor = None, + *args, **kwargs ): super(QuantizedLinear, self).__init__() - + self.in_features = in_features self.out_features = out_features self.weight_bit_width = weight_bit_width @@ -71,44 +75,56 @@ def __init__( self.group_size = in_features self.weight, self.weight_scale, self.weight_zero = _quantize( - self.weight_bit_width, self.symmetric, weight, self.group_dim, self.group_size, torch.int8 + self.weight_bit_width, + self.symmetric, + weight, + self.group_dim, + self.group_size, + torch.int8, ) if bias is None: - self.register_parameter('bias', None) + self.register_parameter("bias", None) else: self.bias = bias self.bias = self.bias.to(kwargs["device"]) - + self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) + self.weight_scale = Parameter( + self.weight_scale.to(kwargs["device"]), requires_grad=False + ) if self.bias is not None: self.bias = Parameter(self.bias.to(kwargs["device"]), requires_grad=False) if self.weight_zero is not None: - self.weight_zero = Parameter(self.weight_zero.to(kwargs["device"]), requires_grad=False) + self.weight_zero = Parameter( + self.weight_zero.to(kwargs["device"]), requires_grad=False + ) def forward(self, input_): # Matrix multiply. - output = torch._C.fused_linear_with_groupwise_quantized_weight(input_, - w=self.weight, - w_scale=self.weight_scale, - w_zero=self.weight_zero, - b=self.bias if self.bias is not None else None, - num_bits=self.weight_bit_width, - symmetric=self.symmetric, - group_dim=self.group_dim, - group_size=self.group_size) - + output = torch._C.fused_linear_with_groupwise_quantized_weight( + input_, + w=self.weight, + w_scale=self.weight_scale, + w_zero=self.weight_zero, + b=self.bias if self.bias is not None else None, + num_bits=self.weight_bit_width, + symmetric=self.symmetric, + group_dim=self.group_dim, + group_size=self.group_size, + ) + return output + def quantize_oneflow(model, weight_bit_width): """Replace fp16 linear with quantized linear""" - + for i in range(len(model.language_model.transformer.layers) + 1): if i == len(model.language_model.transformer.layers): layer = model.language_model.transformer.topQueryLayer else: layer = model.language_model.transformer.layers[i] - + layer.attention.query = QuantizedLinear( in_features=layer.attention.query.in_features, out_features=layer.attention.query.out_features, @@ -163,6 +179,5 @@ def quantize_oneflow(model, weight_bit_width): params_dtype=torch.half, device=layer.mlp.dense_4h_to_h.weight.device, ) - - - return model \ No newline at end of file + + return model diff --git a/codegeex/tokenizer/__init__.py b/codegeex/tokenizer/__init__.py index 257e5b2..b5734b4 100644 --- a/codegeex/tokenizer/__init__.py +++ b/codegeex/tokenizer/__init__.py @@ -1 +1 @@ -from .tokenizer import CodeGeeXTokenizer \ No newline at end of file +from .tokenizer import CodeGeeXTokenizer diff --git a/codegeex/tokenizer/tokenizer.py b/codegeex/tokenizer/tokenizer.py index d2f4f48..fa7cfcd 100644 --- a/codegeex/tokenizer/tokenizer.py +++ b/codegeex/tokenizer/tokenizer.py @@ -5,7 +5,7 @@ def encode_whitespaces(text: str, start_extra_id: int, max_len: int): - """ Encode whitespaces to extra tokens. + """Encode whitespaces to extra tokens. >>> encode_whitespaces('a\\n b\\n c', 10, 10) 'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c' @@ -16,7 +16,7 @@ def encode_whitespaces(text: str, start_extra_id: int, max_len: int): def decode_whitespaces(text: str, start_extra_id: int, max_len: int): - """ Decode the whitespace-encoded strings produced by encode_whitespace. + """Decode the whitespace-encoded strings produced by encode_whitespace. >>> text = 'a\\n b\\n c' >>> s, l = 10, 10 @@ -25,39 +25,47 @@ def decode_whitespaces(text: str, start_extra_id: int, max_len: int): """ for l in range(2, max_len + 1): token_id = start_extra_id - 2 + l - token = f'<|extratoken_{token_id}|>' - text = text.replace(token, ' ' * l) + token = f"<|extratoken_{token_id}|>" + text = text.replace(token, " " * l) return text - + class CodeGeeXTokenizer(object): def __init__( - self, - tokenizer: GPT2TokenizerFast = None, - tokenizer_path: str = "EleutherAI/gpt-j-6B", + self, + tokenizer: GPT2TokenizerFast = None, + tokenizer_path: str = "EleutherAI/gpt-j-6B", start_extra_id: int = 10, - max_len : int = 10, - mode='codegeex-13b', + max_len: int = 10, + mode="codegeex-13b", dict_file: str = None, ): - self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(tokenizer_path) - if mode not in ['codegeex-13b']: + self.tokenizer = ( + tokenizer + if tokenizer is not None + else AutoTokenizer.from_pretrained(tokenizer_path) + ) + if mode not in ["codegeex-13b"]: raise ValueError(f"Invalid mode {mode}, choose from ['codegeex-13b']") self.start_extra_id = start_extra_id self.max_len = max_len self.mode = mode self.eos_token_id = self.tokenizer.eos_token_id - + def encode_code(self, code: str): - if self.mode == 'codegeex-13b': + if self.mode == "codegeex-13b": code = encode_whitespaces(code, self.start_extra_id, self.max_len) - input_ids = self.tokenizer(code, is_split_into_words=False, verbose=False).input_ids - + input_ids = self.tokenizer( + code, is_split_into_words=False, verbose=False + ).input_ids + return input_ids - + def decode_code(self, input_ids): - if self.mode == 'codegeex-13b': - text = self.tokenizer.decode(input_ids, skip_special_tokens=False, verbose=False) + if self.mode == "codegeex-13b": + text = self.tokenizer.decode( + input_ids, skip_special_tokens=False, verbose=False + ) output_code = decode_whitespaces(text, self.start_extra_id, self.max_len) - - return output_code \ No newline at end of file + + return output_code diff --git a/codegeex/torch/__init__.py b/codegeex/torch/__init__.py index 16975c0..f97fbb6 100644 --- a/codegeex/torch/__init__.py +++ b/codegeex/torch/__init__.py @@ -1 +1 @@ -from .codegeex_model import CodeGeeXModel \ No newline at end of file +from .codegeex_model import CodeGeeXModel diff --git a/codegeex/torch/codegeex_model.py b/codegeex/torch/codegeex_model.py index d3b9f0a..8d437de 100644 --- a/codegeex/torch/codegeex_model.py +++ b/codegeex/torch/codegeex_model.py @@ -6,7 +6,11 @@ def fast_gelu(x): """Mindspore's fast gelu implementation.""" - return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(0.851 * (x - torch.abs(x))) + return ( + x + / (1 + torch.exp(-1.702 * torch.abs(x))) + * torch.exp(0.851 * (x - torch.abs(x))) + ) class MLP(torch.nn.Module): @@ -19,7 +23,7 @@ class MLP(torch.nn.Module): """ def __init__( - self, + self, hidden_size, ): super(MLP, self).__init__() @@ -46,7 +50,7 @@ def forward(self, hidden_states): output = self.dense_4h_to_h(intermediate_parallel) return output - + class SelfAttention(torch.nn.Module): """self-attention layer abstract class. @@ -56,9 +60,9 @@ class SelfAttention(torch.nn.Module): """ def __init__( - self, + self, hidden_size, - num_attention_heads, + num_attention_heads, layer_number, fp16=True, attention_softmax_in_fp32=True, @@ -71,8 +75,10 @@ def __init__( self.layer_number = max(1, layer_number) assert self.hidden_size % self.num_attention_heads == 0 - self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) - + self.hidden_size_per_attention_head = int( + self.hidden_size // self.num_attention_heads + ) + self.query = torch.nn.Linear(self.hidden_size, self.hidden_size) self.key = torch.nn.Linear(self.hidden_size, self.hidden_size) self.value = torch.nn.Linear(self.hidden_size, self.hidden_size) @@ -101,19 +107,22 @@ def forward( key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) - new_query_layer_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) query_layer = query_layer.view(*new_query_layer_shape) - new_query_layer_shape = key_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) key_layer = key_layer.view(*new_query_layer_shape) - new_query_layer_shape = value_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) value_layer = value_layer.view(*new_query_layer_shape) # ================================== @@ -122,10 +131,10 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), - key_layer), dim=0) - value_layer = torch.cat((past_value.type_as(value_layer), - value_layer), dim=0) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) if get_key_value: present = (key_layer, value_layer) @@ -134,18 +143,28 @@ def forward( # =================================== # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) - key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) + query_layer = query_layer.contiguous().view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.contiguous().view( + output_size[3], output_size[0] * output_size[1], -1 + ) # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.matmul(query_layer.transpose(0, 1), - key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor + matmul_result = ( + torch.matmul( + query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2) + ) + / self.norm_factor + ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -158,14 +177,12 @@ def forward( with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ - ..., - attention_scores.size(3) - 1, - :attention_scores.size(3)].unsqueeze(2) + ..., attention_scores.size(3) - 1, : attention_scores.size(3) + ].unsqueeze(2) else: attention_mask = attention_mask[ - ..., - :attention_scores.size(3), - :attention_scores.size(3)] + ..., : attention_scores.size(3), : attention_scores.size(3) + ] if context_length is not None: attention_mask = torch.clone(attention_mask) @@ -187,19 +204,26 @@ def forward( # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) - # change view [sq, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [sq, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) - context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) + context_layer = torch.bmm( + attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0) + ) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) @@ -208,8 +232,7 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) context_layer = context_layer.view(*new_context_layer_shape) # ================= @@ -247,7 +270,9 @@ def __init__( self.layer_number = max(1, layer_number) assert self.hidden_size % self.num_attention_heads == 0 - self.hidden_size_per_attention_head = int(self.hidden_size // self.num_attention_heads) + self.hidden_size_per_attention_head = int( + self.hidden_size // self.num_attention_heads + ) self.query = torch.nn.Linear(self.hidden_size, self.hidden_size) self.key = torch.nn.Linear(self.hidden_size, self.hidden_size) @@ -257,7 +282,7 @@ def __init__( self.softmax = torch.nn.Softmax(dim=-1) self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size) - + def forward( self, hidden_states, @@ -274,19 +299,22 @@ def forward( key_layer = self.key(hidden_states) value_layer = self.value(hidden_states) - new_query_layer_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = query_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) query_layer = query_layer.view(*new_query_layer_shape) - new_query_layer_shape = key_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = key_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) key_layer = key_layer.view(*new_query_layer_shape) - new_query_layer_shape = value_layer.size()[:-1] + \ - (self.num_attention_heads, - self.hidden_size_per_attention_head) + new_query_layer_shape = value_layer.size()[:-1] + ( + self.num_attention_heads, + self.hidden_size_per_attention_head, + ) value_layer = value_layer.view(*new_query_layer_shape) # ================================== @@ -295,10 +323,10 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), - key_layer), dim=0) - value_layer = torch.cat((past_value.type_as(value_layer), - value_layer), dim=0) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) if get_key_value: present = (key_layer, value_layer) @@ -307,18 +335,28 @@ def forward( # =================================== # [b, np, sq, sk] - output_size = (query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0)) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) # [s, b, np, hn] -> [s, b * np, hn] - query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) - key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) + query_layer = query_layer.contiguous().view( + output_size[2], output_size[0] * output_size[1], -1 + ) + key_layer = key_layer.contiguous().view( + output_size[3], output_size[0] * output_size[1], -1 + ) # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.matmul(query_layer.transpose(0, 1), - key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor + matmul_result = ( + torch.matmul( + query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2) + ) + / self.norm_factor + ) # change view to [b, np, s, s] attention_scores = matmul_result.view(*output_size) @@ -331,14 +369,12 @@ def forward( with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ - ..., - attention_scores.size(3) - 1, - :attention_scores.size(3)].unsqueeze(2) + ..., attention_scores.size(3) - 1, : attention_scores.size(3) + ].unsqueeze(2) else: attention_mask = attention_mask[ - ..., - :attention_scores.size(3), - :attention_scores.size(3)] + ..., : attention_scores.size(3), : attention_scores.size(3) + ] if context_length is not None: attention_mask = torch.clone(attention_mask) @@ -351,7 +387,7 @@ def forward( attention_probs = self.softmax(attention_scores.float()).half() else: attention_probs = self.softmax(attention_scores) - + # ========================= # Context layer. [sq, b, hp] # ========================= @@ -360,20 +396,27 @@ def forward( # [sq, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3)) + output_size = ( + value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3), + ) # change view [sq, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1 + ) # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], - output_size[2], -1) + attention_probs = attention_probs.view( + output_size[0] * output_size[1], output_size[2], -1 + ) # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0)) + context_layer = torch.bmm( + attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0) + ) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) @@ -382,8 +425,7 @@ def forward( context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) context_layer = context_layer.view(*new_context_layer_shape) # ================= @@ -406,10 +448,10 @@ class TransformerLayer(torch.nn.Module): """ def __init__( - self, + self, hidden_size, num_attention_heads, - layer_number, + layer_number, layernorm_epsilon=1e-5, fp16=True, attention_softmax_in_fp32=True, @@ -420,19 +462,23 @@ def __init__( self.layer_number = layer_number # Layernorm on the input data. - self.input_layernorm = torch.nn.LayerNorm(hidden_size, - eps=self.layernorm_epsilon) + self.input_layernorm = torch.nn.LayerNorm( + hidden_size, eps=self.layernorm_epsilon + ) # Self attention. - self.attention = SelfAttention(hidden_size, - num_attention_heads, - layer_number, - fp16, - attention_softmax_in_fp32) + self.attention = SelfAttention( + hidden_size, + num_attention_heads, + layer_number, + fp16, + attention_softmax_in_fp32, + ) # Layernorm on the input data. - self.post_attention_layernorm = torch.nn.LayerNorm(self.hidden_size, - eps=self.layernorm_epsilon) + self.post_attention_layernorm = torch.nn.LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon + ) self.mlp = MLP(self.hidden_size) def forward( @@ -450,12 +496,14 @@ def forward( layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output = self.attention(layernorm_output, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + attention_output = self.attention( + layernorm_output, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: attention_output, presents = attention_output @@ -463,7 +511,7 @@ def forward( # Residual connection. residual = hidden_states layernorm_input = attention_output + residual - + # Use FP32 for Layernorm # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -484,7 +532,7 @@ class TopQueryLayer(torch.nn.Module): """ def __init__( - self, + self, hidden_size, num_attention_heads, layer_number, @@ -497,16 +545,18 @@ def __init__( self.layer_number = layer_number # Use FP32 for Layernorm - self.input_layernorm = torch.nn.LayerNorm(self.hidden_size, - eps=self.layernorm_epsilon) + self.input_layernorm = torch.nn.LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon + ) # Self attention. - self.attention = TopQuerySelfAttention(self.hidden_size, - self.num_attention_heads, - self.layer_number) + self.attention = TopQuerySelfAttention( + self.hidden_size, self.num_attention_heads, self.layer_number + ) # Layernorm on the input data. - self.post_attention_layernorm = torch.nn.LayerNorm(self.hidden_size, - eps=self.layernorm_epsilon) + self.post_attention_layernorm = torch.nn.LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon + ) # MLP self.mlp = MLP(self.hidden_size) @@ -529,13 +579,15 @@ def forward( layernorm_output = self.input_layernorm(hidden_states) # Self attention. - attention_output = self.attention(layernorm_output, - query_hidden_state, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + attention_output = self.attention( + layernorm_output, + query_hidden_state, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: attention_output, presents = attention_output @@ -543,7 +595,7 @@ def forward( # Residual connection. residual = hidden_states layernorm_input = attention_output + residual - + # Use FP32 for Layernorm # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half() layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -585,22 +637,27 @@ def __init__( if self.num_unique_layers is None: self.num_unique_layers = self.num_layers - assert self.num_layers % self.num_unique_layers == 0, \ - 'number of layers should be divisible by number of unique layers' - + assert ( + self.num_layers % self.num_unique_layers == 0 + ), "number of layers should be divisible by number of unique layers" + # Transformer layers. def build_layer(layer_number): - return TransformerLayer(self.hidden_size, self.num_attention_heads, layer_number) + return TransformerLayer( + self.hidden_size, self.num_attention_heads, layer_number + ) self.layers = torch.nn.ModuleList( - [build_layer(i + 1) for i in range(self.num_unique_layers)]) + [build_layer(i + 1) for i in range(self.num_unique_layers)] + ) - self.topQueryLayer = TopQueryLayer(self.hidden_size, - self.num_attention_heads, - self.num_unique_layers) + self.topQueryLayer = TopQueryLayer( + self.hidden_size, self.num_attention_heads, self.num_unique_layers + ) - self.final_layernorm = torch.nn.LayerNorm(self.hidden_size, - eps=self.layernorm_epsilon) + self.final_layernorm = torch.nn.LayerNorm( + self.hidden_size, eps=self.layernorm_epsilon + ) def _get_layer_index(self, layer_number): return layer_number % self.num_unique_layers @@ -622,7 +679,6 @@ def forward( hidden_states = hidden_states.transpose(0, 1).contiguous() query_hidden_state = query_hidden_state.transpose(0, 1).contiguous() - if get_key_value: presents = [] for index in range(self.num_layers): @@ -630,12 +686,14 @@ def forward( past = None if layer_past is not None: past = layer_past[index] - hidden_states = layer(hidden_states, - attention_mask, - layer_past=past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + hidden_states = layer( + hidden_states, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: hidden_states, present = hidden_states presents.append(present) @@ -650,13 +708,15 @@ def forward( past = None if layer_past is not None: past = layer_past[self.num_layers] - hidden_states = self.topQueryLayer(hidden_states_, - query_hidden_state, - attention_mask, - layer_past=past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + hidden_states = self.topQueryLayer( + hidden_states_, + query_hidden_state, + attention_mask, + layer_past=past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: hidden_states, present = hidden_states @@ -696,35 +756,39 @@ def __init__( self.hidden_size = hidden_size self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length - + # Word embeddings. self.word_embeddings = torch.nn.Embedding(self.vocab_size, self.hidden_size) - self._word_embeddings_key = 'word_embeddings' - + self._word_embeddings_key = "word_embeddings" + # Position embedding. - self.position_embeddings = torch.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.position_embeddings = torch.nn.Embedding( + self.max_sequence_length, self.hidden_size + ) self.position_embeddings = self.position_embeddings.half() - self._position_embeddings_key = 'position_embeddings' - + self._position_embeddings_key = "position_embeddings" + def forward(self, input_ids, position_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings - + return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) - + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict( + destination, prefix, keep_vars + ) + state_dict_[ + self._position_embeddings_key + ] = self.position_embeddings.state_dict(destination, prefix, keep_vars) + return state_dict_ def load_state_dict(self, state_dict, strict=True): @@ -737,10 +801,9 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] - state_dict_["weight"] = state_dict_["weight"][:self.vocab_size] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] + state_dict_["weight"] = state_dict_["weight"][: self.vocab_size] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -750,11 +813,10 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) - + class QueryEmbedding(torch.nn.Module): """Language model embeddings. @@ -779,24 +841,27 @@ def __init__( self.max_sequence_length = max_sequence_length # Top query position embedding (serial). - self.top_query_embeddings = torch.nn.Embedding(self.max_sequence_length, self.hidden_size) + self.top_query_embeddings = torch.nn.Embedding( + self.max_sequence_length, self.hidden_size + ) self.top_query_embeddings = self.top_query_embeddings.half() - self._top_query_embeddings_key = 'top_query_embeddings' - + self._top_query_embeddings_key = "top_query_embeddings" + def forward(self, position_ids): # Embeddings. embeddings = self.top_query_embeddings(position_ids) - + return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._top_query_embeddings_key] \ - = self.top_query_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[ + self._top_query_embeddings_key + ] = self.top_query_embeddings.state_dict(destination, prefix, keep_vars) return state_dict_ @@ -810,11 +875,10 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'top_query_embeddings' in key: - state_dict_[key.split('top_query_embeddings.')[1]] \ - = state_dict[key] + if "top_query_embeddings" in key: + state_dict_[key.split("top_query_embeddings.")[1]] = state_dict[key] self.top_query_embeddings.load_state_dict(state_dict_, strict=strict) - + class TransformerLanguageModel(torch.nn.Module): """Transformer language model. @@ -848,32 +912,32 @@ def __init__( self.max_position_embeddings = max_position_embeddings # Embeddings - self.embedding = Embedding(self.hidden_size, - self.padded_vocab_size, - self.max_position_embeddings) - self._embedding_key = 'embedding' + self.embedding = Embedding( + self.hidden_size, self.padded_vocab_size, self.max_position_embeddings + ) + self._embedding_key = "embedding" # Query embeddings - self.topQueryEmbedding = QueryEmbedding(self.hidden_size, - self.padded_vocab_size, - self.max_position_embeddings) - self._topQueryEmbedding_key = 'topQueryEmbedding' + self.topQueryEmbedding = QueryEmbedding( + self.hidden_size, self.padded_vocab_size, self.max_position_embeddings + ) + self._topQueryEmbedding_key = "topQueryEmbedding" # Transformer - self.transformer = Transformer(self.hidden_size, - self.num_attention_heads, - self.num_layers) - self._transformer_key = 'transformer' + self.transformer = Transformer( + self.hidden_size, self.num_attention_heads, self.num_layers + ) + self._transformer_key = "transformer" def forward( - self, - input_ids, - position_ids, - attention_mask, - layer_past=None, - get_key_value=False, - prompt_length=None, - context_length=None, + self, + input_ids, + position_ids, + attention_mask, + layer_past=None, + get_key_value=False, + prompt_length=None, + context_length=None, ): # Embeddings. @@ -882,30 +946,39 @@ def forward( queryEmbedding_out = self.topQueryEmbedding(query_position_ids) # Transformer. - transformer_output = self.transformer(embedding_output, - queryEmbedding_out, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + transformer_output = self.transformer( + embedding_output, + queryEmbedding_out, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) return transformer_output - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): """For easy load.""" state_dict_ = {} - state_dict_[self._embedding_key] \ - = self.embedding.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._topQueryEmbedding_key] \ - = self.topQueryEmbedding.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - state_dict_[self._transformer_key] \ - = self.transformer.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[ + self._embedding_key + ] = self.embedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[ + self._topQueryEmbedding_key + ] = self.topQueryEmbedding.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) + state_dict_[ + self._transformer_key + ] = self.transformer.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) return state_dict_ @@ -919,7 +992,7 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if '_embeddings' in key: + if "_embeddings" in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) @@ -929,7 +1002,7 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if '_embeddings' in key: + if "_embeddings" in key: state_dict_[key] = state_dict[key] self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict) @@ -940,8 +1013,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'transformer.' in key: - state_dict_[key.split('transformer.')[1]] = state_dict[key] + if "transformer." in key: + state_dict_[key.split("transformer.")[1]] = state_dict[key] self.transformer.load_state_dict(state_dict_, strict=strict) @@ -957,14 +1030,16 @@ def __init__( max_position_embeddings, ): super(CodeGeeXModel, self).__init__() - - self.language_model = TransformerLanguageModel(hidden_size, - num_layers, - num_attention_heads, - padded_vocab_size, - max_position_embeddings) + + self.language_model = TransformerLanguageModel( + hidden_size, + num_layers, + num_attention_heads, + padded_vocab_size, + max_position_embeddings, + ) self._language_model_key = "language_model" - + def forward( self, input_ids, @@ -976,31 +1051,38 @@ def forward( context_length=None, ): # Language model. - lm_output = self.language_model(input_ids, - position_ids, - attention_mask, - layer_past=layer_past, - get_key_value=get_key_value, - prompt_length=prompt_length, - context_length=context_length) + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + layer_past=layer_past, + get_key_value=get_key_value, + prompt_length=prompt_length, + context_length=context_length, + ) if get_key_value: lm_output, presents = lm_output - output = F.linear(lm_output, self.language_model.embedding.word_embeddings.weight.half()) - + output = F.linear( + lm_output, self.language_model.embedding.word_embeddings.weight.half() + ) + if get_key_value: output = [output, presents] return output - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): + def state_dict_for_save_checkpoint( + self, destination=None, prefix="", keep_vars=False + ): state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) + state_dict_[ + self._language_model_key + ] = self.language_model.state_dict_for_save_checkpoint( + destination, prefix, keep_vars + ) return state_dict_ def load_state_dict(self, state_dict, strict=True): diff --git a/codegeex/torch/get_ckpt_qkv.py b/codegeex/torch/get_ckpt_qkv.py index 78693a4..2e25874 100644 --- a/codegeex/torch/get_ckpt_qkv.py +++ b/codegeex/torch/get_ckpt_qkv.py @@ -8,43 +8,76 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument("--load-path", - type=str, - default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_fp32_52224.pt") - parser.add_argument("--save-path", - type=str, - default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_qkv.pt") - + parser.add_argument( + "--load-path", + type=str, + default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_fp32_52224.pt", + ) + parser.add_argument( + "--save-path", + type=str, + default="/zhangpai24/workspace/ckpt_ms/ckpt_ms_213000_qkv.pt", + ) + args, _ = parser.parse_known_args() - + state_dict_path = args.load_path print("Loading state dict ...") sd = torch.load(state_dict_path, map_location="cpu") - + for i in range(40): if i < 39: - query_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.weight', None) - query_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.query.bias', None) - key_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.weight', None) - key_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.key.bias', None) - value_weight = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.weight', None) - value_bias = sd['module']['language_model']['transformer'].pop(f'layers.{i}.attention.value.bias', None) + query_weight = sd["module"]["language_model"]["transformer"].pop( + f"layers.{i}.attention.query.weight", None + ) + query_bias = sd["module"]["language_model"]["transformer"].pop( + f"layers.{i}.attention.query.bias", None + ) + key_weight = sd["module"]["language_model"]["transformer"].pop( + f"layers.{i}.attention.key.weight", None + ) + key_bias = sd["module"]["language_model"]["transformer"].pop( + f"layers.{i}.attention.key.bias", None + ) + value_weight = sd["module"]["language_model"]["transformer"].pop( + f"layers.{i}.attention.value.weight", None + ) + value_bias = sd["module"]["language_model"]["transformer"].pop( + f"layers.{i}.attention.value.bias", None + ) qkv_weight = torch.cat([query_weight, key_weight, value_weight], dim=0) qkv_bias = torch.cat([query_bias, key_bias, value_bias]) - sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.weight'] = qkv_weight - sd['module']['language_model']['transformer'][f'layers.{i}.attention.query_key_value.bias'] = qkv_bias + sd["module"]["language_model"]["transformer"][ + f"layers.{i}.attention.query_key_value.weight" + ] = qkv_weight + sd["module"]["language_model"]["transformer"][ + f"layers.{i}.attention.query_key_value.bias" + ] = qkv_bias else: - tq_key_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.weight', None) - tq_key_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.key.bias', None) - tq_value_weight = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.weight', None) - tq_value_bias = sd['module']['language_model']['transformer'].pop('topQueryLayer.attention.value.bias', None) + tq_key_weight = sd["module"]["language_model"]["transformer"].pop( + "topQueryLayer.attention.key.weight", None + ) + tq_key_bias = sd["module"]["language_model"]["transformer"].pop( + "topQueryLayer.attention.key.bias", None + ) + tq_value_weight = sd["module"]["language_model"]["transformer"].pop( + "topQueryLayer.attention.value.weight", None + ) + tq_value_bias = sd["module"]["language_model"]["transformer"].pop( + "topQueryLayer.attention.value.bias", None + ) tq_kv_weight = torch.cat([tq_key_weight, tq_value_weight], dim=0) tq_kv_bias = torch.cat([tq_key_bias, tq_value_bias]) - sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.weight'] = tq_kv_weight - sd['module']['language_model']['transformer']['topQueryLayer.attention.key_value.bias'] = tq_kv_bias - + sd["module"]["language_model"]["transformer"][ + "topQueryLayer.attention.key_value.weight" + ] = tq_kv_weight + sd["module"]["language_model"]["transformer"][ + "topQueryLayer.attention.key_value.bias" + ] = tq_kv_bias + save_ckpt_path = args.save_path torch.save(sd, save_ckpt_path) - -if __name__ == '__main__': + + +if __name__ == "__main__": main() diff --git a/codegeex/torch/inference.py b/codegeex/torch/inference.py index 7a2feec..e7f3e8d 100644 --- a/codegeex/torch/inference.py +++ b/codegeex/torch/inference.py @@ -10,10 +10,10 @@ def get_ltor_masks_and_position_ids( - data, - eod_token, - reset_position_ids, - reset_attention_mask, + data, + eod_token, + reset_position_ids, + reset_attention_mask, ): """Build masks and position id for left to right model.""" @@ -65,9 +65,9 @@ def get_ltor_masks_and_position_ids( def get_batch( - context_tokens, - micro_batch_size, - eod_token, + context_tokens, + micro_batch_size, + eod_token, reset_position_ids=False, reset_attention_mask=False, ): @@ -125,15 +125,15 @@ def pad_batch(batch, pad_id, seq_length): def forward_step( - model, - tokens, - seq_length, - position_ids, - attention_mask, - layer_past=None, - get_key_value=None, - prompt_length=None, - context_length=None, + model, + tokens, + seq_length, + position_ids, + attention_mask, + layer_past=None, + get_key_value=None, + prompt_length=None, + context_length=None, ): # Forward pass through the model. output_tensor = model( @@ -156,28 +156,30 @@ def forward_step( def get_token_stream( - model, - tokenizer, - seq_length, - out_seq_length, - context_tokens, - return_scores: bool = False, - prompt_length: int = None, - micro_batch_size: int = None, - bad_ids: List = None, - temperature: float = 1.0, - topp: float = 1.0, - topk: int = 0.0, - greedy: bool = False, - recompute: bool = False, + model, + tokenizer, + seq_length, + out_seq_length, + context_tokens, + return_scores: bool = False, + prompt_length: int = None, + micro_batch_size: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + greedy: bool = False, + recompute: bool = False, ): - context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_token_id, seq_length) + context_tokens, context_lengths = pad_batch( + context_tokens, tokenizer.eos_token_id, seq_length + ) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch( - context_tokens_tensor, + context_tokens_tensor, micro_batch_size, tokenizer.eos_token_id, ) @@ -215,23 +217,23 @@ def switch(val1, val2, boolean): def sample_sequence_batch( - model, - tokenizer, - context_tokens, - context_lengths, - attention_mask, - position_ids, - seq_length, - out_seq_length, - maxlen=None, - return_scores: bool = False, - prompt_length: int = None, - bad_ids: List = None, - temperature: float = 1.0, - topp: float = 1.0, - topk: int = 0.0, - recompute: bool = False, - greedy: bool = False, + model, + tokenizer, + context_tokens, + context_lengths, + attention_mask, + position_ids, + seq_length, + out_seq_length, + maxlen=None, + return_scores: bool = False, + prompt_length: int = None, + bad_ids: List = None, + temperature: float = 1.0, + topp: float = 1.0, + topk: int = 0.0, + recompute: bool = False, + greedy: bool = False, ): model.eval() with torch.no_grad(): @@ -257,30 +259,32 @@ def sample_sequence_batch( while context_length <= (maxlen): if recompute: - logits = model(tokens, - position_ids, - attention_mask, - prompt_length=prompt_length, - context_length=context_length, - ) + logits = model( + tokens, + position_ids, + attention_mask, + prompt_length=prompt_length, + context_length=context_length, + ) logits = logits[:, context_length - 1, :] else: if counter == 0: tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] else: - tokens2use = tokens[:, context_length - 1].view( - batch_size, -1) + tokens2use = tokens[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view( - batch_size, -1) - logits, layer_past = model(tokens2use, - positions2use, - attention_mask, - layer_past=layer_past, - get_key_value=True, - prompt_length=prompt_length, - context_length=context_length, - ) + batch_size, -1 + ) + logits, layer_past = model( + tokens2use, + positions2use, + attention_mask, + layer_past=layer_past, + get_key_value=True, + prompt_length=prompt_length, + context_length=context_length, + ) logits = logits[:, -1].view(batch_size, -1).contiguous() if bad_ids is not None: @@ -314,12 +318,12 @@ def sample_sequence_batch( lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) - + if return_scores: yield tokens, (lengths, scores) else: yield tokens, lengths - + context_length += 1 counter += 1 if done: diff --git a/deployment/server_gradio.py b/deployment/server_gradio.py index 96eb770..b7065b2 100644 --- a/deployment/server_gradio.py +++ b/deployment/server_gradio.py @@ -19,7 +19,7 @@ def model_provider(args): args.num_layers, args.num_attention_heads, args.padded_vocab_size, - args.max_position_embeddings + args.max_position_embeddings, ) return model @@ -78,7 +78,7 @@ def add_code_generation_args(parser): "--quantize", action="store_true", ) - + return parser @@ -86,11 +86,11 @@ def main(): parser = argparse.ArgumentParser() parser = add_code_generation_args(parser) args, _ = parser.parse_known_args() - + print("Loading tokenizer ...") tokenizer = CodeGeeXTokenizer( - tokenizer_path=args.tokenizer_path, - mode="codegeex-13b") + tokenizer_path=args.tokenizer_path, mode="codegeex-13b" + ) print("Loading state dict ...") state_dict = torch.load(args.load, map_location="cpu") @@ -106,18 +106,18 @@ def main(): model.cuda() def predict( - prompt, + prompt, lang, - seed, - out_seq_length, - temperature, - top_k, + seed, + out_seq_length, + temperature, + top_k, top_p, ): set_random_seed(seed) if lang.lower() in LANGUAGE_TAG: prompt = LANGUAGE_TAG[lang.lower()] + "\n" + prompt - + generated_code = codegeex.generate( model, tokenizer, @@ -142,57 +142,96 @@ def predict( gr.Markdown( """ - """) + """ + ) gr.Markdown( """

🏠 Homepage | 📖 Blog | 🪧 DEMO | 🛠 VS Code or Jetbrains Extensions | 💻 Source code | 🤖 Download Model

- """) + """ + ) gr.Markdown( """ We introduce CodeGeeX, a large-scale multilingual code generation model with 13 billion parameters, pre-trained on a large code corpus of more than 20 programming languages. CodeGeeX supports 15+ programming languages for both code generation and translation. CodeGeeX is open source, please refer to our [GitHub](https://github.com/THUDM/CodeGeeX) for more details. This is a minimal-functional DEMO, for other DEMOs like code translation, please visit our [Homepage](https://codegeex.cn). We also offer free [VS Code](https://marketplace.visualstudio.com/items?itemName=aminer.codegeex) or [Jetbrains](https://plugins.jetbrains.com/plugin/20587-codegeex) extensions for full functionality. - """) + """ + ) with gr.Row(): with gr.Column(): - prompt = gr.Textbox(lines=13, placeholder='Please enter the description or select an example input below.',label='Input') + prompt = gr.Textbox( + lines=13, + placeholder="Please enter the description or select an example input below.", + label="Input", + ) with gr.Row(): gen = gr.Button("Generate") clr = gr.Button("Clear") - outputs = gr.Textbox(lines=15, label='Output') + outputs = gr.Textbox(lines=15, label="Output") gr.Markdown( """ Generation Parameter - """) + """ + ) with gr.Row(): with gr.Column(): lang = gr.Radio( - choices=["C++", "C", "C#", "Python", "Java", "HTML", "PHP", "JavaScript", "TypeScript", "Go", - "Rust", - "SQL", "Kotlin", "R", "Fortran"], value='lang', label='Programming Language', - default="Python") + choices=[ + "C++", + "C", + "C#", + "Python", + "Java", + "HTML", + "PHP", + "JavaScript", + "TypeScript", + "Go", + "Rust", + "SQL", + "Kotlin", + "R", + "Fortran", + ], + value="lang", + label="Programming Language", + default="Python", + ) with gr.Column(): - seed = gr.Slider(maximum=10000, value=8888, step=1, label='Seed') + seed = gr.Slider(maximum=10000, value=8888, step=1, label="Seed") with gr.Row(): - out_seq_length = gr.Slider(maximum=1024, value=128, minimum=1, step=1, label='Output Sequence Length') - temperature = gr.Slider(maximum=1, value=0.2, minimum=0, label='Temperature') + out_seq_length = gr.Slider( + maximum=1024, + value=128, + minimum=1, + step=1, + label="Output Sequence Length", + ) + temperature = gr.Slider( + maximum=1, value=0.2, minimum=0, label="Temperature" + ) with gr.Row(): - top_k = gr.Slider(maximum=40, value=0, minimum=0, step=1, label='Top K') - top_p = gr.Slider(maximum=1, value=1.0, minimum=0, label='Top P') + top_k = gr.Slider( + maximum=40, value=0, minimum=0, step=1, label="Top K" + ) + top_p = gr.Slider(maximum=1, value=1.0, minimum=0, label="Top P") inputs = [prompt, lang, seed, out_seq_length, temperature, top_k, top_p] gen.click(fn=predict, inputs=inputs, outputs=outputs) clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=prompt) - gr_examples = gr.Examples(examples=examples, inputs=[prompt, lang], - label="Example Inputs (Click to insert an examplet it into the input box)", - examples_per_page=20) - + gr_examples = gr.Examples( + examples=examples, + inputs=[prompt, lang], + label="Example Inputs (Click to insert an examplet it into the input box)", + examples_per_page=20, + ) + demo.launch(server_port=6007) -if __name__ == '__main__': + +if __name__ == "__main__": with torch.no_grad(): - main() \ No newline at end of file + main() diff --git a/scripts/evaluate_humaneval_x.py b/scripts/evaluate_humaneval_x.py index eb91b9e..bebb6b6 100644 --- a/scripts/evaluate_humaneval_x.py +++ b/scripts/evaluate_humaneval_x.py @@ -2,33 +2,33 @@ import os from pathlib import Path from codegeex.benchmark.evaluate_humaneval_x import evaluate_functional_correctness -#GLOBALS -INPUT_FILE: str -LANGUAGE: str -N_WORKERS: int -TIMEOUT: int + +# GLOBALS +INPUT_FILE: str +LANGUAGE: str +N_WORKERS: int +TIMEOUT: int parser = argparse.ArgumentParser("Debugging evaluate humaneval_x") # Path to the .jsonl file that contains the generated codes. -parser.add_argument("-s","--samples", type=str) +parser.add_argument("-s", "--samples", type=str) # Target programming language, currently support one of ["python", "java", "cpp", "js", "go"] -parser.add_argument("-l","--language", default="python", type=str) +parser.add_argument("-l", "--language", default="python", type=str) # Number of parallel workers. -parser.add_argument("-w","--workers", default=64, type=int) +parser.add_argument("-w", "--workers", default=64, type=int) # Timeout in seconds. -parser.add_argument("-t","--timeout", default=5, type=int) +parser.add_argument("-t", "--timeout", default=5, type=int) args = parser.parse_args() INPUT_FILE = args.samples -LANGUAGE = args.language -N_WORKERS = args.workers -TIMEOUT= args.timeout - +LANGUAGE = args.language +N_WORKERS = args.workers +TIMEOUT = args.timeout SCRIPT_PATH: str = Path(os.path.abspath(__file__)) @@ -38,16 +38,32 @@ MAIN_DIR: str = os.path.dirname(SCRIPT_DIR) print(MAIN_DIR) -DATA_DIR=os.path.join(MAIN_DIR,"codegeex/benchmark/humaneval-x/" + LANGUAGE + "/data/humaneval_" + LANGUAGE + ".jsonl.gz") +DATA_DIR = os.path.join( + MAIN_DIR, + "codegeex/benchmark/humaneval-x/" + + LANGUAGE + + "/data/humaneval_" + + LANGUAGE + + ".jsonl.gz", +) print(DATA_DIR) -TMP_DIR=os.path.join(MAIN_DIR, "/codegeex/benchmark/humaneval-x/") +TMP_DIR = os.path.join(MAIN_DIR, "/codegeex/benchmark/humaneval-x/") -#Debugging -INPUT_FILE='/home/rog0d/Escritorio/CodeGeeX/generations/humaneval_rust_generations.jsonl.gz' -LANGUAGE='rust' -DATA_DIR=os.path.join(MAIN_DIR,"codegeex/benchmark/humaneval-x/" + LANGUAGE + "/data/humaneval_" + LANGUAGE + ".jsonl.gz") +# Debugging +INPUT_FILE = ( + "/home/rog0d/Escritorio/CodeGeeX/generations/humaneval_rust_generations.jsonl.gz" +) +LANGUAGE = "rust" +DATA_DIR = os.path.join( + MAIN_DIR, + "codegeex/benchmark/humaneval-x/" + + LANGUAGE + + "/data/humaneval_" + + LANGUAGE + + ".jsonl.gz", +) """ input_file: str = None, @@ -62,10 +78,10 @@ """ -evaluate_functional_correctness(input_file=INPUT_FILE, - n_workers=N_WORKERS, - tmp_dir=TMP_DIR, - problem_file=DATA_DIR, - timeout=300.0) - - +evaluate_functional_correctness( + input_file=INPUT_FILE, + n_workers=N_WORKERS, + tmp_dir=TMP_DIR, + problem_file=DATA_DIR, + timeout=300.0, +) diff --git a/setup.py b/setup.py index 3260672..9367170 100644 --- a/setup.py +++ b/setup.py @@ -22,5 +22,5 @@ "cpm_kernels", "deepspeed>0.6.1", ], - entry_points={} + entry_points={}, ) diff --git a/tests/test_inference.py b/tests/test_inference.py index c4267ab..5a1272b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -17,9 +17,9 @@ def model_provider(args): args.num_layers, args.num_attention_heads, args.padded_vocab_size, - args.max_position_embeddings + args.max_position_embeddings, ) - + return model @@ -111,19 +111,19 @@ def add_code_generation_args(parser): "--interative", action="store_true", ) - + return parser - + def main(): parser = argparse.ArgumentParser() parser = add_code_generation_args(parser) args, _ = parser.parse_known_args() - + print("Loading tokenizer ...") tokenizer = CodeGeeXTokenizer( - tokenizer_path=args.tokenizer_path, - mode="codegeex-13b") + tokenizer_path=args.tokenizer_path, mode="codegeex-13b" + ) print("Loading state dict ...") state_dict = torch.load(args.load, map_location="cpu") @@ -138,16 +138,18 @@ def main(): model = quantize(model, weight_bit_width=8, backend="torch") model.cuda() torch.cuda.synchronize() - + with open(args.prompt_file, "r") as f: prompt = f.readlines() prompt = "".join(prompt) - + out_seq_lengths = [args.out_seq_length] - for out_seq_length in out_seq_lengths: + for out_seq_length in out_seq_lengths: print(f"Generating with out_seq_len {out_seq_length}...") while True: - print("\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> ") + print( + "\nPlease Input Query (Ctrl-D to save multiple lines, 'stop' to exit) >>> " + ) prompts = [] while True: try: @@ -158,10 +160,10 @@ def main(): prompt = "\n".join(prompts) prompt = prompt.strip() if not prompt: - print('Query should not be empty!') + print("Query should not be empty!") continue if prompt == "stop": - return + return try: t0 = time.perf_counter() generated_code = codegeex.generate( @@ -182,9 +184,9 @@ def main(): except (ValueError, FileNotFoundError) as e: print(e) continue - + print("Generation finished.") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_inference_megatron.py b/tests/test_inference_megatron.py index 66bda37..57cf088 100644 --- a/tests/test_inference_megatron.py +++ b/tests/test_inference_megatron.py @@ -25,10 +25,9 @@ def set_random_seed(seed): def model_provider(pre_process=True, post_process=True): """Build the model.""" - + print_rank_0("Building CodeGeeX model ...") - model = CodeGeeXModel(num_tokentypes=0, - parallel_output=False) + model = CodeGeeXModel(num_tokentypes=0, parallel_output=False) return model @@ -71,7 +70,7 @@ def add_code_generation_args(parser): "--recompute", action="store_true", help="During generation recompute all attention " - "instead of using previously computed keys/values.", + "instead of using previously computed keys/values.", ) group.add_argument( "--ws-encoding-start-id", @@ -119,11 +118,11 @@ def add_code_generation_args(parser): action="store_true", ) group.add_argument( - '--bad-ids', + "--bad-ids", nargs="*", type=int, default=None, - help='Identify the type of programming language to generate', + help="Identify the type of programming language to generate", ) group.add_argument( "--quantize", @@ -137,9 +136,9 @@ def main(): initialize_megatron( extra_args_provider=add_code_generation_args, args_defaults={ - 'no_load_rng': True, - 'no_load_optim': True, - } + "no_load_rng": True, + "no_load_optim": True, + }, ) args = get_args() @@ -153,9 +152,9 @@ def main(): model = get_model(model_provider) if args.load is not None: _ = load_checkpoint(model, None, None) - + assert len(model) == 1, "Above condition should have caught this" - + model = model[0] model.eval() if args.fp16 and args.ln_fp16: @@ -166,13 +165,13 @@ def main(): with open(args.prompt_file, "r") as f: prompt = f.readlines() prompt = "".join(prompt) - + times = {} out_seq_lengths = [args.out_seq_length] micro_batch_size = args.micro_batch_size - for out_seq_length in out_seq_lengths: + for out_seq_length in out_seq_lengths: print_rank_0(f"Generating with out_seq_len {out_seq_length}...") - + times[out_seq_length] = [] for prompt in [prompt] * args.n_generation: t0 = time.perf_counter() @@ -196,27 +195,38 @@ def main(): for j in range(micro_batch_size): if is_finished[j]: continue - if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod or len( - generated_tokens[j]) >= out_seq_length: + if ( + generated_tokens[j].cpu().numpy()[-1] == tokenizer.eod + or len(generated_tokens[j]) >= out_seq_length + ): is_finished[j] = True generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() - generated_code = tokenizer.detokenize(generated_tokens_[n_token_prompt:]) + generated_code = tokenizer.detokenize( + generated_tokens_[n_token_prompt:] + ) t1 = time.perf_counter() - print_rank_0(f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}") - print_rank_0(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") + print_rank_0( + f"Total generation time: {t1 - t0}, # Tokens: {len(generated_tokens_) - n_token_prompt}" + ) + print_rank_0( + f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token" + ) times[out_seq_length].append(t1 - t0) - print_rank_0("================================= Generated code:") + print_rank_0( + "================================= Generated code:" + ) print_rank_0(generated_code) t0 = time.perf_counter() - + if all(is_finished): break print_rank_0(times) for out_seq_length in times.keys(): print_rank_0(f"{out_seq_length}, {np.mean(times[out_seq_length])}") - + print_rank_0("Generation finished.") + if __name__ == "__main__": main() diff --git a/tests/test_inference_oneflow.py b/tests/test_inference_oneflow.py index 037849c..d0f507f 100644 --- a/tests/test_inference_oneflow.py +++ b/tests/test_inference_oneflow.py @@ -1,4 +1,3 @@ - import os import copy import time @@ -11,9 +10,11 @@ from codegeex.oneflow import CodeGeeXModel from codegeex.tokenizer import CodeGeeXTokenizer from codegeex.quantization import quantize_oneflow + os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1" + def model_provider(args): """Build the model.""" @@ -22,9 +23,9 @@ def model_provider(args): args.num_layers, args.num_attention_heads, args.padded_vocab_size, - args.max_position_embeddings + args.max_position_embeddings, ) - + return model @@ -112,19 +113,19 @@ def add_code_generation_args(parser): "--quantize", action="store_true", ) - + return parser - + def main(): parser = argparse.ArgumentParser() parser = add_code_generation_args(parser) args, _ = parser.parse_known_args() - + print("Loading tokenizer ...") tokenizer = CodeGeeXTokenizer( - tokenizer_path=args.tokenizer_path, - mode="codegeex-13b") + tokenizer_path=args.tokenizer_path, mode="codegeex-13b" + ) print("Loading state dict ...") state_dict = torch.load(args.load, map_location="cpu") @@ -142,14 +143,14 @@ def main(): with open(args.prompt_file, "r") as f: prompt = f.readlines() prompt = "".join(prompt) - + times = {} out_seq_lengths = [args.out_seq_length] micro_batch_size = args.micro_batch_size seq_length = args.max_position_embeddings - for out_seq_length in out_seq_lengths: + for out_seq_length in out_seq_lengths: print(f"Generating with out_seq_len {out_seq_length}...") - + times[out_seq_length] = [] for prompt in [prompt]: t0 = time.perf_counter() @@ -178,26 +179,37 @@ def main(): if is_finished[j]: continue generated_token_numpy = generated_tokens[j].numpy() - if generated_token_numpy[-1] == tokenizer.eos_token_id or len( - generated_tokens[j]) >= out_seq_length: + if ( + generated_token_numpy[-1] == tokenizer.eos_token_id + or len(generated_tokens[j]) >= out_seq_length + ): is_finished[j] = True generated_tokens_ = generated_token_numpy.tolist() - generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) + generated_code = tokenizer.decode_code( + generated_tokens_[n_token_prompt:] + ) generated_code = "".join(generated_code) t1 = time.perf_counter() - print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt) - print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") + print( + "Total generation time:", + t1 - t0, + "# Tokens:", + len(generated_tokens_) - n_token_prompt, + ) + print( + f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token" + ) times[out_seq_length].append(t1 - t0) print("================================= Generated code:") print(generated_code) - + if all(is_finished): break - + print(times) for out_seq_length in times.keys(): print(out_seq_length, np.mean(times[out_seq_length])) - + print("Generation finished.") diff --git a/tests/test_inference_paddle.py b/tests/test_inference_paddle.py index 4a6c21a..d85dcd3 100644 --- a/tests/test_inference_paddle.py +++ b/tests/test_inference_paddle.py @@ -1,4 +1,3 @@ - import os import copy import time @@ -22,7 +21,7 @@ def model_provider(args): args.num_layers, args.num_attention_heads, args.padded_vocab_size, - args.max_position_embeddings + args.max_position_embeddings, ) model.language_model.embedding.word_embeddings.to(dtype="float32") model.language_model.embedding.position_embeddings.to(dtype="float32") @@ -31,10 +30,12 @@ def model_provider(args): i.input_layernorm.to(dtype="float32") i.post_attention_layernorm.to(dtype="float32") model.language_model.transformer.topQueryLayer.input_layernorm.to(dtype="float32") - model.language_model.transformer.topQueryLayer.post_attention_layernorm.to(dtype="float32") + model.language_model.transformer.topQueryLayer.post_attention_layernorm.to( + dtype="float32" + ) model.language_model.transformer.final_layernorm.to(dtype="float32") paddle.set_default_dtype(old_dtype) - + return model @@ -122,19 +123,19 @@ def add_code_generation_args(parser): "--quantize", action="store_true", ) - + return parser - + def main(): parser = argparse.ArgumentParser() parser = add_code_generation_args(parser) args, _ = parser.parse_known_args() - + print("Loading tokenizer ...") tokenizer = CodeGeeXTokenizer( - tokenizer_path=args.tokenizer_path, - mode="codegeex-13b") + tokenizer_path=args.tokenizer_path, mode="codegeex-13b" + ) print("Loading state dict ...") state_dict = paddle.load(args.load) @@ -147,18 +148,18 @@ def main(): model.to(dtype="float16") if args.quantize: raise NotImplementedError("quantize") - + with open(args.prompt_file, "r") as f: prompt = f.readlines() prompt = "".join(prompt) - + times = {} out_seq_lengths = [args.out_seq_length] micro_batch_size = args.micro_batch_size seq_length = args.max_position_embeddings - for out_seq_length in out_seq_lengths: + for out_seq_length in out_seq_lengths: print(f"Generating with out_seq_len {out_seq_length}...") - + times[out_seq_length] = [] for prompt in [prompt]: t0 = time.perf_counter() @@ -186,26 +187,37 @@ def main(): for j in range(micro_batch_size): if is_finished[j]: continue - if generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id or len( - generated_tokens[j]) >= out_seq_length: + if ( + generated_tokens[j].cpu().numpy()[-1] == tokenizer.eos_token_id + or len(generated_tokens[j]) >= out_seq_length + ): is_finished[j] = True generated_tokens_ = generated_tokens[j].cpu().numpy().tolist() - generated_code = tokenizer.decode_code(generated_tokens_[n_token_prompt:]) + generated_code = tokenizer.decode_code( + generated_tokens_[n_token_prompt:] + ) generated_code = "".join(generated_code) t1 = time.perf_counter() - print("Total generation time:", t1 - t0, "# Tokens:", len(generated_tokens_) - n_token_prompt) - print(f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token") + print( + "Total generation time:", + t1 - t0, + "# Tokens:", + len(generated_tokens_) - n_token_prompt, + ) + print( + f"{(t1 - t0) / (len(generated_tokens_) - n_token_prompt)}s/token" + ) times[out_seq_length].append(t1 - t0) print("================================= Generated code:") print(generated_code) - + if all(is_finished): break - + print(times) for out_seq_length in times.keys(): print(out_seq_length, np.mean(times[out_seq_length])) - + print("Generation finished.")