diff --git a/src/MPoly.jl b/src/MPoly.jl index 27c90aa8fc..27b4263a8b 100644 --- a/src/MPoly.jl +++ b/src/MPoly.jl @@ -1130,6 +1130,19 @@ function evaluate(a::MPolyRingElem{T}, vals::Vector{U}) where {T <: RingElement, return a(vals...) end +function (a::MPolyRingElem)(;kwargs...) + ss = symbols(parent(a)) + vars = Array{Int}(undef, length(kwargs)) + vals = Array{RingElement}(undef, length(kwargs)) + for (i, (var, val)) in enumerate(kwargs) + vari = findfirst(isequal(var), ss) + vari === nothing && error("Given polynomial has no variable $var") + vars[i] = vari + vals[i] = val + end + return evaluate(a, vars, vals) +end + ################################################################################ # # Derivative diff --git a/src/UnivPoly.jl b/src/UnivPoly.jl index 2cb4d2d1f0..a368d722ad 100644 --- a/src/UnivPoly.jl +++ b/src/UnivPoly.jl @@ -8,6 +8,25 @@ function content(a::UniversalPolyRingElem) return content(data(a)) end +############################################################################### +# +# Evaluation +# +############################################################################### + +function (a::UniversalPolyRingElem)(;kwargs...) + ss = symbols(parent(a)) + vars = Int[] + vals = RingElement[] + for (var, val) in kwargs + vari = findfirst(isequal(var), ss) + vari === nothing && continue + push!(vars, vari) + push!(vals, val) + end + return evaluate(a, vars, vals) +end + ############################################################################### # # Iterators diff --git a/src/generic/UnivPoly.jl b/src/generic/UnivPoly.jl index 08a84d1b5d..518c597ec0 100644 --- a/src/generic/UnivPoly.jl +++ b/src/generic/UnivPoly.jl @@ -783,93 +783,26 @@ end # ############################################################################### -function evaluate(a::UnivPoly{T}, A::Vector{T}) where {T <: RingElem} - R = base_ring(a) +function evaluate(a::UnivPoly, A::Vector{<:Union{NCRingElem, RingElement}}) + isempty(A) && error("Too few values") + a2 = data(a) + varidx = var_indices(a2) + isempty(varidx) && return constant_coefficient(a2) + vals = zeros(parent(A[1]), nvars(parent(a2))) n = length(A) - num = nvars(parent(data(a))) - if n > num - n > nvars(parent(a)) && error("Too many values") - if nvars(parent(data(a))) == 0 - return constant_coefficient(data(a))*one(parent(A[1])) - end - return evaluate(data(a), A[1:num]) - end - if n < num - A = vcat(A, [zero(R) for i = 1:num - n]) - end - return evaluate(data(a), A) -end - -function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}} - n = length(A) - num = nvars(parent(data(a))) - if n > num - n > nvars(parent(a)) && error("Too many values") - if nvars(parent(data(a))) == 0 - return constant_coefficient(data(a))*one(parent(A[1])) - end - return evaluate(data(a), A[1:num]) - end - if n < num - A = vcat(A, zeros(V, num - n)) - end - return evaluate(data(a), A) -end - -function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: RingElement} - n = length(A) - num = nvars(parent(data(a))) - if n > num - n > nvars(parent(a)) && error("Too many values") - if nvars(parent(data(a))) == 0 - return constant_coefficient(data(a))*one(parent(A[1])) - end - return evaluate(data(a), A[1:num]) + for i in varidx + i <= n || error("Number of variables does not match number of values") + vals[i] = A[i] end - if n < num - if n == 0 - R = base_ring(a) - return evaluate(data(a), [zero(R) for _ in 1:num]) - else - R = parent(A[1]) - A = vcat(A, [zero(R) for _ in 1:num-n]) - return evaluate(data(a), A) - end - end - return evaluate(data(a), A) -end - -function (a::UnivPoly{T})() where {T <: RingElement} - return evaluate(a, T[]) + return evaluate(a2, vals) end -function (a::UnivPoly{T})(vals::T...) where {T <: RingElement} - return evaluate(a, [vals...]) -end - -function (a::UnivPoly{T})(val::V, vals::V...) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}} +function (a::UnivPoly)(val::T, vals::T...) where T <: Union{NCRingElem, RingElement} return evaluate(a, [val, vals...]) end -function (a::UnivPoly{T})(vals::NCRingElement...) where {T <: RingElement} - A = [vals...] - n = length(vals) - num = nvars(parent(data(a))) - if n > num - n > nvars(parent(a)) && error("Too many values") - if nvars(parent(data(a))) == 0 - return constant_coefficient(data(a))*one(parent(A[1])) - end - return data(a)(vals[1:num]...) - end - if n < num - A = vcat(A, zeros(Int, num - n)) - end - return data(a)(A...) -end - -function evaluate(a::UnivPoly{T}, vals::Vector{V}) where {T <: RingElement, V <: NCRingElem} - return a(vals...) +function (a::UnivPoly)() + return evaluate(a, Int[]) end function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T <: RingElement, V <: RingElement} @@ -878,15 +811,14 @@ function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T < vals2 = Vector{mpoly_type(T)}(undef, 0) num = nvars(parent(data(a))) S = parent(a) - n = nvars(S) + a2 = data(S(a)) for i = 1:length(vars) - vars[i] > n && error("Unknown variable") if vars[i] <= num push!(vars2, vars[i]) push!(vals2, data(S(vals[i]))) end end - return UnivPoly(evaluate(data(S(a)), vars2, vals2), S) + return UnivPoly(evaluate(a2, vars2, vals2), S) end function evaluate(a::S, vars::Vector{S}, vals::Vector{V}) where {S <: UnivPoly{T}, V <: RingElement} where {T <: RingElement} @@ -894,19 +826,6 @@ function evaluate(a::S, vars::Vector{S}, vals::Vector{V}) where {S <: UnivPoly{T return evaluate(a, varidx, vals) end -function (a::Union{MPolyRingElem, UniversalPolyRingElem})(;kwargs...) - ss = symbols(parent(a)) - vars = Array{Int}(undef, length(kwargs)) - vals = Array{RingElement}(undef, length(kwargs)) - for (i, (var, val)) in enumerate(kwargs) - vari = findfirst(isequal(var), ss) - vari === nothing && error("Given polynomial has no variable $var") - vars[i] = vari - vals[i] = val - end - return evaluate(a, vars, vals) -end - ########S,(a,b)=QQ[:a,:b]####################################################################### # # GCD diff --git a/test/generic/UnivPoly-test.jl b/test/generic/UnivPoly-test.jl index 741f14d62f..3614c0907c 100644 --- a/test/generic/UnivPoly-test.jl +++ b/test/generic/UnivPoly-test.jl @@ -974,6 +974,7 @@ end @test evaluate(f, V) == f(V...) @test evaluate(f, V) == f([ZZ(v) for v in V]...) @test evaluate(f, V) == f([U(v) for v in V]...) + @test evaluate(f, V) == evaluate(f, collect(1:n), V) @test evaluate(g, V) == evaluate(g, [R(v) for v in V]) @test evaluate(g, V) == evaluate(g, [ZZ(v) for v in V]) @@ -981,13 +982,11 @@ end @test evaluate(g, V) == g(V...) @test evaluate(g, V) == g([ZZ(v) for v in V]...) @test evaluate(g, V) == g([U(v) for v in V]...) + @test evaluate(g, V) == evaluate(g, collect(1:n), V) - @test evaluate(h, V) == evaluate(h, [R(v) for v in V]) - @test evaluate(h, V) == evaluate(h, [ZZ(v) for v in V]) - @test evaluate(h, V) == evaluate(h, [U(v) for v in V]) - @test evaluate(h, V) == h(V...) - @test evaluate(h, V) == h([ZZ(v) for v in V]...) - @test evaluate(h, V) == h([U(v) for v in V]...) + @test parent(evaluate(g, V)) == R + @test parent(evaluate(g, collect(1:n), V)) == S + @test evaluate(h, [1,2,3,4]) == evaluate(h, [1,2,3,5]) V = [rand(-10:10) for v in 1:2] @@ -999,16 +998,20 @@ end @test evaluate(f, [1, 3], [V[1], V[2]]) == evaluate(f, [1, 3], [ZZ(v) for v in V[1:2]]) @test evaluate(f, [1, 3], [V[1], V[2]]) == f(x=V[1], z=V[2]) @test evaluate(f, [1, 3], [V[1], V[2]]) == f(z=V[2], x=V[1]) + @test evaluate(f, [1, 3], [V[1], V[2]]) == f(x=V[1], z=V[2], w=0) + @test parent(evaluate(f, [1, 3], [V[1], V[2]])) == S @test evaluate(g, [1], [V[1]]) == evaluate(g, [1], [R(V[1])]) @test evaluate(g, [1], [V[1]]) == evaluate(g, [1], [ZZ(V[1])]) @test evaluate(g, [1, 3], [V[1], V[2]]) == evaluate(g, [1, 3], [R(v) for v in V[1:2]]) @test evaluate(g, [1, 3], [V[1], V[2]]) == evaluate(g, [1, 3], [ZZ(v) for v in V[1:2]]) + @test parent(evaluate(g, [1, 3], [V[1], V[2]])) == S @test evaluate(h, [1], [V[1]]) == evaluate(h, [1], [R(V[1])]) @test evaluate(h, [1], [V[1]]) == evaluate(h, [1], [ZZ(V[1])]) @test evaluate(h, [1, 3], [V[1], V[2]]) == evaluate(h, [1, 3], [R(v) for v in V[1:2]]) @test evaluate(h, [1, 3], [V[1], V[2]]) == evaluate(h, [1, 3], [ZZ(v) for v in V[1:2]]) + @test parent(evaluate(h, [1, 3], [V[1], V[2]])) == S @test evaluate(x, [1], [y]) == evaluate(z, [3], [y]) end