@@ -278,9 +278,10 @@ with weights proportional to `treatweights` within each relative time.
278278function agg (r:: RegressionBasedDIDResult{<:DynamicTreatment} , names= nothing ;
279279 bys= nothing , subset= nothing )
280280 inds = subset === nothing ? Colon () : _parse_subset (r, subset, false )
281- ptcells = treatcells (r)
282- bycells = view (ptcells, inds)
283- _parse_bycells! (getfield (bycells, :columns ), ptcells, bys)
281+ bycells = view (treatcells (r), inds)
282+ isempty (bycells) && throw (ArgumentError (
283+ " No coefficient left for aggregation after taking the subset" ))
284+ _parse_bycells! (getfield (bycells, :columns ), bycells, bys)
284285 names === nothing || (bycells = subcolumns (bycells, names, nomissing= false ))
285286
286287 tcells, rows = cellrows (bycells, findcell (bycells))
@@ -386,28 +387,43 @@ Base.getproperty(cr::ContrastResult, n::Symbol) = getproperty(_getmat(cr), n)
386387Base. parent (cr:: ContrastResult ) = getfield (cr, :rs )
387388
388389"""
389- contrast(r1::RegDIDResultOrAgg, rs::RegDIDResultOrAgg...)
390+ contrast(r1::RegDIDResultOrAgg, rs::RegDIDResultOrAgg...; kwargs )
390391
391392Construct a [`ContrastResult`](@ref) by collecting the computed least-square weights
392393from each of the [`RegDIDResultOrAgg`](@ref).
394+
395+ # Keywords
396+ - `subset=nothing`: indices for cells to be included (rows in output).
397+ - `coefs=nothing`: indices for coefficients from each result to be included (columns in output).
393398"""
394- function contrast (r1:: RegDIDResultOrAgg , rs:: RegDIDResultOrAgg... )
399+ function contrast (r1:: RegDIDResultOrAgg , rs:: RegDIDResultOrAgg... ;
400+ subset= nothing , coefs= nothing )
395401 has_lsweights (r1) && all (r-> has_lsweights (r), rs) || throw (ArgumentError (
396402 " Results must contain computed least-sqaure weights" ))
397403 ri = r1. lsweights. r
398- ncoef = ntreatcoef (r1)
399- m = r1. lsweights. m
400- for r in rs
401- r. lsweights. r == ri || throw (ArgumentError (
404+ rinds = subset === nothing ? Colon () : _parse_subset (ri, subset)
405+ # Make a copy to avoid accidentally overwriting cells for DIDResult
406+ ri = deepcopy (view (ri, rinds))
407+ dcoefs = coefs === nothing ? NamedTuple () : IdDict {Int,Any} (coefs)
408+ rs = RegDIDResultOrAgg[r1, rs... ]
409+ m = Vector {Array} (undef, length (rs)+ 1 )
410+ m[1 ] = view (r1. cellymeans, rinds)
411+ iresult = [0 ]
412+ icoef = [0 ]
413+ names = [" cellymeans" ]
414+ for (i, r) in enumerate (rs)
415+ view (r. lsweights. r, rinds) == ri || throw (ArgumentError (
402416 " Cells for least-square weights comparisons must be identical across the inputs" ))
403- ncoef += ntreatcoef (r)
417+ cinds = get (dcoefs, i, nothing )
418+ cinds = cinds === nothing ? (1 : ntreatcoef (r)) : _parse_subset (treatcells (r), cinds)
419+ m[i+ 1 ] = view (r. lsweights. m, rinds, cinds)
420+ ic = view (1 : ntreatcoef (r), cinds)
421+ push! (iresult, (i for k in 1 : length (ic)). .. )
422+ append! (icoef, ic)
423+ append! (names, view (treatnames (r), cinds))
404424 end
405- rs = RegDIDResultOrAgg[r1, rs... ]
406- m = hcat (r1. cellymeans, (r. lsweights. m for r in rs). .. )
407- rinds = vcat (0 , (fill (i+ 1 , ntreatcoef (r)) for (i, r) in enumerate (rs)). .. )
408- cinds = vcat (0 , (1 : ntreatcoef (r) for r in rs). .. )
409- names = vcat (" cellymeans" , (treatnames (r) for r in rs). .. )
410- ci = VecColumnTable ((iresult= rinds, icoef= cinds, name= names))
425+ m = hcat (m... )
426+ ci = VecColumnTable ((iresult= iresult, icoef= icoef, name= names))
411427 return ContrastResult (rs, TableIndexedMatrix (m, ri, ci))
412428end
413429
@@ -428,18 +444,8 @@ function Base.sort!(cr::ContrastResult; @nospecialize(kwargs...))
428444 return cr
429445end
430446
431- _parse_subset (cr:: ContrastResult , by:: Pair ) = (inds = apply (cr. r, by); return inds)
432-
433- function _parse_subset (cr:: ContrastResult , inds)
434- eltype (inds) <: Pair || return inds
435- inds = apply_and (cr. r, inds... )
436- return inds
437- end
438-
439- _parse_subset (:: ContrastResult , :: Colon ) = Colon ()
440-
441447function Base. view (cr:: ContrastResult , subset)
442- inds = _parse_subset (cr, subset)
448+ inds = _parse_subset (cr. r , subset)
443449 r = view (cr. r, inds)
444450 m = view (cr. m, inds, :)
445451 return ContrastResult (parent (cr), TableIndexedMatrix (m, r, cr. c))
0 commit comments