Skip to content

Commit 644150f

Browse files
committed
Make evaluate for universal polynomials more universal
1 parent fe402c9 commit 644150f

File tree

4 files changed

+56
-102
lines changed

4 files changed

+56
-102
lines changed

src/MPoly.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,19 @@ function evaluate(a::MPolyRingElem{T}, vals::Vector{U}) where {T <: RingElement,
11301130
return a(vals...)
11311131
end
11321132

1133+
function (a::MPolyRingElem)(;kwargs...)
1134+
ss = symbols(parent(a))
1135+
vars = Array{Int}(undef, length(kwargs))
1136+
vals = Array{RingElement}(undef, length(kwargs))
1137+
for (i, (var, val)) in enumerate(kwargs)
1138+
vari = findfirst(isequal(var), ss)
1139+
vari === nothing && error("Given polynomial has no variable $var")
1140+
vars[i] = vari
1141+
vals[i] = val
1142+
end
1143+
return evaluate(a, vars, vals)
1144+
end
1145+
11331146
################################################################################
11341147
#
11351148
# Derivative

src/UnivPoly.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,25 @@ function content(a::UniversalPolyRingElem)
88
return content(data(a))
99
end
1010

11+
###############################################################################
12+
#
13+
# Evaluation
14+
#
15+
###############################################################################
16+
17+
function (a::UniversalPolyRingElem)(;kwargs...)
18+
ss = symbols(parent(a))
19+
vars = Int[]
20+
vals = RingElement[]
21+
for (var, val) in kwargs
22+
vari = findfirst(isequal(var), ss)
23+
vari === nothing && continue
24+
push!(vars, vari)
25+
push!(vals, val)
26+
end
27+
return evaluate(a, vars, vals)
28+
end
29+
1130
###############################################################################
1231
#
1332
# Iterators

src/generic/UnivPoly.jl

Lines changed: 15 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -783,93 +783,26 @@ end
783783
#
784784
###############################################################################
785785

786-
function evaluate(a::UnivPoly{T}, A::Vector{T}) where {T <: RingElem}
787-
R = base_ring(a)
786+
function evaluate(a::UnivPoly, A::Vector{<:Union{NCRingElem, RingElement}})
787+
isempty(A) && error("Too few values")
788+
a2 = data(a)
789+
varidx = var_indices(a2)
790+
isempty(varidx) && return constant_coefficient(a2)
791+
vals = zeros(parent(A[1]), nvars(parent(a2)))
788792
n = length(A)
789-
num = nvars(parent(data(a)))
790-
if n > num
791-
n > nvars(parent(a)) && error("Too many values")
792-
if nvars(parent(data(a))) == 0
793-
return constant_coefficient(data(a))*one(parent(A[1]))
794-
end
795-
return evaluate(data(a), A[1:num])
796-
end
797-
if n < num
798-
A = vcat(A, [zero(R) for i = 1:num - n])
799-
end
800-
return evaluate(data(a), A)
801-
end
802-
803-
function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}}
804-
n = length(A)
805-
num = nvars(parent(data(a)))
806-
if n > num
807-
n > nvars(parent(a)) && error("Too many values")
808-
if nvars(parent(data(a))) == 0
809-
return constant_coefficient(data(a))*one(parent(A[1]))
810-
end
811-
return evaluate(data(a), A[1:num])
812-
end
813-
if n < num
814-
A = vcat(A, zeros(V, num - n))
815-
end
816-
return evaluate(data(a), A)
817-
end
818-
819-
function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: RingElement}
820-
n = length(A)
821-
num = nvars(parent(data(a)))
822-
if n > num
823-
n > nvars(parent(a)) && error("Too many values")
824-
if nvars(parent(data(a))) == 0
825-
return constant_coefficient(data(a))*one(parent(A[1]))
826-
end
827-
return evaluate(data(a), A[1:num])
793+
for i in varidx
794+
i <= n || error("Number of variables does not match number of values")
795+
vals[i] = A[i]
828796
end
829-
if n < num
830-
if n == 0
831-
R = base_ring(a)
832-
return evaluate(data(a), [zero(R) for _ in 1:num])
833-
else
834-
R = parent(A[1])
835-
A = vcat(A, [zero(R) for _ in 1:num-n])
836-
return evaluate(data(a), A)
837-
end
838-
end
839-
return evaluate(data(a), A)
840-
end
841-
842-
function (a::UnivPoly{T})() where {T <: RingElement}
843-
return evaluate(a, T[])
797+
return evaluate(a2, vals)
844798
end
845799

846-
function (a::UnivPoly{T})(vals::T...) where {T <: RingElement}
847-
return evaluate(a, [vals...])
848-
end
849-
850-
function (a::UnivPoly{T})(val::V, vals::V...) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}}
800+
function (a::UnivPoly)(val::T, vals::T...) where T <: Union{NCRingElem, RingElement}
851801
return evaluate(a, [val, vals...])
852802
end
853803

854-
function (a::UnivPoly{T})(vals::NCRingElement...) where {T <: RingElement}
855-
A = [vals...]
856-
n = length(vals)
857-
num = nvars(parent(data(a)))
858-
if n > num
859-
n > nvars(parent(a)) && error("Too many values")
860-
if nvars(parent(data(a))) == 0
861-
return constant_coefficient(data(a))*one(parent(A[1]))
862-
end
863-
return data(a)(vals[1:num]...)
864-
end
865-
if n < num
866-
A = vcat(A, zeros(Int, num - n))
867-
end
868-
return data(a)(A...)
869-
end
870-
871-
function evaluate(a::UnivPoly{T}, vals::Vector{V}) where {T <: RingElement, V <: NCRingElem}
872-
return a(vals...)
804+
function (a::UnivPoly)()
805+
return evaluate(a, Int[])
873806
end
874807

