@@ -7,6 +7,10 @@ representation.
77
88include (" types.jl" )
99
10+ function _is_multiclass (y:: Vector{Int64} ):: Bool
11+ return length (unique (y)) > 2
12+ end
13+
1014"""
1115 from_vector(X::Matrix{Int}, y::Vector{Int}, names::Union{Vector{String}, Nothing} = nothing)
1216
@@ -26,11 +30,22 @@ function from_vector(X::Matrix{Int64}, y::Vector{Int64}, names::Union{Vector{Str
2630
2731 pos, neg, facts = String[], String[], String[]
2832
29- for (i, row) in enumerate (y)
30- if Bool (row)
31- push! (pos, " $(last (names)) (id$(i) )." )
32- else
33- push! (neg, " $(last (names)) (id$(i) )." )
33+ is_multiclass = _is_multiclass (y)
34+
35+ if is_multiclass
36+
37+ for (i, row) in enumerate (y)
38+ push! (pos, " $(last (names)) (id$(i) ,$(row) )." )
39+ end
40+
41+ else
42+
43+ for (i, row) in enumerate (y)
44+ if Bool (row)
45+ push! (pos, " $(last (names)) (id$(i) )." )
46+ else
47+ push! (neg, " $(last (names)) (id$(i) )." )
48+ end
3449 end
3550 end
3651
@@ -39,8 +54,14 @@ function from_vector(X::Matrix{Int64}, y::Vector{Int64}, names::Union{Vector{Str
3954 facts = vcat (facts, [" $(var) (id$(j) ,$(row) )." for (j, row) in enumerate (col)])
4055 end
4156
57+
4258 modes = [" $(name) (+id,#var$(name) )." for name in names[1 : end - 1 ]]
43- push! (modes, " $(last (names)) (+id)." )
59+
60+ if is_multiclass
61+ push! (modes, " $(last (names)) (+id,#classlabel)." )
62+ else
63+ push! (modes, " $(last (names)) (+id)." )
64+ end
4465
4566 return RelationalDataset ((pos, neg, facts)), modes
4667end
0 commit comments