Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 5e470d1

Browse files
authored
Support UnspecifiedParallel (#9)
1 parent 9de60bd commit 5e470d1

File tree

4 files changed

+59
-27
lines changed

4 files changed

+59
-27
lines changed

src/did.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ const RegressionBasedDID = DiffinDiffsEstimator{:RegressionBasedDID,
1010
const Reg = RegressionBasedDID
1111

1212
function valid_didargs(d::Type{Reg}, ::DynamicTreatment{SharpDesign},
13-
::TrendParallel{Unconditional, Exact}, args::Dict{Symbol,Any})
13+
::TrendOrUnspecifiedPR{Unconditional,Exact}, args::Dict{Symbol,Any})
1414
name = get(args, :name, "")::String
1515
treatintterms = haskey(args, :treatintterms) ? args[:treatintterms] : TermSet()
1616
xterms = haskey(args, :xterms) ? args[:xterms] : TermSet()
1717
solvelsweights = haskey(args, :lswtnames) || get(args, :solvelsweights, false)::Bool
1818
ntargs = (data=args[:data],
1919
tr=args[:tr]::DynamicTreatment{SharpDesign},
20-
pr=args[:pr]::TrendParallel{Unconditional, Exact},
20+
pr=args[:pr]::TrendOrUnspecifiedPR{Unconditional,Exact},
2121
yterm=args[:yterm]::AbstractTerm,
2222
treatname=args[:treatname]::Symbol,
2323
subset=get(args, :subset, nothing)::Union{BitVector,Nothing},

src/procedures.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,7 @@ See also [`MakeTreatCols`](@ref).
186186
function maketreatcols(data, treatname::Symbol, treatintterms::TermSet,
187187
feM::Union{AbstractFixedEffectSolver, Nothing},
188188
weights::AbstractWeights, esample::BitVector,
189-
cohortinteracted::Bool, fetol::Real, femaxiter::Int,
190-
::Type{DynamicTreatment{SharpDesign}}, time::Symbol,
189+
cohortinteracted::Bool, fetol::Real, femaxiter::Int, time::Symbol,
191190
exc::Dict{Int,Int}, notreat::IdDict{ValidTimeType,Int})
192191

193192
nobs = sum(esample)
@@ -286,18 +285,19 @@ const MakeTreatCols = StatsStep{:MakeTreatCols, typeof(maketreatcols), true}
286285

287286
required(::MakeTreatCols) = (:data, :treatname, :treatintterms, :feM, :weights, :esample)
288287
default(::MakeTreatCols) = (cohortinteracted=true, fetol=1e-8, femaxiter=10000)
289-
transformed(::MakeTreatCols, @nospecialize(nt::NamedTuple)) = (typeof(nt.tr), nt.tr.time)
288+
# No need to consider typeof(tr) and typeof(pr) given the restrictions by valid_didargs
289+
transformed(::MakeTreatCols, @nospecialize(nt::NamedTuple)) = (nt.tr.time,)
290290

291-
combinedargs(step::MakeTreatCols, allntargs) =
292-
combinedargs(step, allntargs, typeof(allntargs[1].tr))
293-
294-
# Obtain the relative time periods excluded by all tr in allntargs
295-
function combinedargs(::MakeTreatCols, allntargs, ::Type{DynamicTreatment{SharpDesign}})
291+
# Obtain the relative time periods excluded by all tr
292+
# and the treatment groups excluded by all pr in allntargs
293+
function combinedargs(::MakeTreatCols, allntargs)
296294
exc = Dict{Int,Int}()
297295
notreat = IdDict{ValidTimeType,Int}()
298-
@inbounds for nt in allntargs
296+
for nt in allntargs
299297
foreach(x->_count!(exc, x), nt.tr.exc)
300-
foreach(x->_count!(notreat, x), nt.pr.e)
298+
if nt.pr isa TrendParallel
299+
foreach(x->_count!(notreat, x), nt.pr.e)
300+
end
301301
end
302302
nnt = length(allntargs)
303303
for (k, v) in exc
@@ -315,14 +315,18 @@ end
315315
Solve the least squares problem for regression coefficients and residuals.
316316
See also [`SolveLeastSquares`](@ref).
317317
"""
318-
function solveleastsquares!(tr::DynamicTreatment{SharpDesign}, pr::TrendParallel,
318+
function solveleastsquares!(tr::DynamicTreatment{SharpDesign}, pr::TrendOrUnspecifiedPR,
319319
yterm::AbstractTerm, xterms::TermSet, yxterms::Dict, yxcols::Dict,
320320
treatcells::VecColumnTable, treatcols::Vector,
321321
cohortinteracted::Bool, has_fe_intercept::Bool)
322322

323323
y = yxcols[yxterms[yterm]]
324324
if cohortinteracted
325-
tinds = .!((treatcells[2] .∈ (tr.exc,)).| (treatcells[1] .∈ (pr.e,)))
325+
if pr isa TrendParallel
326+
tinds = .!((treatcells[2] .∈ (tr.exc,)) .| (treatcells[1] .∈ (pr.e,)))
327+
else
328+
tinds = .!(treatcells[2] .∈ (tr.exc,))
329+
end
326330
else
327331
tinds = .!(treatcells[1] .∈ (tr.exc,))
328332
end

test/did.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,35 @@
9494
Fixed effects: none
9595
──────────────────────────────────────────────────────────────────────"""
9696

97+
r0 = @did(Reg, data=hrs, dynamic(:wave, -1), unspecifiedpr(),
98+
vce=Vcov.cluster(:hhidpn), yterm=term(:oop_spend), treatname=:wave_hosp,
99+
treatintterms=(), xterms=(fe(:wave),),
100+
cohortinteracted=false, solvelsweights=true)
101+
# Compare estimates with Stata
102+
# gen rel = wave - wave_hosp
103+
# gen irel?? = rel==??
104+
# reghdfe oop_spend irel*, a(wave) cluster(hhidpn)
105+
@test coef(r0) [-1029.0482, 245.20926, 188.59266, 3063.2707,
106+
1060.5317, 1152.3315, 1986.7811] atol=1e-4
107+
@test diag(vcov(r0)) [764612.86, 626165.68, 236556.13, 163459.83,
108+
130471.28, 294368.49, 677821.36] atol=1e0
109+
pv = VERSION < v"1.6.0" ? " <1e-8" : "<1e-08"
110+
@test sprint(show, MIME("text/plain"), r0) == """
111+
──────────────────────────────────────────────────────────────────────
112+
Summary of results: Regression-based DID
113+
──────────────────────────────────────────────────────────────────────
114+
Number of obs: 3280 Degrees of freedom: 12
115+
F-statistic: 9.12 p-value: $pv
116+
──────────────────────────────────────────────────────────────────────
117+
Sharp dynamic specification
118+
──────────────────────────────────────────────────────────────────────
119+
Relative time periods: 7 Excluded periods: -1
120+
──────────────────────────────────────────────────────────────────────
121+
Fixed effects: fe_wave
122+
──────────────────────────────────────────────────────────────────────
123+
Converged: true Singletons dropped: 0
124+
──────────────────────────────────────────────────────────────────────"""
125+
97126
sr = view(r, 1:3)
98127
@test coef(sr)[1] == r.coef[1]
99128
@test vcov(sr)[1] == r.vcov[1]

test/procedures.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ end
152152
pr = nevertreated(11)
153153
nt = (data=hrs, treatname=:wave_hosp, treatintterms=TermSet(), feM=nothing,
154154
weights=uweights(N), esample=trues(N), default(MakeTreatCols())...)
155-
ret = maketreatcols(nt..., typeof(tr), tr.time, Dict(-1=>1),
156-
IdDict{ValidTimeType,Int}(11=>1))
155+
ret = maketreatcols(nt..., tr.time, Dict(-1=>1), IdDict{ValidTimeType,Int}(11=>1))
157156
@test size(ret.cells) == (20, 2)
158157
@test length(ret.rows) == 20
159158
@test size(ret.treatcells) == (12, 2)
@@ -175,8 +174,7 @@ end
175174
@test all(w[ret.treatcells.wave_hosp.==10].==163)
176175

177176
nt = merge(nt, (treatintterms=TermSet(term(:male)),))
178-
ret1 = maketreatcols(nt..., typeof(tr), tr.time, Dict(-1=>1),
179-
IdDict{ValidTimeType,Int}(11=>1))
177+
ret1 = maketreatcols(nt..., tr.time, Dict(-1=>1), IdDict{ValidTimeType,Int}(11=>1))
180178
@test size(ret1.cells) == (40, 3)
181179
@test length(ret1.rows) == 40
182180
@test size(ret1.treatcells) == (24, 3)
@@ -195,8 +193,7 @@ end
195193
@test ret1.cellweights == ret1.cellcounts
196194

197195
nt = merge(nt, (cohortinteracted=false, treatintterms=TermSet()))
198-
ret2 = maketreatcols(nt..., typeof(tr), tr.time, Dict(-1=>1),
199-
IdDict{ValidTimeType,Int}(11=>1))
196+
ret2 = maketreatcols(nt..., tr.time, Dict(-1=>1), IdDict{ValidTimeType,Int}(11=>1))
200197
@test ret2.cells[1] == ret.cells[1]
201198
@test ret2.cells[2] == ret.cells[2]
202199
@test ret2.rows == ret.rows
@@ -210,8 +207,7 @@ end
210207
@test ret2.cellweights == ret2.cellcounts
211208

212209
nt = merge(nt, (treatintterms=TermSet(term(:male)),))
213-
ret3 = maketreatcols(nt..., typeof(tr), tr.time, Dict(-1=>1),
214-
IdDict{ValidTimeType,Int}(11=>1))
210+
ret3 = maketreatcols(nt..., tr.time, Dict(-1=>1), IdDict{ValidTimeType,Int}(11=>1))
215211
@test ret3.cells[1] == ret1.cells[1]
216212
@test ret3.cells[2] == ret1.cells[2]
217213
@test ret3.cells[3] == ret1.cells[3]
@@ -232,8 +228,7 @@ end
232228
feM = AbstractFixedEffectSolver{Float64}(fes, wt, Val{:cpu}, Threads.nthreads())
233229
nt = merge(nt, (data=df, feM=feM, weights=wt, esample=esample,
234230
treatintterms=TermSet(), cohortinteracted=true))
235-
ret = maketreatcols(nt..., typeof(tr), tr.time, Dict(-1=>1),
236-
IdDict{ValidTimeType,Int}(11=>1))
231+
ret = maketreatcols(nt..., tr.time, Dict(-1=>1), IdDict{ValidTimeType,Int}(11=>1))
237232
col = reshape(col[esample], N, 1)
238233
defaults = (default(MakeTreatCols())...,)
239234
_feresiduals!(col, feM, defaults[2:3]...)
@@ -254,9 +249,13 @@ end
254249
@test combinedargs(MakeTreatCols(), allntargs) ==
255250
(Dict{Int,Int}(), IdDict{ValidTimeType,Int}())
256251

252+
allntargs = NamedTuple[(tr=tr, pr=pr), (tr=tr, pr=unspecifiedpr())]
253+
@test combinedargs(MakeTreatCols(), allntargs) ==
254+
(Dict(-1=>2), IdDict{ValidTimeType,Int}())
255+
257256
df.wave = settime(Date.(hrs.wave), Year(1))
258257
df.wave_hosp = settime(Date.(hrs.wave_hosp), Year(1), start=Date(7))
259-
ret1 = maketreatcols(nt..., typeof(tr), tr.time,
258+
ret1 = maketreatcols(nt..., tr.time,
260259
Dict(-1=>1), IdDict{ValidTimeType,Int}(Date(11)=>1))
261260
@test ret1.cells[1] == Date.(ret.cells[1])
262261
@test ret1.cells[2] == Date.(ret.cells[2])
@@ -272,7 +271,7 @@ end
272271
df.wave = RotatingTimeArray(rot, hrs.wave)
273272
df.wave_hosp = RotatingTimeArray(rot, hrs.wave_hosp)
274273
e = rotatingtime((1,1,2), (10,11,11))
275-
ret2 = maketreatcols(nt..., typeof(tr), tr.time,
274+
ret2 = maketreatcols(nt..., tr.time,
276275
Dict(-1=>1), IdDict{ValidTimeType,Int}(c=>1 for c in e))
277276
@test ret2.cells[1] == sort!(append!((rotatingtime(r, ret.cells[1]) for r in (1,2))...))
278277
rt = append!((rotatingtime(r, 7:11) for r in (1,2))...)
@@ -284,7 +283,7 @@ end
284283
df.wave = settime(Date.(hrs.wave), Year(1), rotation=rot)
285284
df.wave_hosp = settime(Date.(hrs.wave_hosp), Year(1), start=Date(7), rotation=rot)
286285
e = rotatingtime((1,1,2), Date.((10,11,11)))
287-
ret3 = maketreatcols(nt..., typeof(tr), tr.time,
286+
ret3 = maketreatcols(nt..., tr.time,
288287
Dict(-1=>1), IdDict{ValidTimeType,Int}(c=>1 for c in e))
289288
@test ret3.cells[1].time == Date.(ret2.cells[1].time)
290289
@test ret3.cells[2].time == Date.(ret2.cells[2].time)

0 commit comments

Comments
 (0)