Skip to content

Commit 949ad8f

Browse files
committed
Make evaluate for universal polynomials more universal
1 parent 0a6e99f commit 949ad8f

File tree

1 file changed

+21
-84
lines changed

1 file changed

+21
-84
lines changed

src/generic/UnivPoly.jl

Lines changed: 21 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -787,110 +787,44 @@ end
787787
#
788788
###############################################################################
789789

790-
function evaluate(a::UnivPoly{T}, A::Vector{T}) where {T <: RingElem}
791-
R = base_ring(a)
792-
n = length(A)
793-
num = nvars(parent(data(a)))
794-
if n > num
795-
n > nvars(parent(a)) && error("Too many values")
796-
if nvars(parent(data(a))) == 0
797-
return constant_coefficient(data(a))*one(parent(A[1]))
798-
end
799-
return evaluate(data(a), A[1:num])
790+
function evaluate(a::UnivPoly, A::Vector{<:Union{NCRingElem, RingElement}})
791+
a2 = data(a)
792+
varidx = Int[var_index(x) for x in vars(a2)]
793+
if isempty(A)
794+
isempty(varidx) || error("Number of variables does not match number of values")
795+
return constant_coefficient(a2)
800796
end
801-
if n < num
802-
A = vcat(A, [zero(R) for i = 1:num - n])
803-
end
804-
return evaluate(data(a), A)
805-
end
806-
807-
function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}}
797+
vals = zeros(parent(A[1]), nvars(parent(a2)))
808798
n = length(A)
809-
num = nvars(parent(data(a)))
810-
if n > num
811-
n > nvars(parent(a)) && error("Too many values")
812-
if nvars(parent(data(a))) == 0
813-
return constant_coefficient(data(a))*one(parent(A[1]))
814-
end
815-
return evaluate(data(a), A[1:num])
816-
end
817-
if n < num
818-
A = vcat(A, zeros(V, num - n))
799+
for i in varidx
800+
i <= n || error("Number of variables does not match number of values")
801+
vals[i] = A[i]
819802
end
820-
return evaluate(data(a), A)
803+
return evaluate(a2, vals)
821804
end
822805

823-
function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: RingElement}
824-
n = length(A)
825-
num = nvars(parent(data(a)))
826-
if n > num
827-
n > nvars(parent(a)) && error("Too many values")
828-
if nvars(parent(data(a))) == 0
829-
return constant_coefficient(data(a))*one(parent(A[1]))
830-
end
831-
return evaluate(data(a), A[1:num])
832-
end
833-
if n < num
834-
if n == 0
835-
R = base_ring(a)
836-
return evaluate(data(a), [zero(R) for _ in 1:num])
837-
else
838-
R = parent(A[1])
839-
A = vcat(A, [zero(R) for _ in 1:num-n])
840-
return evaluate(data(a), A)
841-
end
842-
end
843-
return evaluate(data(a), A)
844-
end
845-
846-
function (a::UnivPoly{T})() where {T <: RingElement}
847-
return evaluate(a, T[])
848-
end
849-
850-
function (a::UnivPoly{T})(vals::T...) where {T <: RingElement}
851-
return evaluate(a, [vals...])
852-
end
853-
854-
function (a::UnivPoly{T})(val::V, vals::V...) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}}
806+
function (a::UnivPoly)(val::T, vals::T...) where T <: Union{NCRingElem, RingElement}
855807
return evaluate(a, [val, vals...])
856808
end
857809

858-
function (a::UnivPoly{T})(vals::Union{NCRingElem, RingElement}...) where {T <: RingElement}
859-
A = [vals...]
860-
n = length(vals)
861-
num = nvars(parent(data(a)))
862-
if n > num
863-
n > nvars(parent(a)) && error("Too many values")
864-
if nvars(parent(data(a))) == 0
865-
return constant_coefficient(data(a))*one(parent(A[1]))
866-
end
867-
return data(a)(vals[1:num]...)
868-
end
869-
if n < num
870-
A = vcat(A, zeros(Int, num - n))
871-
end
872-
return data(a)(A...)
873-
end
874-
875-
function evaluate(a::UnivPoly{T}, vals::Vector{V}) where {T <: RingElement, V <: NCRingElem}
876-
return a(vals...)
810+
function (a::UnivPoly)()
811+
return evaluate(a, Int[])
877812
end
878813

879814
function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T <: RingElement, V <: RingElement}
880815
length(vars) != length(vals) && error("Numbers of variables and values do not match")
881816
vars2 = Vector{Int}(undef, 0)
882817
vals2 = Vector{mpoly_type(T)}(undef, 0)
883-
num = nvars(parent(data(a)))
884818
S = parent(a)
885-
n = nvars(S)
819+
a2 = data(S(a))
820+
num = nvars(S)
886821
for i = 1:length(vars)
887-
vars[i] > n && error("Unknown variable")
888822
if vars[i] <= num
889823
push!(vars2, vars[i])
890824
push!(vals2, data(S(vals[i])))
891825
end
892826
end
893-
return UnivPoly(evaluate(data(S(a)), vars2, vals2), S)
827+
return UnivPoly(evaluate(a2, vars2, vals2), S)
894828
end
895829

896830
function evaluate(a::S, vars::Vector{S}, vals::Vector{V}) where {S <: UnivPoly{T}, V <: RingElement} where {T <: RingElement}
@@ -904,7 +838,10 @@ function (a::Union{MPolyRingElem, UniversalPolyRingElem})(;kwargs...)
904838
vals = Array{RingElement}(undef, length(kwargs))
905839
for (i, (var, val)) in enumerate(kwargs)
906840
vari = findfirst(isequal(var), ss)
907-
vari === nothing && error("Given polynomial has no variable $var")
841+
if vari === nothing
842+
isa(a, MPolyRingElem) && error("Given polynomial has no variable $var")
843+
continue
844+
end
908845
vars[i] = vari
909846
vals[i] = val
910847
end

0 commit comments

Comments
 (0)