Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/MPoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,19 @@ function evaluate(a::MPolyRingElem{T}, vals::Vector{U}) where {T <: RingElement,
return a(vals...)
end

function (a::MPolyRingElem)(;kwargs...)
ss = symbols(parent(a))
vars = Array{Int}(undef, length(kwargs))
vals = Array{RingElement}(undef, length(kwargs))
for (i, (var, val)) in enumerate(kwargs)
vari = findfirst(isequal(var), ss)
vari === nothing && error("Given polynomial has no variable $var")
vars[i] = vari
vals[i] = val
end
return evaluate(a, vars, vals)
end

################################################################################
#
# Derivative
Expand Down
19 changes: 19 additions & 0 deletions src/UnivPoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@ function content(a::UniversalPolyRingElem)
return content(data(a))
end

###############################################################################
#
# Evaluation
#
###############################################################################

function (a::UniversalPolyRingElem)(;kwargs...)
ss = symbols(parent(a))
vars = Int[]
vals = RingElement[]
for (var, val) in kwargs
vari = findfirst(isequal(var), ss)
vari === nothing && continue
push!(vars, vari)
push!(vals, val)
end
return evaluate(a, vars, vals)
end

###############################################################################
#
# Iterators
Expand Down
111 changes: 15 additions & 96 deletions src/generic/UnivPoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -783,93 +783,26 @@ end
#
###############################################################################

function evaluate(a::UnivPoly{T}, A::Vector{T}) where {T <: RingElem}
R = base_ring(a)
function evaluate(a::UnivPoly, A::Vector{<:Union{NCRingElem, RingElement}})
isempty(A) && error("Too few values")
a2 = data(a)
varidx = var_indices(a2)
isempty(varidx) && return constant_coefficient(a2)
vals = zeros(parent(A[1]), nvars(parent(a2)))
n = length(A)
num = nvars(parent(data(a)))
if n > num
n > nvars(parent(a)) && error("Too many values")
if nvars(parent(data(a))) == 0
return constant_coefficient(data(a))*one(parent(A[1]))
end
return evaluate(data(a), A[1:num])
end
if n < num
A = vcat(A, [zero(R) for i = 1:num - n])
end
return evaluate(data(a), A)
end

function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}}
n = length(A)
num = nvars(parent(data(a)))
if n > num
n > nvars(parent(a)) && error("Too many values")
if nvars(parent(data(a))) == 0
return constant_coefficient(data(a))*one(parent(A[1]))
end
return evaluate(data(a), A[1:num])
end
if n < num
A = vcat(A, zeros(V, num - n))
end
return evaluate(data(a), A)
end

function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: RingElement}
n = length(A)
num = nvars(parent(data(a)))
if n > num
n > nvars(parent(a)) && error("Too many values")
if nvars(parent(data(a))) == 0
return constant_coefficient(data(a))*one(parent(A[1]))
end
return evaluate(data(a), A[1:num])
for i in varidx
i <= n || error("Number of variables does not match number of values")
vals[i] = A[i]
end
if n < num
if n == 0
R = base_ring(a)
return evaluate(data(a), [zero(R) for _ in 1:num])
else
R = parent(A[1])
A = vcat(A, [zero(R) for _ in 1:num-n])
return evaluate(data(a), A)
end
end
return evaluate(data(a), A)
end

function (a::UnivPoly{T})() where {T <: RingElement}
return evaluate(a, T[])
return evaluate(a2, vals)
end

function (a::UnivPoly{T})(vals::T...) where {T <: RingElement}
return evaluate(a, [vals...])
end

function (a::UnivPoly{T})(val::V, vals::V...) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}}
function (a::UnivPoly)(val::T, vals::T...) where T <: Union{NCRingElem, RingElement}
return evaluate(a, [val, vals...])
end

function (a::UnivPoly{T})(vals::NCRingElement...) where {T <: RingElement}
A = [vals...]
n = length(vals)
num = nvars(parent(data(a)))
if n > num
n > nvars(parent(a)) && error("Too many values")
if nvars(parent(data(a))) == 0
return constant_coefficient(data(a))*one(parent(A[1]))
end
return data(a)(vals[1:num]...)
end
if n < num
A = vcat(A, zeros(Int, num - n))
end
return data(a)(A...)
end

function evaluate(a::UnivPoly{T}, vals::Vector{V}) where {T <: RingElement, V <: NCRingElem}
return a(vals...)
function (a::UnivPoly)()
return evaluate(a, Int[])
end

