Skip to content

Commit d41e2d7

Browse files
committed
Let loss_compare.py check the repo cleaness
This will prevent errors when later doing git checkout ghstack-source-id: 0725199 Pull-Request: #2062
1 parent c8ebd7a commit d41e2d7

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

scripts/loss_compare.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,33 @@ def print_configuration(
301301
# =============================================================================
302302

303303

304+
def check_git_clean_state() -> None:
305+
"""Check if git working directory is clean before switching commits.
306+
307+
Raises SystemExit if there are uncommitted changes or untracked files.
308+
"""
309+
result = subprocess.run(
310+
["git", "status", "--porcelain"],
311+
capture_output=True,
312+
text=True,
313+
check=True,
314+
)
315+
316+
if result.stdout.strip():
317+
log_print("Error: Git working directory is not clean")
318+
log_print(" Cannot switch commits with uncommitted changes")
319+
log_print("")
320+
log_print("Modified/untracked files:")
321+
for line in result.stdout.strip().split("\n"):
322+
log_print(f" {line}")
323+
log_print("")
324+
log_print("Please commit, stash, or discard your changes before running this script")
325+
log_print(" - To commit: git add -A && git commit -m 'message'")
326+
log_print(" - To stash: git stash")
327+
log_print(" - To discard: git checkout -- . && git clean -fd")
328+
sys.exit(1)
329+
330+
304331
def checkout_commit(commit: str, commit_name: str) -> None:
305332
"""Checkout git commit."""
306333
if commit != ".":
@@ -840,6 +867,12 @@ def main() -> None:
840867
args.job_dump_folder,
841868
)
842869

870+
# Check if git working directory is clean before switching commits
871+
# Skip check if both commits are "." (comparing configs on same commit)
872+
needs_git_checkout = args.baseline_commit != "." or args.test_commit != "."
873+
if needs_git_checkout:
874+
check_git_clean_state()
875+
843876
create_seed_checkpoint(
844877
enable_seed_checkpoint,
845878
args.baseline_config,

0 commit comments

Comments
 (0)