Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 5b5942d

Browse files
authored
Add setindex! for ScaledArray and support missing values (#27)
1 parent 5b02fbe commit 5b5942d

File tree

7 files changed

+113
-21
lines changed

7 files changed

+113
-21
lines changed

src/DiffinDiffsBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using Tables: AbstractColumns, istable, columnnames, getcolumn
2222

2323
import Base: ==, +, -, *, isless, show, parent, view, diff
2424
import Base: eltype, firstindex, lastindex, getindex, iterate, length, sym_in
25+
import Missings: allowmissing, disallowmissing
2526
import StatsBase: coef, vcov, confint, nobs, dof_residual, responsename, coefnames, weights,
2627
coeftable
2728
import StatsModels: concrete_term, schema, termvars, lag, lead

src/ScaledArrays.jl

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,31 @@ end
1111
An array type that stores data as indices of a range.
1212
1313
# Fields
14-
- `refs::RA<:AbstractArray{<:Any, N}`: an array of indices.
15-
- `pool::P<:AbstractRange{T}`: a range that covers all possible values stored by the array.
14+
- `refs::RA<:AbstractArray{R,N}`: an array of indices.
15+
- `pool::P<:AbstractRange`: a range that covers all possible values stored by the array.
1616
- `invpool::Dict{T,R}`: a map from array elements to indices of `pool`.
1717
"""
1818
mutable struct ScaledArray{T,R,N,RA,P} <: AbstractArray{T,N}
1919
refs::RA
2020
pool::P
2121
invpool::Dict{T,R}
22-
ScaledArray{T,R,N,RA,P}(rs::RefArray{RA}, pool::P, invpool::Dict{T,R}) where
23-
{T, R, N, RA<:AbstractArray{R, N}, P<:AbstractRange{T}} =
24-
new{T,R,N,RA,P}(rs.a, pool, invpool)
22+
function ScaledArray{T,R,N,RA,P}(rs::RefArray{RA}, pool::P, invpool::Dict{T,R}) where
23+
{T, R, N, RA<:AbstractArray{R, N}, P<:AbstractRange}
24+
eltype(P) == nonmissingtype(T) || throw(ArgumentError(
25+
"expect element type of pool being $(nonmissingtype(T)); got $(eltype(P))"))
26+
return new{T,R,N,RA,P}(rs.a, pool, invpool)
27+
end
2528
end
2629

27-
ScaledArray(rs::RefArray{RA}, pool::P, invpool::Dict{T,R}) where
28-
{T,R,RA<:AbstractArray{R},P} = ScaledArray{T,R,ndims(RA),RA,P}(rs, pool, invpool)
30+
ScaledArray(rs::RefArray{RA}, pool::P, invpool::Dict{T,R}) where {T,R,RA<:AbstractArray{R},P} =
31+
ScaledArray{T,R,ndims(RA),RA,P}(rs, pool, invpool)
2932

3033
const ScaledVector{T,R} = ScaledArray{T,R,1}
3134
const ScaledMatrix{T,R} = ScaledArray{T,R,2}
3235

33-
scale(sa::ScaledArray) = step(sa.pool)
36+
const ScaledArrOrSub = Union{ScaledArray, SubArray{<:Any, <:Any, <:ScaledArray}}
37+
38+
scale(sa::ScaledArrOrSub) = step(DataAPI.refpool(sa))
3439

3540
function _validmin(min, xmin, isstart::Bool)
3641
if min === nothing
@@ -178,6 +183,14 @@ ScaledArray(sa::ScaledArray, step=nothing; reftype::Type=eltype(refarray(sa)),
178183
Base.size(sa::ScaledArray) = size(sa.refs)
179184
Base.IndexStyle(::Type{<:ScaledArray{T,R,N,RA}}) where {T,R,N,RA} = IndexStyle(RA)
180185

186+
Base.similar(sa::ScaledArray{T,R}, dims::Dims=size(sa)) where {T,R} =
187+
ScaledArray(RefArray(ones(R, dims)), DataAPI.refpool(sa), Dict{T,R}())
188+
189+
Base.similar(sa::SubArray{<:Any, <:Any, <:ScaledArray{T,R}}, dims::Dims=size(sa)) where {T,R} =
190+
ScaledArray(RefArray(ones(R, dims)), DataAPI.refpool(sa), Dict{T,R}())
191+
192+
Base.similar(sa::ScaledArrOrSub, dims::Int...) = similar(sa, dims)
193+
181194
DataAPI.refarray(sa::ScaledArray) = sa.refs
182195
DataAPI.refvalue(sa::ScaledArray, n::Integer) = getindex(DataAPI.refpool(sa), n)
183196
DataAPI.refpool(sa::ScaledArray) = sa.pool
@@ -202,19 +215,51 @@ DataAPI.invrefpool(ssa::SubArray{<:Any, <:Any, <:ScaledArray}) =
202215
return @inbounds pool[n]
203216
end
204217

205-
@inline function Base.getindex(sa::ScaledArray, I...)
206-
refs = DataAPI.refarray(sa)
207-
@boundscheck checkbounds(refs, I...)
208-
@inbounds ns = refs[I...]
218+
Base.@propagate_inbounds function Base.getindex(sa::ScaledArrOrSub, I::AbstractVector)
219+
newrefs = DataAPI.refarray(sa)[I]
209220
pool = DataAPI.refpool(sa)
210-
N = length(pool)
211-
@boundscheck checkindex(Bool, 0:N, ns) || throw_boundserror(pool, ns)
212-
return @inbounds pool[ns]
221+
invpool = DataAPI.invrefpool(sa)
222+
return ScaledArray(RefArray(newrefs), pool, invpool)
223+
end
224+
225+
Base.@propagate_inbounds function Base.setindex!(sa::ScaledArray, val, ind::Int)
226+
invpool = DataAPI.invrefpool(sa)
227+
n = get(invpool, val, nothing)
228+
if n === nothing
229+
pool = DataAPI.refpool(sa)
230+
r = first(pool):step(pool):val
231+
last(r) > last(pool) && (sa.pool = r)
232+
n = length(r)
233+
invpool[val] = n
234+
end
235+
refs = DataAPI.refarray(sa)
236+
refs[ind] = n
237+
return sa
238+
end
239+
240+
Base.@propagate_inbounds function Base.setindex!(sa::ScaledArray, ::Missing, ind::Int)
241+
refs = DataAPI.refarray(sa)
242+
z = zero(eltype(refs))
243+
invpool = DataAPI.invrefpool(sa)
244+
invpool[missing] = z
245+
refs[ind] = z
246+
return sa
247+
end
248+
249+
allowmissing(sa::ScaledArray{T,R}) where {T,R} =
250+
ScaledArray(RefArray(sa.refs), sa.pool, convert(Dict{Union{T,Missing},R}, sa.invpool))
251+
252+
function disallowmissing(sa::ScaledArray{T,R}) where {T,R}
253+
T1 = nonmissingtype(T)
254+
any(x->iszero(x), sa.refs) && throw(ArgumentError("cannot convert missing to $T1"))
255+
delete!(sa.invpool, missing)
256+
return ScaledArray(RefArray(sa.refs), sa.pool, convert(Dict{T1,R}, sa.invpool))
213257
end
214258

215-
function Base.:(==)(x::ScaledArray, y::ScaledArray)
259+
function Base.:(==)(x::ScaledArrOrSub, y::ScaledArrOrSub)
216260
size(x) == size(y) || return false
217-
first(x.pool) == first(y.pool) && step(x.pool) == step(y.pool) && return x.refs == y.refs
261+
first(DataAPI.refpool(x)) == first(DataAPI.refpool(y)) &&
262+
scale(x) == scale(y) && return DataAPI.refarray(x) == DataAPI.refarray(y)
218263
eq = true
219264
for (p, q) in zip(x, y)
220265
# missing could arise

src/operations.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,15 @@ function cellrows(cols::VecColumnTable, refrows::IdDict)
9494
ncol = length(cols)
9595
ncell = length(refrows)
9696
rows = Vector{Vector{Int}}(undef, ncell)
97-
columns = AbstractVector[Vector{eltype(c)}(undef, ncell) for c in cols]
97+
columns = Vector{AbstractVector}(undef, ncol)
98+
for i in 1:ncol
99+
c = cols[i]
100+
if typeof(c) <: ScaledArray || typeof(c) <: SubArray{<:Any,1,<:ScaledArray}
101+
columns[i] = similar(c, ncell)
102+
else
103+
columns[i] = Vector{eltype(c)}(undef, ncell)
104+
end
105+
end
98106
refs = Vector{keytype(refrows)}(undef, ncell)
99107
r = 0
100108
@inbounds for (k, v) in refrows

src/tables.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ end
9696

9797
Base.values(cols::VecColumnTable) = _columns(cols)
9898
Base.haskey(cols::VecColumnTable, key::Symbol) = haskey(_lookup(cols), key)
99-
Base.haskey(cols::VecColumnTable, i::Int) = 0 < i < length(_names(cols))
99+
Base.haskey(cols::VecColumnTable, i::Int) = 0 < i <= length(_names(cols))
100100

101101
function Base.:(==)(x::VecColumnTable, y::VecColumnTable)
102102
size(x) == size(y) || return false

test/ScaledArrays.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,51 @@ using DiffinDiffsBase: RefArray, validpool, scaledlabel
5050
@test sa5.refs == sa4.refs
5151
@test sa5.refs !== sa4.refs
5252

53+
sa6 = ScaledArray([missing, 1, 2], 1)
54+
@test eltype(sa6) == Union{Int,Missing}
55+
@test sa6.refs == 0:2
56+
@test eltype(sa6.refs) == Int32
57+
@test ismissing(sa6[1])
58+
5359
@test_throws ArgumentError ScaledArray(sa5, stop=7, usepool=false)
60+
@test_throws ArgumentError ScaledArray(RefArray(1:3), 1.0:0.5:3.0, Dict{Int,Int}())
5461

5562
@test size(sa) == (10,)
5663
@test IndexStyle(typeof(sa)) == IndexLinear()
5764

65+
@test similar(sa, 4) == ScaledArray(RefArray(ones(Int, 4)), sa.pool, Dict{Date,Int}())
66+
@test similar(sa, (4,)) == similar(view(sa, 1:5), 4) == similar(sa, 4)
67+
5868
@test refarray(sa) === sa.refs
5969
@test refvalue(sa, 1) == Date(1)
6070
@test refpool(sa) === sa.pool
6171
@test invrefpool(sa) === sa.invpool
6272

63-
ssa = view(sa, 3:4)
64-
@test refarray(ssa) == view(sa.refs, 3:4)
73+
ssa = view(sa, 3:5)
74+
@test refarray(ssa) == view(sa.refs, 3:5)
6575
@test refvalue(ssa, 1) == Date(1)
6676
@test refpool(ssa) === sa.pool
6777
@test invrefpool(ssa) === sa.invpool
78+
@test ssa == sa[3:5]
6879

6980
@test sa[1] == Date(1)
7081
@test sa[1:2] == sa[[1,2]] == sa[(1:10).<3] == Date.(1:2)
82+
@test sa[1:2] isa ScaledArray
83+
@test ssa[1:2] isa ScaledArray
84+
85+
sa[1] = Date(10)
86+
@test sa[1] == Date(10)
87+
sa[1:2] .= Date(100)
88+
@test sa[1:2] == [Date(100), Date(100)]
89+
@test last(sa.pool) == Date(100)
90+
@test_throws MethodError sa[1] = missing
91+
sa = allowmissing(sa)
92+
sa[1] = missing
93+
@test ismissing(sa[1])
94+
@test_throws ArgumentError disallowmissing(sa)
95+
sa[1] = Date(1)
96+
sa = disallowmissing(sa)
97+
@test eltype(sa) == Date
7198
end
7299

73100
@testset "scaledlabel" begin

test/operations.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ end
4242
sortslices(unique(hcat(hrs.wave, hrs.wave_hosp), dims=1), dims=1)
4343
@test propertynames(cells) == [:wave, :wave_hosp]
4444
@test rows[1] == intersect(findall(x->x==7, hrs.wave), findall(x->x==8, hrs.wave_hosp))
45+
46+
df = DataFrame(hrs)
47+
df.wave = ScaledArray(df.wave, 1)
48+
df.wave_hosp = ScaledArray(df.wave_hosp, 1)
49+
cols1 = subcolumns(df, (:wave, :wave_hosp))
50+
cells1, rows1 = cellrows(cols1, rows_dict)
51+
@test rows1 == rows
52+
@test cells1 == cells
53+
@test cells1.wave isa ScaledArray
54+
@test cells1.wave_hosp isa ScaledArray
4555
end
4656

4757
@testset "settime" begin

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using DiffinDiffsBase: @fieldequal, unpack, @unpack, checktable, hastreat, parse
1111
_totermset!, parse_didargs!, _treatnames, _parse_bycells!, _parse_subset, _nselected,
1212
treatindex, checktreatindex
1313
using LinearAlgebra: Diagonal
14+
using Missings: allowmissing, disallowmissing
1415
using PooledArrays: PooledArray
1516
using StatsBase: Weights, UnitWeights
1617
using StatsModels: termvars

0 commit comments

Comments
 (0)