Skip to content

Commit 74f3126

Browse files
adienesfingolfinSeelengrab
authored
clear up several more inconsistencies among gcd , gcdx and invmod (#59678)
test suite of invariances satisfied among `gcd`, `gcdx`, `invmod`: ``` function can_compute_gcd(::Type{T}, a, b) where {T} (a > b) && return can_compute_gcd(T, b, a) if T <: Signed return a != typemin(T) || (b != typemin(T) && b != 0) else return (a > typemin(a)) || iszero(a) end end check_bezout(T::Type{U}, a, b, d, u, v) where U = u * a + v * b == d function check_bezout(T::Type{S}, a, b, d, u, v) where {S<:Signed} W = widen(T) return W(u) * W(a) + W(v) * W(b) == W(d) end bad_values = Tuple{Integer, Integer}[] let Ts_signed = (Int8, Int16, Int32, Int64) Ts_unsigned = (UInt8, UInt16, UInt32, UInt64) Ts = (Ts_signed..., Ts_unsigned...) T = UInt128 test_values = T[] nbits = (sizeof(T) << 3) append!(test_values, [-one(T), zero(T), one(T), typemin(T), typemax(T)]) append!(test_values, reduce(vcat, [ [T(1) << i, (T(1) << i) - 1, (T(1) << i) + 1] for i in 0:nbits ])) append!(test_values, reduce(vcat, [ [T(3)^div(i,2), T(3)^div(i,2) - 1, T(3)^div(i,2) + 1] for i in 0:2:nbits ])) function do_tests(a, b) T = Base.promote_typeof(a, b) if !can_compute_gcd(T, a, b) return else try gcd(a, b) gcdx(a, b) catch push!(bad_values, (a, b)) return end end g = gcd(a, b) d, u, v = gcdx(a, b) properties = ( ispositive(g) && g === d && (typeof(g) == typeof(d)) && check_bezout(T, a, b, d, u, v) && mod(a, g) == mod(b, g) == 0 ) if isone(g) && !iszero(b) ai = invmod(a, b) properties = properties && ( (mod(one(b), b) == mod(widemul(a, ai), b)) && ((sign(ai) == sign(b)) || (sign(ai) == 0)) && (ai === mod(ai, b)) && (typeof(ai) == typeof(mod(a, b))) && iszero(div(ai, b)) ) end if !properties println("$(typeof(a))($a), $(typeof(b))($b)") push!(bad_values, (a, b)) end end for T1 in Ts As = test_values .% T1 for T2 in Ts Bs = test_values .% T2 for a in As for b in Bs iszero(a) && iszero(b) && continue do_tests(a, b) end end end end end ``` --------- Co-authored-by: Max Horn <max@quendi.de> Co-authored-by: Sukera <11753998+Seelengrab@users.noreply.github.com>
1 parent 11c517e commit 74f3126

File tree

2 files changed

+86
-33
lines changed

2 files changed

+86
-33
lines changed

base/intfuncs.jl

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,19 @@ function lcm(a::T, b::T) where T<:Integer
143143
end
144144
end
145145

146+
function _promote_mixed_signs(a::Signed, b::Unsigned)
147+
# handle the case a == typemin(typeof(a)) if R != typeof(a)
148+
R = promote_typeof(a, b)
149+
promote(abs(a % signed(R)), b)
150+
end
151+
146152
gcd(a::Integer) = checked_abs(a)
147153
gcd(a::Rational) = checked_abs(a.num) // a.den
148154
lcm(a::Union{Integer,Rational}) = gcd(a)
149-
gcd(a::Unsigned, b::Signed) = gcd(promote(a, abs(b))...)
150-
gcd(a::Signed, b::Unsigned) = gcd(promote(abs(a), b)...)
155+
gcd(a::Unsigned, b::Signed) = gcd(b, a)
156+
gcd(a::Signed, b::Unsigned) = gcd(_promote_mixed_signs(a, b)...)
151157
lcm(a::Unsigned, b::Signed) = lcm(promote(a, abs(b))...)
152-
lcm(a::Signed, b::Unsigned) = lcm(promote(abs(a), b)...)
158+
lcm(a::Signed, b::Unsigned) = lcm(_promote_mixed_signs(a, b)...)
153159
gcd(a::Real, b::Real) = gcd(promote(a,b)...)
154160
lcm(a::Real, b::Real) = lcm(promote(a,b)...)
155161
gcd(a::Real, b::Real, c::Real...) = gcd(a, gcd(b, c...))
@@ -223,22 +229,31 @@ julia> gcdx(15, 12, 20)
223229
their `typemax`, and the identity then holds only via the unsigned
224230
integers' modulo arithmetic.
225231
"""
226-
Base.@assume_effects :terminates_locally function gcdx(a::Integer, b::Integer)
227-
T = promote_type(typeof(a), typeof(b))
228-
a == b == 0 && return (zero(T), zero(T), zero(T))
232+
Base.@assume_effects :terminates_locally function gcdx(a::T, b::T) where {T<:Integer}
233+
if iszero(a) && iszero(b)
234+
return (zero(T), zero(T), zero(T))
235+
elseif isone(abs(b))
236+
# handles (typemin(::Signed), -1)
237+
return (one(T), zero(T), b)
238+
elseif isone(abs(a))
239+
return (one(T), a, zero(T))
240+
end
229241
# a0, b0 = a, b
230242
s0, s1 = oneunit(T), zero(T)
231243
t0, t1 = s1, s0
232244
# The loop invariant is: s0*a0 + t0*b0 == a && s1*a0 + t1*b0 == b
233-
x = a % T
234-
y = b % T
235-
while y != 0
236-
q, r = divrem(x, y)
237-
x, y = y, r
245+
while !iszero(b)
246+
q, r = divrem(a, b)
247+
a, b = b, r
238248
s0, s1 = s1, s0 - q*s1
239249
t0, t1 = t1, t0 - q*t1
240250
end
241-
x < 0 ? (-x, -s0, -t0) : (x, s0, t0)
251+
# for cases like abs(Int8(-128))
252+
if isnegative(a) && isnegative(abs(a))
253+
throw(DomainError((a, b), LazyString("gcd not representable in ", T)))
254+
else
255+
return isnegative(a) ? (abs(a), -s0, -t0) : (a, s0, t0)
256+
end
242257
end
243258
gcdx(a::Real, b::Real) = gcdx(promote(a,b)...)
244259
gcdx(a::T, b::T) where T<:Real = throw(MethodError(gcdx, (a,b)))
@@ -254,11 +269,12 @@ function gcdx(a::Real, b::Real, cs::Real...)
254269
d′, x, ys... = gcdx(d, cs...)
255270
return d′, i*x, j*x, ys...
256271
end
272+
257273
function gcdx(a::Signed, b::Unsigned)
258-
R = promote_type(typeof(a), typeof(b))
259-
_a = a % signed(R) # handle the case a == typemin(typeof(a)) if R != typeof(a)
260-
d, u, v = gcdx(promote(abs(_a), b)...)
261-
d, flipsign(u, a), v
274+
R = promote_typeof(a, b)
275+
d, u, v = gcdx(promote(abs(a % signed(R)), b)...)
276+
flip_typemin = isnegative(a) & (R <: Signed)
277+
d, flipsign(u, a - flip_typemin), v
262278
end
263279
function gcdx(a::Unsigned, b::Signed)
264280
d, v, u = gcdx(b, a)
@@ -287,24 +303,38 @@ julia> invmod(5, 6)
287303
```
288304
"""
289305
function invmod(n::Integer, m::Integer)
306+
# The postcondition is: mod(widemul(result, n), m) == mod(one(T), m) && iszero(div(result, m))
290307
iszero(m) && throw(DomainError(m, "`m` must not be 0."))
291-
if n isa Signed && hastypemax(typeof(n))
292-
# work around inconsistencies in gcdx
293-
# https://github.com/JuliaLang/julia/issues/33781
294-
T = promote_type(typeof(n), typeof(m))
295-
n == typemin(typeof(n)) && m == typeof(n)(-1) && return T(0)
296-
n == typeof(n)(-1) && m == typemin(typeof(n)) && return T(-1)
308+
R = promote_typeof(n, m)
309+
if R <: Signed
310+
x = _bezout_coef(n, m)
311+
return mod(x, m)
312+
else
313+
S = signed(R)
314+
if !hastypemax(S) || (n <= typemax(S)) && (m <= typemax(S))
315+
x = _bezout_coef(n % S, m % S)
316+
317+
# this branch is only hit if R <: Unsigned, so we don't have
318+
# to worry about abs(typemin(::Signed)) overflow. If `m` is
319+
# signed then `x` must be unsigned, and thus never negative
320+
isnegative(x) && (x += abs(m))
321+
return mod(x % R, m)
322+
else
323+
# since gcdx only promises bezout w.r.t overflow for unsigned ints,
324+
# we have to widen to a signed type
325+
W = widen(S)
326+
x = _bezout_coef(n % W, m % W)
327+
t = mod(x, m % W)
328+
isnegative(m) && (t -= m)
329+
return mod(t % R, m)
330+
end
297331
end
298-
g, x, y = gcdx(n, m)
332+
end
333+
334+
function _bezout_coef(n, m)
335+
g, x, _ = gcdx(n, m)
299336
g != 1 && throw(DomainError((n, m), LazyString("Greatest common divisor is ", g, ".")))
300-
# Note that m might be negative here.
301-
if x isa Unsigned && hastypemax(typeof(x)) && x > typemax(x)>>1
302-
# x might have wrapped if it would have been negative
303-
# adding back m forces a correction
304-
x += m
305-
end
306-
# The postcondition is: mod(result * n, m) == mod(T(1), m) && div(result, m) == 0
307-
return mod(x, m)
337+
return x
308338
end
309339

310340
"""

test/intfuncs.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,19 @@ end
217217
d, u, v = gcdx(x, y)
218218
@test x*u + y*v == d
219219

220+
for T in (Int8, Int16, Int32, Int64, Int128)
221+
@test_throws DomainError gcdx(typemin(T), typemin(T))
222+
@test_throws DomainError gcdx(typemin(T), T(0))
223+
@test_throws DomainError gcdx(T(0), typemin(T))
224+
d, u, v = gcdx(typemin(T), T(-1))
225+
@test d == T(1)
226+
@test typemin(T) * u + T(-1) * v == T(1)
227+
@test gcdx(T(-1), typemin(T)) == (d, v, u)
228+
d, u, v = gcdx(typemin(T), T(1))
229+
@test d == T(1)
230+
@test typemin(T) * u + T(1) * v == T(1)
231+
@test gcdx(T(1), typemin(T)) == (d, v, u)
232+
end
220233
end
221234

222235
# issue #58025
@@ -244,7 +257,7 @@ end
244257

245258
@test gcdx(Int16(-32768), Int8(-128)) === (Int16(128), Int16(0), Int16(-1))
246259
@test gcdx(Int8(-128), UInt16(256)) === (0x0080, 0xffff, 0x0000)
247-
@test_broken gcd(Int8(-128), UInt16(256)) === 0x0080
260+
@test gcd(Int8(-128), UInt16(256)) === 0x0080
248261
end
249262

250263
@testset "gcd/lcm/gcdx for custom types" begin
@@ -294,10 +307,20 @@ end
294307
# Verify issue described in PR 58010 is fixed
295308
@test invmod(UInt8(3), UInt16(50000)) === 0x411b
296309

310+
@test invmod(0x00000001, Int8(-128)) === Int32(-127)
311+
@test invmod(0xffffffff, Int8(-38)) === Int32(-15)
312+
@test invmod(Int8(-1), 0xffffffff) === 0xfffffffe
313+
@test invmod(Int32(-1), typemin(Int64)) === Int64(-1)
314+
@test invmod(0x3e81, Int16(-5716)) === Int16(-2407)
315+
297316
for T in (Int8, UInt8)
298317
for x in typemin(T):typemax(T)
299318
for m in typemin(T):typemax(T)
300-
if m != 0 && try gcdx(x, m)[1] == 1 catch _ true end
319+
if !(
320+
iszero(m) ||
321+
iszero(mod(x, m)) && !isone(abs(m)) ||
322+
!isone(gcd(x, m))
323+
)
301324
y = invmod(x, m)
302325
@test mod(widemul(y, x), m) == mod(1, m)
303326
@test div(y, m) == 0

0 commit comments

Comments
 (0)