Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 9de60bd

Browse files
authored
Support RotatingTimeArray (#8)
1 parent 8879696 commit 9de60bd

File tree

3 files changed

+50
-11
lines changed

3 files changed

+50
-11
lines changed

src/procedures.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,11 @@ function maketreatcols(data, treatname::Symbol, treatintterms::TermSet,
196196
cols = subcolumns(data, cellnames, esample)
197197
cells, rows = cellrows(cols, findcell(cols))
198198

199-
rel = refarray(cells[2]) .- refarray(cells[1])
199+
if cells[1] isa RotatingTimeArray
200+
rel = refarray(cells[2].time) .- refarray(cells[1].time)
201+
else
202+
rel = refarray(cells[2]) .- refarray(cells[1])
203+
end
200204
kept = .!haskey.(Ref(exc), rel) .& .!haskey.(Ref(notreat), cells[1])
201205
treatrows = rows[kept]
202206
# Construct cells needed for treatment indicators
@@ -495,7 +499,7 @@ function solveleastsquaresweights(::DynamicTreatment{SharpDesign},
495499
end
496500
feM === nothing || _feresiduals!(d, feM, fetol, femaxiter)
497501
weights isa UnitWeights || (d .*= sqrt.(weights))
498-
lswtmat[i,:] .= (crossx \ (X'd))[1:nt]
502+
lswtmat[i,:] .= view((crossx \ (X'd)), 1:nt)
499503
end
500504
ycellmeans ./= ycellweights
501505
lswt = TableIndexedMatrix(lswtmat, lswtcells, treatcells)

test/did.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,25 @@
5656
──────────────────────────────────────────────────────────────────────"""
5757

5858
df = DataFrame(hrs)
59-
df.wave = Date.(df.wave)
60-
df.wave_hosp = Date.(df.wave_hosp)
61-
df.wave = settime(df, :wave, step=Year(1))
62-
df.wave_hosp = settime(df, :wave_hosp, start=Date(7), step=Year(1))
59+
df.wave = settime(Date.(hrs.wave), Year(1))
60+
df.wave_hosp = settime(Date.(hrs.wave_hosp), Year(1), start=Date(7))
6361
r1 = @did(Reg, data=df, dynamic(:wave, -1), notyettreated(Date(11)),
6462
vce=Vcov.cluster(:hhidpn), yterm=term(:oop_spend), treatname=:wave_hosp,
65-
treatintterms=(), xterms=(fe(:wave)+fe(:hhidpn)), solvelsweights=false)
63+
treatintterms=(), xterms=(fe(:wave)+fe(:hhidpn)), solvelsweights=true)
6664
@test coef(r1) coef(r)
6765
@test r1.coefnames[1] == "wave_hosp: 0008-01-01 & rel: 0"
6866

67+
rot = ifelse.(isodd.(hrs.hhidpn), 1, 2)
68+
df.wave = settime(Date.(hrs.wave), Year(1), rotation=rot)
69+
df.wave_hosp = settime(Date.(hrs.wave_hosp), Year(1), start=Date(7), rotation=rot)
70+
e = rotatingtime((1,2), Date(11))
71+
r2 = @did(Reg, data=df, dynamic(:wave, -1), notyettreated(e),
72+
vce=Vcov.cluster(:hhidpn), yterm=term(:oop_spend), treatname=:wave_hosp,
73+
treatintterms=(), xterms=(fe(:wave)+fe(:hhidpn)), solvelsweights=true)
74+
@test length(coef(r2)) == 18
75+
@test coef(r2)[1] 3790.7218412450593
76+
@test r2.coefnames[1] == "wave_hosp: 1_0008-01-01 & rel: 0"
77+
6978
r = @did(Reg, data=hrs, dynamic(:wave, -1), notyettreated([11]),
7079
vce=Vcov.cluster(:hhidpn), yterm=term(:oop_spend), treatname=:wave_hosp,
7180
treatintterms=(), cohortinteracted=false, lswtnames=(:wave_hosp, :wave))

test/procedures.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,8 @@ end
254254
@test combinedargs(MakeTreatCols(), allntargs) ==
255255
(Dict{Int,Int}(), IdDict{ValidTimeType,Int}())
256256

257-
df.wave = Date.(df.wave)
258-
df.wave_hosp = Date.(df.wave_hosp)
259-
df.wave = settime(df, :wave, step=Year(1))
260-
df.wave_hosp = settime(df, :wave_hosp, start=Date(7), step=Year(1))
257+
df.wave = settime(Date.(hrs.wave), Year(1))
258+
df.wave_hosp = settime(Date.(hrs.wave_hosp), Year(1), start=Date(7))
261259
ret1 = maketreatcols(nt..., typeof(tr), tr.time,
262260
Dict(-1=>1), IdDict{ValidTimeType,Int}(Date(11)=>1))
263261
@test ret1.cells[1] == Date.(ret.cells[1])
@@ -270,6 +268,34 @@ end
270268
@test ret1.cellweights == ret.cellweights
271269
@test ret1.cellcounts == ret.cellcounts
272270

271+
rot = ifelse.(isodd.(hrs.hhidpn), 1, 2)
272+
df.wave = RotatingTimeArray(rot, hrs.wave)
273+
df.wave_hosp = RotatingTimeArray(rot, hrs.wave_hosp)
274+
e = rotatingtime((1,1,2), (10,11,11))
275+
ret2 = maketreatcols(nt..., typeof(tr), tr.time,
276+
Dict(-1=>1), IdDict{ValidTimeType,Int}(c=>1 for c in e))
277+
@test ret2.cells[1] == sort!(append!((rotatingtime(r, ret.cells[1]) for r in (1,2))...))
278+
rt = append!((rotatingtime(r, 7:11) for r in (1,2))...)
279+
@test ret2.cells[2] == repeat(rt, 4)
280+
@test size(ret2.treatcells) == (20, 2)
281+
@test sort!(unique(ret2.treatcells[1])) ==
282+
sort!(append!(rotatingtime(1, 8:9), rotatingtime(2, 8:10)))
283+
284+
df.wave = settime(Date.(hrs.wave), Year(1), rotation=rot)
285+
df.wave_hosp = settime(Date.(hrs.wave_hosp), Year(1), start=Date(7), rotation=rot)
286+
e = rotatingtime((1,1,2), Date.((10,11,11)))
287+
ret3 = maketreatcols(nt..., typeof(tr), tr.time,
288+
Dict(-1=>1), IdDict{ValidTimeType,Int}(c=>1 for c in e))
289+
@test ret3.cells[1].time == Date.(ret2.cells[1].time)
290+
@test ret3.cells[2].time == Date.(ret2.cells[2].time)
291+
@test ret3.rows == ret2.rows
292+
@test ret3.treatcells[1].time == Date.(ret2.treatcells[1].time)
293+
@test ret3.treatcells[2] == ret2.treatcells[2]
294+
@test ret3.treatrows == ret2.treatrows
295+
@test ret3.treatcols == ret2.treatcols
296+
@test ret3.cellweights == ret2.cellweights
297+
@test ret3.cellcounts == ret2.cellcounts
298+
273299
nt = merge(nt, (data=hrs, tr=tr, pr=pr))
274300
@test MakeTreatCols()(nt) == merge(nt, (cells=ret.cells, rows=ret.rows,
275301
treatcells=ret.treatcells, treatrows=ret.treatrows, treatcols=ret.treatcols,

0 commit comments

Comments
 (0)