Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0e4a88e
Added Loggers for optimizer, routine and line search removed tqdm
MPMPMPMPMPMPMP Oct 15, 2025
66cd0c1
fixes to the logger implementation
MPMPMPMPMPMPMP Oct 16, 2025
98603bf
fixes for logger + added some other logging
MPMPMPMPMPMPMP Oct 18, 2025
df74df4
add logging for expectation function
MPMPMPMPMPMPMP Oct 19, 2025
273387e
changed info -> debug in autosaving
MPMPMPMPMPMPMP Oct 19, 2025
4a932a7
another fix for the logger
MPMPMPMPMPMPMP Oct 20, 2025
be9ed47
this is only one very big unindent
MPMPMPMPMPMPMP Oct 20, 2025
061ad27
fix a bug when loading a config file
MPMPMPMPMPMPMP Oct 20, 2025
df51349
enhance warning message for insufficient convergence in optimize_peps…
MPMPMPMPMPMPMP Oct 20, 2025
f484860
Added logging to structure factor ctmrg
MPMPMPMPMPMPMP Oct 21, 2025
1906ff5
measuere time it takes to Autosave for logger
MPMPMPMPMPMPMP Oct 21, 2025
ccf67b4
add date to logger and log time in energy update
MPMPMPMPMPMPMP Oct 21, 2025
7b8e2b3
added sec to time
MPMPMPMPMPMPMP Oct 22, 2025
23cf09f
ensure logging configured in optimizer
MPMPMPMPMPMPMP Oct 22, 2025
fbc6a4a
Add delta to logging
MPMPMPMPMPMPMP Oct 23, 2025
208a3fc
fixed missing import
MPMPMPMPMPMPMP Oct 24, 2025
1a7c2dd
fixed a misstake in logger
MPMPMPMPMPMPMP Oct 28, 2025
4489c23
logger improvement
MPMPMPMPMPMPMP Oct 28, 2025
7177a99
show chi in optimization log messages, change info to warning in incr…
MPMPMPMPMPMPMP Oct 28, 2025
807c4ed
logger lineseach/ctmrg improvement
MPMPMPMPMPMPMP Oct 30, 2025
b093571
added tqdm logging
MPMPMPMPMPMPMP Nov 7, 2025
cadf89a
shortend step message and changed some levels
MPMPMPMPMPMPMP Nov 7, 2025
f3c8dc7
changed a symbol in logger
MPMPMPMPMPMPMP Nov 7, 2025
4f43ba1
fixed typo
MPMPMPMPMPMPMP Nov 7, 2025
7d73a8b
changed message
MPMPMPMPMPMPMP Nov 7, 2025
0bbc43a
changed message
MPMPMPMPMPMPMP Nov 17, 2025
0ae7a6b
fixed a logging message in the structure factor routine
MPMPMPMPMPMPMP Nov 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions varipeps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,4 @@

jax_config.update("jax_enable_x64", True)

from tqdm_loggable.tqdm_logging import tqdm_logging
import datetime

tqdm_logging.set_log_rate(datetime.timedelta(seconds=60))

del datetime
del tqdm_logging
del jax_config
68 changes: 66 additions & 2 deletions varipeps/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from enum import Enum, IntEnum, auto, unique
from typing import TypeVar, Tuple, Any, Type, NoReturn
import logging

import numpy as np

from jax.tree_util import register_pytree_node_class

from typing import TypeVar, Tuple, Any, Type, NoReturn

T_VariPEPS_Config = TypeVar("T_VariPEPS_Config", bound="VariPEPS_Config")

Expand Down Expand Up @@ -54,6 +55,15 @@ class Slurm_Restart_Mode(IntEnum):
AUTOMATIC_RESTART = auto() #: Write restart script and start new slurm job with it


@unique
class LogLevel(IntEnum):
OFF = 0
ERROR = logging.ERROR
WARNING = logging.WARNING
INFO = logging.INFO
DEBUG = logging.DEBUG