875808
function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T <: RingElement, V <: RingElement}
@@ -878,35 +811,21 @@ function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T <
878811
vals2 = Vector{mpoly_type(T)}(undef, 0)
879812
num = nvars(parent(data(a)))
880813
S = parent(a)
881-
n = nvars(S)
814+
a2 = data(S(a))
882815
for i = 1:length(vars)
883-
vars[i] > n && error("Unknown variable")
884816
if vars[i] <= num
885817
push!(vars2, vars[i])
886818
push!(vals2, data(S(vals[i])))
887819
end
888820
end
889-
return UnivPoly(evaluate(data(S(a)), vars2, vals2), S)
821+
return UnivPoly(evaluate(a2, vars2, vals2), S)
890822
end
891823

892824
function evaluate(a::S, vars::Vector{S}, vals::Vector{V}) where {S <: UnivPoly{T}, V <: RingElement} where {T <: RingElement}
893825
varidx = Int[var_index(x) for x in vars]
894826
return evaluate(a, varidx, vals)
895827
end
896828

897-
function (a::Union{MPolyRingElem, UniversalPolyRingElem})(;kwargs...)
898-
ss = symbols(parent(a))
899-
vars = Array{Int}(undef, length(kwargs))
900-
vals = Array{RingElement}(undef, length(kwargs))
901-
for (i, (var, val)) in enumerate(kwargs)
902-
vari = findfirst(isequal(var), ss)
903-
vari === nothing && error("Given polynomial has no variable $var")
904-
vars[i] = vari
905-
vals[i] = val
906-
end
907-
return evaluate(a, vars, vals)
908-
end
909-
910829
########S,(a,b)=QQ[:a,:b]#######################################################################
911830
#
912831
# GCD

test/generic/UnivPoly-test.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -974,20 +974,19 @@ end
974974
@test evaluate(f, V) == f(V...)
975975
@test evaluate(f, V) == f([ZZ(v) for v in V]...)
976976
@test evaluate(f, V) == f([U(v) for v in V]...)
977+
@test evaluate(f, V) == evaluate(f, collect(1:n), V)
977978

978979
@test evaluate(g, V) == evaluate(g, [R(v) for v in V])
979980
@test evaluate(g, V) == evaluate(g, [ZZ(v) for v in V])
980981
@test evaluate(g, V) == evaluate(g, [U(v) for v in V])
981982
@test evaluate(g, V) == g(V...)
982983
@test evaluate(g, V) == g([ZZ(v) for v in V]...)
983984
@test evaluate(g, V) == g([U(v) for v in V]...)
985+
@test evaluate(g, V) == evaluate(g, collect(1:n), V)
984986

985-
@test evaluate(h, V) == evaluate(h, [R(v) for v in V])
986-
@test evaluate(h, V) == evaluate(h, [ZZ(v) for v in V])
987-
@test evaluate(h, V) == evaluate(h, [U(v) for v in V])
988-
@test evaluate(h, V) == h(V...)
989-
@test evaluate(h, V) == h([ZZ(v) for v in V]...)
990-
@test evaluate(h, V) == h([U(v) for v in V]...)
987+
@test parent(evaluate(g, V)) == R
988+
@test parent(evaluate(g, collect(1:n), V)) == S
989+
@test evaluate(h, [1,2,3,4]) == evaluate(h, [1,2,3,5])
991990

992991
V = [rand(-10:10) for v in 1:2]
993992

@@ -999,16 +998,20 @@ end
999998
@test evaluate(f, [1, 3], [V[1], V[2]]) == evaluate(f, [1, 3], [ZZ(v) for v in V[1:2]])
1000999
@test evaluate(f, [1, 3], [V[1], V[2]]) == f(x=V[1], z=V[2])
10011000
@test evaluate(f, [1, 3], [V[1], V[2]]) == f(z=V[2], x=V[1])
1001+
@test evaluate(f, [1, 3], [V[1], V[2]]) == f(x=V[1], z=V[2], w=0)
1002+
@test parent(evaluate(f, [1, 3], [V[1], V[2]])) == S
10021003

10031004
@test evaluate(g, [1], [V[1]]) == evaluate(g, [1], [R(V[1])])
10041005
@test evaluate(g, [1], [V[1]]) == evaluate(g, [1], [ZZ(V[1])])
10051006
@test evaluate(g, [1, 3], [V[1], V[2]]) == evaluate(g, [1, 3], [R(v) for v in V[1:2]])
10061007
@test evaluate(g, [1, 3], [V[1], V[2]]) == evaluate(g, [1, 3], [ZZ(v) for v in V[1:2]])
1008+
@test parent(evaluate(g, [1, 3], [V[1], V[2]])) == S
10071009

10081010
@test evaluate(h, [1], [V[1]]) == evaluate(h, [1], [R(V[1])])
10091011
@test evaluate(h, [1], [V[1]]) == evaluate(h, [1], [ZZ(V[1])])
10101012
@test evaluate(h, [1, 3], [V[1], V[2]]) == evaluate(h, [1, 3], [R(v) for v in V[1:2]])
10111013
@test evaluate(h, [1, 3], [V[1], V[2]]) == evaluate(h, [1, 3], [ZZ(v) for v in V[1:2]])
1014+
@test parent(evaluate(h, [1, 3], [V[1], V[2]])) == S
10121015

10131016
@test evaluate(x, [1], [y]) == evaluate(z, [3], [y])
10141017
end

0 commit comments

Comments
 (0)