function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T <: RingElement, V <: RingElement}
Expand All @@ -878,35 +811,21 @@ function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T <
vals2 = Vector{mpoly_type(T)}(undef, 0)
num = nvars(parent(data(a)))
S = parent(a)
n = nvars(S)
a2 = data(S(a))
for i = 1:length(vars)
vars[i] > n && error("Unknown variable")
if vars[i] <= num
push!(vars2, vars[i])
push!(vals2, data(S(vals[i])))
end
end
return UnivPoly(evaluate(data(S(a)), vars2, vals2), S)
return UnivPoly(evaluate(a2, vars2, vals2), S)
end

function evaluate(a::S, vars::Vector{S}, vals::Vector{V}) where {S <: UnivPoly{T}, V <: RingElement} where {T <: RingElement}
varidx = Int[var_index(x) for x in vars]
return evaluate(a, varidx, vals)
end

function (a::Union{MPolyRingElem, UniversalPolyRingElem})(;kwargs...)
ss = symbols(parent(a))
vars = Array{Int}(undef, length(kwargs))
vals = Array{RingElement}(undef, length(kwargs))
for (i, (var, val)) in enumerate(kwargs)
vari = findfirst(isequal(var), ss)
vari === nothing && error("Given polynomial has no variable $var")
vars[i] = vari
vals[i] = val
end
return evaluate(a, vars, vals)
end

########S,(a,b)=QQ[:a,:b]#######################################################################
#
# GCD
Expand Down
15 changes: 9 additions & 6 deletions test/generic/UnivPoly-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -974,20 +974,19 @@ end
@test evaluate(f, V) == f(V...)
@test evaluate(f, V) == f([ZZ(v) for v in V]...)
@test evaluate(f, V) == f([U(v) for v in V]...)
@test evaluate(f, V) == evaluate(f, collect(1:n), V)

@test evaluate(g, V) == evaluate(g, [R(v) for v in V])
@test evaluate(g, V) == evaluate(g, [ZZ(v) for v in V])
@test evaluate(g, V) == evaluate(g, [U(v) for v in V])
@test evaluate(g, V) == g(V...)
@test evaluate(g, V) == g([ZZ(v) for v in V]...)
@test evaluate(g, V) == g([U(v) for v in V]...)
@test evaluate(g, V) == evaluate(g, collect(1:n), V)

@test evaluate(h, V) == evaluate(h, [R(v) for v in V])
@test evaluate(h, V) == evaluate(h, [ZZ(v) for v in V])
@test evaluate(h, V) == evaluate(h, [U(v) for v in V])
@test evaluate(h, V) == h(V...)
@test evaluate(h, V) == h([ZZ(v) for v in V]...)
@test evaluate(h, V) == h([U(v) for v in V]...)
@test parent(evaluate(g, V)) == R
@test parent(evaluate(g, collect(1:n), V)) == S
@test evaluate(h, [1,2,3,4]) == evaluate(h, [1,2,3,5])

V = [rand(-10:10) for v in 1:2]

Expand All @@ -999,16 +998,20 @@ end
@test evaluate(f, [1, 3], [V[1], V[2]]) == evaluate(f, [1, 3], [ZZ(v) for v in V[1:2]])
@test evaluate(f, [1, 3], [V[1], V[2]]) == f(x=V[1], z=V[2])
@test evaluate(f, [1, 3], [V[1], V[2]]) == f(z=V[2], x=V[1])
@test evaluate(f, [1, 3], [V[1], V[2]]) == f(x=V[1], z=V[2], w=0)
@test parent(evaluate(f, [1, 3], [V[1], V[2]])) == S

@test evaluate(g, [1], [V[1]]) == evaluate(g, [1], [R(V[1])])
@test evaluate(g, [1], [V[1]]) == evaluate(g, [1], [ZZ(V[1])])
@test evaluate(g, [1, 3], [V[1], V[2]]) == evaluate(g, [1, 3], [R(v) for v in V[1:2]])
@test evaluate(g, [1, 3], [V[1], V[2]]) == evaluate(g, [1, 3], [ZZ(v) for v in V[1:2]])
@test parent(evaluate(g, [1, 3], [V[1], V[2]])) == S

@test evaluate(h, [1], [V[1]]) == evaluate(h, [1], [R(V[1])])
@test evaluate(h, [1], [V[1]]) == evaluate(h, [1], [ZZ(V[1])])
@test evaluate(h, [1, 3], [V[1], V[2]]) == evaluate(h, [1, 3], [R(v) for v in V[1:2]])
@test evaluate(h, [1, 3], [V[1], V[2]]) == evaluate(h, [1, 3], [ZZ(v) for v in V[1:2]])
@test parent(evaluate(h, [1, 3], [V[1], V[2]])) == S

@test evaluate(x, [1], [y]) == evaluate(z, [3], [y])
end
Expand Down
Loading