Skip to content

Commit c2dbc0a

Browse files
committed
Adding changes related to SET command
1 parent 654e09e commit c2dbc0a

File tree

4 files changed

+275
-32
lines changed

4 files changed

+275
-32
lines changed

redis/commands/core.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
extract_expire_flags,
5151
)
5252

53-
from .helpers import list_or_args
53+
from .helpers import at_most_one_value_set, list_or_args
5454

5555
if TYPE_CHECKING:
5656
import redis.asyncio.client
@@ -1732,8 +1732,8 @@ def __delitem__(self, name: KeyT):
17321732
def delex(
17331733
self,
17341734
name: KeyT,
1735-
ifeq: Optional[EncodableT] = None,
1736-
ifne: Optional[EncodableT] = None,
1735+
ifeq: Optional[Union[bytes, str]] = None,
1736+
ifne: Optional[Union[bytes, str]] = None,
17371737
ifdeq: Optional[str] = None, # hex digest
17381738
ifdne: Optional[str] = None, # hex digest
17391739
) -> int:
@@ -1752,6 +1752,8 @@ def delex(
17521752
and a condition is specified.
17531753
ValueError: if more than one condition is provided.
17541754
1755+
1756+
Requires Redis 8.4 or greater.
17551757
For more information, see https://redis.io/commands/delex
17561758
"""
17571759
conds = [x is not None for x in (ifeq, ifne, ifdeq, ifdne)]
@@ -1886,6 +1888,8 @@ def digest(self, name: KeyT) -> Optional[str]:
18861888
Raises:
18871889
- ResponseError if key exists but is not a string
18881890
1891+
1892+
Requires Redis 8.4 or greater.
18891893
For more information, see https://redis.io/commands/digest
18901894
"""
18911895
# Bulk string response is already handled (bytes/str based on decode_responses)
@@ -1939,8 +1943,7 @@ def getex(
19391943
19401944
For more information, see https://redis.io/commands/getex
19411945
"""
1942-
opset = {ex, px, exat, pxat}
1943-
if len(opset) > 2 or len(opset) > 1 and persist:
1946+
if not at_most_one_value_set((ex, px, exat, pxat, persist)):
19441947
raise DataError(
19451948
"``ex``, ``px``, ``exat``, ``pxat``, "
19461949
"and ``persist`` are mutually exclusive."
@@ -2128,8 +2131,7 @@ def msetex(
21282131
Available since Redis 8.4
21292132
For more information, see https://redis.io/commands/msetex
21302133
"""
2131-
opset = {ex, px, exat, pxat}
2132-
if len(opset) > 2 or len(opset) > 1 and keepttl:
2134+
if not at_most_one_value_set((ex, px, exat, pxat, keepttl)):
21332135
raise DataError(
21342136
"``ex``, ``px``, ``exat``, ``pxat``, "
21352137
"and ``keepttl`` are mutually exclusive."
@@ -2395,6 +2397,10 @@ def set(
23952397
get: bool = False,
23962398
exat: Optional[AbsExpiryT] = None,
23972399
pxat: Optional[AbsExpiryT] = None,
2400+
ifeq: Optional[Union[bytes, str]] = None,
2401+
ifne: Optional[Union[bytes, str]] = None,
2402+
ifdeq: Optional[str] = None, # hex digest of current value
2403+
ifdne: Optional[str] = None, # hex digest of current value
23982404
) -> ResponseT:
23992405
"""
24002406
Set the value at key ``name`` to ``value``
@@ -2422,35 +2428,67 @@ def set(
24222428
``pxat`` sets an expire flag on key ``name`` for ``ex`` milliseconds,
24232429
specified in unix time.
24242430
2431+
``ifeq`` set the value at key ``name`` to ``value`` only if the current
2432+
value exactly matches the argument.
2433+
If key doesn’t exist - it won’t be created.
2434+
(Requires Redis 8.4 or greater)
2435+
2436+
``ifne`` set the value at key ``name`` to ``value`` only if the current
2437+
value does not exactly match the argument.
2438+
If key doesn’t exist - it will be created.
2439+
(Requires Redis 8.4 or greater)
2440+
2441+
``ifdeq`` set the value at key ``name`` to ``value`` only if the current
2442+
value XXH3 hex digest exactly matches the argument.
2443+
If key doesn’t exist - it won’t be created.
2444+
(Requires Redis 8.4 or greater)
2445+
2446+
``ifdne`` set the value at key ``name`` to ``value`` only if the current
2447+
value XXH3 hex digest does not exactly match the argument.
2448+
If key doesn’t exist - it will be created.
2449+
(Requires Redis 8.4 or greater)
2450+
24252451
For more information, see https://redis.io/commands/set
24262452
"""
2427-
opset = {ex, px, exat, pxat}
2428-
if len(opset) > 2 or len(opset) > 1 and keepttl:
2453+
2454+
if not at_most_one_value_set((ex, px, exat, pxat, keepttl)):
24292455
raise DataError(
24302456
"``ex``, ``px``, ``exat``, ``pxat``, "
24312457
"and ``keepttl`` are mutually exclusive."
24322458
)
24332459

2434-
if nx and xx:
2435-
raise DataError("``nx`` and ``xx`` are mutually exclusive.")
2460+
# Enforce mutual exclusivity among all conditional switches.
2461+
if not at_most_one_value_set((nx, xx, ifeq, ifne, ifdeq, ifdne)):
2462+
raise DataError(
2463+
"``nx``, ``xx``, ``ifeq``, ``ifne``, ``ifdeq``, ``ifdne`` are mutually exclusive."
2464+
)
24362465

24372466
pieces: list[EncodableT] = [name, value]
24382467
options = {}
24392468

2440-
pieces.extend(extract_expire_flags(ex, px, exat, pxat))
2441-
2442-
if keepttl:
2443-
pieces.append("KEEPTTL")
2444-
2469+
# Conditional modifier (exactly one at most)
24452470
if nx:
24462471
pieces.append("NX")
2447-
if xx:
2472+
elif xx:
24482473
pieces.append("XX")
2474+
elif ifeq is not None:
2475+
pieces.extend(("IFEQ", ifeq))
2476+
elif ifne is not None:
2477+
pieces.extend(("IFNE", ifne))
2478+
elif ifdeq is not None:
2479+
pieces.extend(("IFDEQ", ifdeq))
2480+
elif ifdne is not None:
2481+
pieces.extend(("IFDNE", ifdne))
24492482

24502483
if get:
24512484
pieces.append("GET")
24522485
options["get"] = True
24532486

2487+
pieces.extend(extract_expire_flags(ex, px, exat, pxat))
2488+
2489+
if keepttl:
2490+
pieces.append("KEEPTTL")
2491+
24542492
return self.execute_command("SET", *pieces, **options)
24552493

24562494
def __setitem__(self, name: KeyT, value: EncodableT):
@@ -5257,8 +5295,7 @@ def hgetex(
52575295
if not keys:
52585296
raise DataError("'hgetex' should have at least one key provided")
52595297

5260-
opset = {ex, px, exat, pxat}
5261-
if len(opset) > 2 or len(opset) > 1 and persist:
5298+
if not at_most_one_value_set((ex, px, exat, pxat, persist)):
52625299
raise DataError(
52635300
"``ex``, ``px``, ``exat``, ``pxat``, "
52645301
"and ``persist`` are mutually exclusive."
@@ -5403,8 +5440,7 @@ def hsetex(
54035440
"'items' must contain a list of key/value pairs."
54045441
)
54055442

5406-
opset = {ex, px, exat, pxat}
5407-
if len(opset) > 2 or len(opset) > 1 and keepttl:
5443+
if not at_most_one_value_set((ex, px, exat, pxat, keepttl)):
54085444
raise DataError(
54095445
"``ex``, ``px``, ``exat``, ``pxat``, "
54105446
"and ``keepttl`` are mutually exclusive."

redis/commands/helpers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import random
33
import string
4-
from typing import List, Tuple
4+
from typing import Any, Iterable, List, Tuple
55

66
import redis
77
from redis.typing import KeysT, KeyT
@@ -96,3 +96,22 @@ def get_protocol_version(client):
9696
return client.connection_pool.connection_kwargs.get("protocol")
9797
elif isinstance(client, redis.cluster.AbstractRedisCluster):
9898
return client.nodes_manager.connection_kwargs.get("protocol")
99+
100+
101+
def at_most_one_value_set(iterable: Iterable[Any]):
102+
"""
103+
Checks that at most one of the values in the iterable is truthy.
104+
105+
Args:
106+
iterable: An iterable of values to check.
107+
108+
Returns:
109+
True if at most one value is truthy, False otherwise.
110+
111+
Raises:
112+
Might raise an error if the values in iterable are not boolean-compatible.
113+
For example if the type of the values implement
114+
__len__ or __bool__ methods and they raise an error.
115+
"""
116+
values = (bool(x) for x in iterable)
117+
return sum(values) <= 1

tests/test_asyncio/test_commands.py

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pytest
1313
import pytest_asyncio
14-
from redis import RedisClusterException, ResponseError
14+
from redis import DataError, RedisClusterException, ResponseError
1515
import redis
1616
from redis import exceptions
1717
from redis._parsers.helpers import (
@@ -1125,10 +1125,10 @@ async def test_delex_ifdeq_and_ifdne(self, r, val):
11251125

11261126
@skip_if_server_version_lt("8.3.224")
11271127
async def test_delex_pipeline(self, r):
1128-
await r.mset({"p1": b"A", "p2": b"B"})
1128+
await r.mset({"p1{45}": b"A", "p2{45}": b"B"})
11291129
p = r.pipeline()
1130-
p.delex("p1", ifeq=b"A")
1131-
p.delex("p2", ifne=b"B") # false → 0
1130+
p.delex("p1{45}", ifeq=b"A")
1131+
p.delex("p2{45}", ifne=b"B") # false → 0
11321132
p.delex("nope") # nonexistent → 0
11331133
out = await p.execute()
11341134
assert out == [1, 0, 0]
@@ -1253,7 +1253,7 @@ async def test_digest_response_when_available(self, r, value):
12531253

12541254
@skip_if_server_version_lt("8.3.224")
12551255
async def test_pipeline_digest(self, r):
1256-
k1, k2 = "k:d1", "k:d2"
1256+
k1, k2 = "k:d1{42}", "k:d2{42}"
12571257
await r.mset({k1: b"A", k2: b"B"})
12581258
p = r.pipeline()
12591259
p.digest(k1)
@@ -1849,6 +1849,102 @@ async def test_set_keepttl(self, r: redis.Redis):
18491849
assert await r.get("a") == b"2"
18501850
assert 0 < await r.ttl("a") <= 10
18511851

1852+
@skip_if_server_version_lt("8.3.224")
1853+
async def test_set_ifeq_true_sets_and_returns_true(self, r):
1854+
await r.delete("k")
1855+
await r.set("k", b"foo")
1856+
assert await r.set("k", b"bar", ifeq=b"foo") is True
1857+
assert await r.get("k") == b"bar"
1858+
1859+
@skip_if_server_version_lt("8.3.224")
1860+
async def test_set_ifeq_false_does_not_set_returns_none(self, r):
1861+
await r.delete("k")
1862+
await r.set("k", b"foo")
1863+
assert await r.set("k", b"bar", ifeq=b"nope") is None
1864+
assert await r.get("k") == b"foo"
1865+
1866+
@skip_if_server_version_lt("8.3.224")
1867+
async def test_set_ifne_true_sets(self, r):
1868+
await r.delete("k")
1869+
await r.set("k", b"foo")
1870+
assert await r.set("k", b"bar", ifne=b"zzz") is True
1871+
assert await r.get("k") == b"bar"
1872+
1873+
@skip_if_server_version_lt("8.3.224")
1874+
async def test_set_ifne_false_does_not_set(self, r):
1875+
await r.delete("k")
1876+
await r.set("k", b"foo")
1877+
assert await r.set("k", b"bar", ifne=b"foo") is None
1878+
assert await r.get("k") == b"foo"
1879+
1880+
@skip_if_server_version_lt("8.3.224")
1881+
async def test_set_ifeq_when_key_missing_does_not_create(self, r):
1882+
await r.delete("k")
1883+
assert await r.set("k", b"bar", ifeq=b"foo") is None
1884+
assert await r.exists("k") == 0
1885+
1886+
@skip_if_server_version_lt("8.3.224")
1887+
async def test_set_ifne_when_key_missing_creates(self, r):
1888+
await r.delete("k")
1889+
assert await r.set("k", b"bar", ifne=b"foo") is True
1890+
assert await r.get("k") == b"bar"
1891+
1892+
@skip_if_server_version_lt("8.3.224")
1893+
@pytest.mark.parametrize("val", [b"", b"abc", b"The quick brown fox"])
1894+
async def test_set_ifdeq_and_ifdne(self, r, val):
1895+
await r.delete("k")
1896+
await r.set("k", val)
1897+
d = await self._server_xxh3_digest(r, "k")
1898+
assert d is not None
1899+
1900+
# IFDEQ must match to set; if key missing => won't create
1901+
assert await r.set("k", b"X", ifdeq=d) is True
1902+
assert await r.get("k") == b"X"
1903+
1904+
await r.delete("k")
1905+
# key missing + IFDEQ => not created
1906+
assert await r.set("k", b"Y", ifdeq=d) is None
1907+
assert await r.exists("k") == 0
1908+
1909+
# IFDNE: create when missing, and set when digest differs
1910+
assert await r.set("k", b"bar", ifdne=d) is True
1911+
prev_d = await self._server_xxh3_digest(r, "k")
1912+
assert prev_d is not None
1913+
# If digest equal → do not set
1914+
assert await r.set("k", b"zzz", ifdne=prev_d) is None
1915+
assert await r.get("k") == b"bar"
1916+
1917+
@skip_if_server_version_lt("8.3.224")
1918+
async def test_set_with_get_returns_previous_value(self, r):
1919+
await r.delete("k")
1920+
# when key didn’t exist → returns None, and key is created if condition allows it
1921+
prev = await r.set("k", b"v1", get=True, ifne=b"any") # IFNE on missing creates
1922+
assert prev is None
1923+
# subsequent GET returns previous value, regardless of whether set occurs
1924+
prev2 = await r.set(
1925+
"k", b"v2", get=True, ifeq=b"v1"
1926+
) # matches → set; returns "v1"
1927+
assert prev2 == b"v1"
1928+
prev3 = await r.set(
1929+
"k", b"v3", get=True, ifeq=b"no"
1930+
) # no set; returns previous "v2"
1931+
assert prev3 == b"v2"
1932+
assert await r.get("k") == b"v2"
1933+
1934+
@skip_if_server_version_lt("8.3.224")
1935+
async def test_set_mutual_exclusion_client_side(self, r):
1936+
await r.delete("k")
1937+
with pytest.raises(DataError):
1938+
await r.set("k", b"v", nx=True, ifeq=b"x")
1939+
with pytest.raises(DataError):
1940+
await r.set("k", b"v", ifdeq="aa", ifdne="bb")
1941+
with pytest.raises(DataError):
1942+
await r.set("k", b"v", ex=1, px=1)
1943+
with pytest.raises(DataError):
1944+
await r.set("k", b"v", exat=1, pxat=1)
1945+
with pytest.raises(DataError):
1946+
await r.set("k", b"v", ex=1, exat=1)
1947+
18521948
async def test_setex(self, r: redis.Redis):
18531949
assert await r.setex("a", 60, "1")
18541950
assert await r.get("a") == b"1"

0 commit comments

Comments
 (0)