@dataclass
@register_pytree_node_class
class VariPEPS_Config:
Expand Down Expand Up @@ -234,6 +244,27 @@ class VariPEPS_Config:
Type of wavevector to be used (only positive/symmetric interval/...).
slurm_restart_mode (:obj:`Slurm_Restart_Mode`):
Mode of operation to restart slurm job if maximal runtime is reached.
log_level_global (:obj:`LogLevel`):
Global logging level for the 'varipeps' package logger.
log_level_optimizer (:obj:`LogLevel`):
Logging level for 'varipeps.optimizer'.
log_level_ctmrg (:obj:`LogLevel`):
Logging level for 'varipeps.ctmrg'.
log_level_line_search (:obj:`LogLevel`):
Logging level for 'varipeps.line_search'.
log_level_expectation (:obj:`LogLevel`):
Logging level for 'varipeps.expectation'.
log_to_console (:obj:`bool`):
Enable standard console logging (StreamHandler).
Ignored when :obj:`VariPEPS_Config.log_tqdm` is True.
log_to_file (:obj:`bool`):
Enable logging to file.
log_file (:obj:`str`):
Filename for logging to file (used when :obj:`VariPEPS_Config.log_to_file` is True).
log_tqdm (:obj:`bool`):
Enable tqdm-based console logging. If True, messages from
'varipeps.optimizer' update a tqdm progress bar, while other modules
log via tqdm.write. File logging settings still apply.
"""

# AD config
Expand Down Expand Up @@ -322,6 +353,17 @@ class VariPEPS_Config:
# Slurm
slurm_restart_mode: Slurm_Restart_Mode = Slurm_Restart_Mode.WRITE_NEED_RESTART_FILE

# Logging configuration
log_level_global: LogLevel = LogLevel.INFO
log_level_optimizer: LogLevel = LogLevel.INFO
log_level_ctmrg: LogLevel = LogLevel.INFO
log_level_line_search: LogLevel = LogLevel.INFO
log_level_expectation: LogLevel = LogLevel.INFO
log_to_console: bool = True
log_to_file: bool = False
log_file: str = "varipeps.log"
log_tqdm: bool = False #: Enable tqdm-based console logging

def update(self, name: str, value: Any) -> NoReturn:
self.__setattr__(name, value)

Expand Down Expand Up @@ -358,12 +400,33 @@ def __setattr__(self, name: str, value: Any) -> NoReturn:
elif (
field.type is bool
and hasattr(value, "dtype")
and np.isdtype(value.dtype, np.bool)
and np.issubdtype(value.dtype, np.bool_)
and value.size == 1
):
if value.ndim > 0:
value = value.reshape(-1)[0]
value = bool(value)
elif isinstance(field.type, type) and issubclass(field.type, Enum):
# Accept ints/np.int64 or enum names for Enum fields
if isinstance(value, field.type):
pass
elif isinstance(value, (int,)) or (
hasattr(value, "dtype")
and np.issubdtype(value.dtype, np.integer)
and value.size == 1
):
if hasattr(value, "ndim") and value.ndim > 0:
value = value.reshape(-1)[0]
value = field.type(int(value))
elif isinstance(value, str):
try:
value = field.type[value]
except KeyError:
value = field.type(int(value))
else:
raise TypeError(
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
)
else:
raise TypeError(
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
Expand Down Expand Up @@ -407,6 +470,7 @@ class ConfigModuleWrapper:
"Projector_Method",
"Wavevector_Type",
"Slurm_Restart_Mode",
"LogLevel",
"VariPEPS_Config",
"config",
}
Expand Down
61 changes: 46 additions & 15 deletions varipeps/ctmrg/routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from jax import jit, custom_vjp, vjp, tree_util
from jax.lax import cond, while_loop
import jax.debug as jdebug
import logging
import time
import jax

logger = logging.getLogger("varipeps.ctmrg")

from varipeps import varipeps_config, varipeps_global_state
from varipeps.peps import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Unit_Cell
Expand Down Expand Up @@ -515,9 +520,8 @@ def corner_svd_func(old, new, old_corner, conv_eps, config):
eps,
config,
)

if config.ctmrg_print_steps:
debug_print("CTMRG: {}: {}", count, measure)
if logger.isEnabledFor(logging.DEBUG):
jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True)
if config.ctmrg_verbose_output:
jax.debug.callback(print_verbose, verbose_data, ordered=True)

Expand Down Expand Up @@ -620,9 +624,9 @@ def calc_ctmrg_env(
best_norm_smallest_S = None
best_truncation_eps = None
have_been_increased = False

while True:
tmp_count = 0
t0 = time.perf_counter()
corner_singular_vals = None

while tmp_count < varipeps_config.ctmrg_max_steps and (
Expand Down Expand Up @@ -720,6 +724,17 @@ def calc_ctmrg_env(
else:
converged = False
end_count = tmp_count

if not converged and logger.isEnabledFor(logging.WARNING):
logger.warning(
"CTMRG: ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)
elif logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG: ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)

if converged and (
working_unitcell[0, 0][0][0].chi > best_chi or best_result is None
Expand Down Expand Up @@ -751,9 +766,9 @@ def calc_ctmrg_env(
working_unitcell = working_unitcell.change_chi(new_chi)
initial_unitcell = initial_unitcell.change_chi(new_chi)

if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Increasing chi to {} since smallest SVD Norm was {}.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"Increasing chi to %d since smallest SVD Norm was %.3e.",
new_chi,
norm_smallest_S,
)
Expand Down Expand Up @@ -785,9 +800,9 @@ def calc_ctmrg_env(
if not new_chi in already_tried_chi:
working_unitcell = working_unitcell.change_chi(new_chi)

if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Decreasing chi to {} since smallest SVD Norm was {} or routine did not converge.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"Decreasing chi to %d since smallest SVD Norm was %.3e or routine did not converge.",
new_chi,
norm_smallest_S,
)
Expand All @@ -809,9 +824,9 @@ def calc_ctmrg_env(
new_truncation_eps
<= varipeps_config.ctmrg_increase_truncation_eps_max_value
):
if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Increasing SVD truncation eps to {}.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"Increasing SVD truncation eps to %.1e.",
new_truncation_eps,
)
varipeps_global_state.ctmrg_effective_truncation_eps = (
Expand Down Expand Up @@ -884,6 +899,8 @@ def calc_ctmrg_env_fwd(
Internal helper function of custom VJP to calculate the values in
the forward sweep.
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Custom VJP: Starting forward CTMRG calculation.")
new_unitcell, last_truncation_eps, norm_smallest_S = calc_ctmrg_env_custom_rule(
peps_tensors, unitcell, _return_truncation_eps=True
)
Expand Down Expand Up @@ -937,8 +954,8 @@ def _ctmrg_rev_while_body(carry):

count += 1

if config.ad_custom_print_steps:
debug_print("Custom VJP: {}: {}", count, measure)
if logger.isEnabledFor(logging.DEBUG):
jax.debug.callback(lambda cnt, msr: logger.debug(f"Custom VJP: Step {cnt}: {msr}"), count, measure, ordered=True)
if config.ad_custom_verbose_output:
jax.debug.callback(print_verbose, verbose_data, ordered=True, ad=True)

Expand Down Expand Up @@ -1009,17 +1026,31 @@ def calc_ctmrg_env_rev(
Internal helper function of custom VJP to calculate the gradient in
the backward sweep.
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Custom VJP: Starting reverse CTMRG calculation.")
unitcell_bar, _ = input_bar
peps_tensors, new_unitcell, input_unitcell, last_truncation_eps = res

varipeps_global_state.ctmrg_effective_truncation_eps = last_truncation_eps

if logger.isEnabledFor(logging.WARNING):
t0 = time.perf_counter()
t_bar, converged, end_count = _ctmrg_rev_workhorse(
peps_tensors, new_unitcell, unitcell_bar, varipeps_config, varipeps_global_state
)

varipeps_global_state.ctmrg_effective_truncation_eps = None

if not converged and logger.isEnabledFor(logging.WARNING):
logger.warning(
"Custom VJP: ❌ did not converge, took %.2f seconds. (Steps: %d)",
time.perf_counter() - t0, end_count
)
elif logger.isEnabledFor(logging.INFO):
logger.info(
"Custom VJP: ✅ converged, took %.2f seconds. (Steps: %d)",
time.perf_counter() - t0, end_count
)
if end_count == varipeps_config.ad_custom_max_steps and not converged:
raise CTMRGGradientNotConvergedError

Expand Down
40 changes: 27 additions & 13 deletions varipeps/ctmrg/structure_factor_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from jax import jit, custom_vjp, vjp, tree_util
from jax.lax import cond, while_loop
import jax.debug as jdebug
import logging
import time

logger = logging.getLogger("varipeps.ctmrg")

from varipeps import varipeps_config, varipeps_global_state
from varipeps.peps import PEPS_Tensor, PEPS_Unit_Cell
Expand Down Expand Up @@ -125,8 +129,8 @@ def _ctmrg_body_func_structure_factor(carry):
measure = jnp.linalg.norm(corner_svd - last_corner_svd)
converged = measure < eps

if config.ctmrg_print_steps:
debug_print("CTMRG: {}: {}", count, measure)
if logger.isEnabledFor(logging.DEBUG):
jdebug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True)
if config.ctmrg_verbose_output:
for ti, ctm_enum_i, diff in verbose_data:
debug_print(
Expand Down Expand Up @@ -244,6 +248,7 @@ def calc_ctmrg_env_structure_factor(
norm_smallest_S = jnp.nan
already_tried_chi = {working_unitcell[0, 0][0][0].chi}

t0 = time.perf_counter()
while True:
tmp_count = 0
corner_singular_vals = None
Expand Down Expand Up @@ -304,6 +309,17 @@ def calc_ctmrg_env_structure_factor(
)
)

if not converged and logger.isEnabledFor(logging.WARNING):
logger.warning(
"CTMRG (SF): ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)
elif logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG (SF): ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)

current_truncation_eps = (
varipeps_config.ctmrg_truncation_eps
if varipeps_global_state.ctmrg_effective_truncation_eps is None
Expand All @@ -326,15 +342,14 @@ def calc_ctmrg_env_structure_factor(
working_unitcell = working_unitcell.change_chi(new_chi)
initial_unitcell = initial_unitcell.change_chi(new_chi)

if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Increasing chi to {} since smallest SVD Norm was {}.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG (SF): Increasing chi to %d since smallest SVD Norm was %.3e.",
new_chi,
norm_smallest_S,
)

already_tried_chi.add(new_chi)

continue
elif (
varipeps_config.ctmrg_heuristic_decrease_chi
Expand All @@ -351,15 +366,14 @@ def calc_ctmrg_env_structure_factor(
if not new_chi in already_tried_chi:
working_unitcell = working_unitcell.change_chi(new_chi)

if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Decreasing chi to {} since smallest SVD Norm was {}.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG (SF): Decreasing chi to %d since smallest SVD Norm was %.3e.",
new_chi,
norm_smallest_S,
)

already_tried_chi.add(new_chi)

continue

if (
Expand All @@ -375,9 +389,9 @@ def calc_ctmrg_env_structure_factor(
new_truncation_eps
<= varipeps_config.ctmrg_increase_truncation_eps_max_value
):
if varipeps_config.ctmrg_print_steps:
debug_print(
"CTMRG: Increasing SVD truncation eps to {}.",
if logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG (SF): Increasing SVD truncation eps to %g.",
new_truncation_eps,
)
varipeps_global_state.ctmrg_effective_truncation_eps = (
Expand Down
Loading