Skip to content

Commit 407e0ee

Browse files
authored
Merge pull request #231 from lucidrains/residual-fsq-fix
allow for hard clamp in fsq, to ready for residual fsq pre-softclampi…
2 parents 976c3f2 + 3867f60 commit 407e0ee

File tree

4 files changed

+48
-22
lines changed

4 files changed

+48
-22
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.25.2"
3+
version = "1.26.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -23,7 +23,7 @@ classifiers=[
2323
]
2424

2525
dependencies = [
26-
"torch>=2.0",
26+
"torch>=2.4",
2727
"einops>=0.8.0",
2828
"einx>=0.3.0",
2929
]

tests/test_readme.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,15 @@ def test_directional_reparam():
247247
quantized, indices, _ = rq(x)
248248

249249
@pytest.mark.parametrize('preserve_symmetry', (True, False))
250+
@pytest.mark.parametrize('bound_hard_clamp', (True, False))
250251
def test_fsq(
251-
preserve_symmetry
252+
preserve_symmetry,
253+
bound_hard_clamp
252254
):
253255
from vector_quantize_pytorch import FSQ
254256

255257
levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
256-
quantizer = FSQ(levels, preserve_symmetry = preserve_symmetry)
258+
quantizer = FSQ(levels, preserve_symmetry = preserve_symmetry, bound_hard_clamp = bound_hard_clamp)
257259

258260
x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
259261
xhat, indices = quantizer(x)

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313
from torch.nn import Module
14-
from torch import tensor, Tensor, int32
14+
from torch import tensor, Tensor, int32, tanh, atanh, clamp
1515
from torch.amp import autocast
1616

1717
import einx
@@ -30,6 +30,9 @@ def default(*args):
3030
return arg
3131
return None
3232

33+
def identity(t):
34+
return t
35+
3336
def maybe(fn):
3437
@wraps(fn)
3538
def inner(x, *args, **kwargs):
@@ -73,6 +76,7 @@ def __init__(
7376
force_quantization_f32 = True,
7477
preserve_symmetry = False,
7578
noise_dropout = 0.,
79+
bound_hard_clamp = False # for residual fsq, if input is pre-softclamped to the right range
7680
):
7781
super().__init__()
7882

@@ -121,22 +125,31 @@ def __init__(
121125
self.allowed_dtypes = allowed_dtypes
122126
self.force_quantization_f32 = force_quantization_f32
123127

124-
def bound(self, z, eps = 1e-3):
128+
# allow for a hard clamp
129+
130+
self.bound_hard_clamp = bound_hard_clamp
131+
132+
def bound(self, z, eps = 1e-3, hard_clamp = False):
125133
""" Bound `z`, an array of shape (..., d). """
134+
maybe_tanh = tanh if not hard_clamp else partial(clamp, min = -1., max = 1.)
135+
maybe_atanh = atanh if not hard_clamp else identity
136+
126137
half_l = (self._levels - 1) * (1 + eps) / 2
127138
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
128-
shift = (offset / half_l).atanh()
129-
bounded_z = (z + shift).tanh() * half_l - offset
139+
shift = maybe_atanh(offset / half_l)
140+
bounded_z = maybe_tanh(z + shift) * half_l - offset
130141
half_width = self._levels // 2
131142
return round_ste(bounded_z) / half_width
132143

133144
# symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
134145

135-
def symmetry_preserving_bound(self, z):
146+
def symmetry_preserving_bound(self, z, hard_clamp = False):
136147
""" QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1 """
148+
maybe_tanh = tanh if not hard_clamp else partial(clamp, min = -1., max = 1.)
149+
137150
levels_minus_1 = (self._levels - 1)
138151
scale = 2. / levels_minus_1
139-
bracket = (levels_minus_1 * (z.tanh() + 1) / 2.) + 0.5
152+
bracket = (levels_minus_1 * (maybe_tanh(z) + 1) / 2.) + 0.5
140153
bracket = floor_ste(bracket)
141154
return scale * bracket - 1.
142155

@@ -146,7 +159,7 @@ def quantize(self, z):
146159
shape, device, noise_dropout, preserve_symmetry = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry
147160
bound_fn = self.symmetry_preserving_bound if preserve_symmetry else self.bound
148161

149-
bounded_z = bound_fn(z)
162+
bounded_z = bound_fn(z, hard_clamp = self.bound_hard_clamp)
150163

151164
# determine where to add a random offset elementwise
152165
# if using noise dropout

vector_quantize_pytorch/residual_fsq.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from __future__ import annotations
2+
13
import random
24
from math import ceil
35
from functools import partial
46

5-
from typing import List
6-
77
import torch
8-
from torch import nn
8+
from torch import nn, tensor
99
from torch.nn import Module, ModuleList
1010
import torch.nn.functional as F
1111
from torch.amp import autocast
@@ -52,14 +52,15 @@ class ResidualFSQ(Module):
5252
def __init__(
5353
self,
5454
*,
55-
levels: List[int],
55+
levels: list[int],
5656
num_quantizers,
5757
dim = None,
5858
is_channel_first = False,
5959
quantize_dropout = False,
6060
quantize_dropout_cutoff_index = 0,
6161
quantize_dropout_multiple_of = 1,
62-
soft_clamp_input_value = None,
62+
soft_clamp_input_value: float | list[float] | Tensor | None = None,
63+
bound_hard_clamp = True,
6364
**kwargs
6465
):
6566
super().__init__()
@@ -74,25 +75,24 @@ def __init__(
7475
self.is_channel_first = is_channel_first
7576
self.num_quantizers = num_quantizers
7677

77-
# soft clamping the input value
78-
79-
self.soft_clamp_input_value = soft_clamp_input_value
80-
8178
# layers
8279

8380
self.levels = levels
8481
self.layers = nn.ModuleList([])
8582

86-
levels_tensor = torch.Tensor(levels)
83+
levels_tensor = tensor(levels)
84+
assert (levels_tensor > 1).all()
8785

8886
scales = []
8987

9088
for ind in range(num_quantizers):
91-
scales.append((levels_tensor - 1) ** -ind)
89+
scales.append(levels_tensor.float() ** -ind)
9290

9391
fsq = FSQ(
9492
levels = levels,
9593
dim = codebook_dim,
94+
preserve_symmetry = True,
95+
bound_hard_clamp = bound_hard_clamp,
9696
**kwargs
9797
)
9898

@@ -111,6 +111,17 @@ def __init__(
111111
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
112112
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
113113

114+
# soft clamping the input value
115+
116+
if bound_hard_clamp:
117+
assert not exists(soft_clamp_input_value)
118+
soft_clamp_input_value = 1 + (1 / (levels_tensor - 1))
119+
120+
if isinstance(soft_clamp_input_value, (list, float)):
121+
soft_clamp_input_value = tensor(soft_clamp_input_value)
122+
123+
self.register_buffer('soft_clamp_input_value', soft_clamp_input_value, persistent = False)
124+
114125
@property
115126
def codebooks(self):
116127
codebooks = [layer.implicit_codebook for layer in self.layers]

0 commit comments

Comments
 (0)