Skip to content

Commit 1deffd6

Browse files
authored
Change ParamResolver keys to be strings (#7714)
- This PR changes ParamResolvers to only use string keys. - Previously, ParamResolver keys could be strings or symbols. - If a symbol is detected, the ParamResolver will remake the dictionary using only strings. - This simplifies the logic and improves performance since sympy Symbols can be slower to create and hash. Performance implications: - Creation speed is unchanged unless you have symbols as keys, then creation takes about twice as long. (For 1000 parameters, this is a change from about 0.6ms to 1.3ms). - Lookup speed is now about ~120-130ns - Before it was: -- ~200ns for string lookup with string keys -- ~250ns for symbol lookup with string keys -- ~850ns for string lookup with symbol keys -- ~350ns for symbol lookup with symbol keys
1 parent 08a15c8 commit 1deffd6

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

cirq-core/cirq/study/resolver.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,20 @@ def __init__(self, param_dict: cirq.ParamResolverOrSimilarType = None) -> None:
7575

7676
self._param_hash: int | None = None
7777
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
78+
self._param_dict_with_str_keys = self._param_dict
79+
generate_str_keys = False
7880
for key in self._param_dict:
79-
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
80-
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
81+
if isinstance(key, sympy.Expr):
82+
if isinstance(key, sympy.Symbol):
83+
generate_str_keys = True
84+
else:
85+
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
86+
if generate_str_keys:
87+
# Remake dictionary with string keys for faster access
88+
self._param_dict_with_str_keys = {
89+
(key.name if isinstance(key, sympy.Symbol) else key): value
90+
for key, value in self._param_dict.items()
91+
}
8192
self._deep_eval_map: ParamDictType = {}
8293

8394
@property
@@ -119,22 +130,23 @@ def value_of(
119130
"""
120131

121132
# Handle string or symbol
122-
if isinstance(value, (str, sympy.Symbol)):
123-
string = value if isinstance(value, str) else value.name
124-
param_value = self._param_dict.get(string, _NOT_FOUND)
133+
original_value = value
134+
if isinstance(value, sympy.Symbol):
135+
value = value.name
136+
if isinstance(value, str):
137+
param_value = self._param_dict_with_str_keys.get(value, _NOT_FOUND)
138+
if isinstance(param_value, float):
139+
return param_value
125140
if param_value is _NOT_FOUND:
126-
symbol = value if isinstance(value, sympy.Symbol) else sympy.Symbol(value)
127-
param_value = self._param_dict.get(symbol, _NOT_FOUND)
128-
if param_value is _NOT_FOUND:
129-
# Symbol or string cannot be resolved if not in param dict; return as symbol.
130-
return symbol
141+
# Symbol or string cannot be resolved if not in param dict; return as symbol.
142+
return sympy.Symbol(value)
131143
v = _resolve_value(param_value)
132144
if v is not NotImplemented:
133145
return v
134146
if isinstance(param_value, str):
135147
param_value = sympy.Symbol(param_value)
136148
elif not isinstance(param_value, sympy.Basic):
137-
return value
149+
return original_value
138150
if recursive:
139151
param_value = self._value_of_recursive(value)
140152
return param_value
@@ -210,7 +222,7 @@ def _value_of_recursive(self, value: cirq.TParamKey) -> cirq.TParamValComplex:
210222
self._deep_eval_map[value] = _RECURSION_FLAG
211223

212224
v = self.value_of(value, recursive=False)
213-
if v == value:
225+
if v == value or (isinstance(v, sympy.Symbol) and v.name == value):
214226
self._deep_eval_map[value] = v
215227
else:
216228
self._deep_eval_map[value] = self.value_of(v, recursive=True)
@@ -278,7 +290,7 @@ def _from_json_dict_(cls, param_dict, **kwargs):
278290

279291

280292
def _resolve_value(val: Any) -> Any:
281-
if val is None or isinstance(val, float):
293+
if isinstance(val, float) or val is None:
282294
return val
283295
if isinstance(val, numbers.Number) and not isinstance(val, sympy.Basic):
284296
return val

0 commit comments

Comments
 (0)