Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 2920470

Browse files
committed
Fix some issues
1 parent 9803fa1 commit 2920470

File tree

5 files changed

+36
-20
lines changed

5 files changed

+36
-20
lines changed

src/DiffinDiffsBase.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module DiffinDiffsBase
22

3-
using Combinatorics: combinations
43
using CSV: File
4+
using Combinatorics: combinations
55
using MacroTools: @capture, isexpr, postwalk
66
using Reexport
7+
using SplitApplyCombine: groupfind, groupview
78
@reexport using StatsModels
89
using StatsModels: TupleTerm
9-
using SplitApplyCombine: groupfind, groupview
1010
using Tables: istable, getcolumn
1111

1212
import Base: ==, show, union

src/StatsProcedures.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,18 @@ _procedure(::StatsSpec{A,T}) where {A,T} = T
331331

332332
function (sp::StatsSpec{A,T})(;
333333
verbose::Bool=false, keep=nothing, keepall::Bool=false) where {A,T}
334-
args = deepcopy(sp.args)
335-
args = verbose ? merge(args, (verbose=true,)) : args
334+
args = verbose ? merge(sp.args, (verbose=true,)) : sp.args
336335
ntall = foldl(|>, T(), init=args)
337336
ntall = _result(T, ntall)
338337
if keepall
339338
return ntall
340-
elseif !isempty(ntall)
339+
else
341340
if keep === nothing
342-
return haskey(ntall, :result) ? ntall.result : ntall[end]
341+
if isempty(ntall)
342+
return nothing
343+
else
344+
return haskey(ntall, :result) ? ntall.result : ntall[end]
345+
end
343346
else
344347
# Cannot iterate Symbol
345348
if keep isa Symbol
@@ -348,11 +351,10 @@ function (sp::StatsSpec{A,T})(;
348351
eltype(keep)==Symbol ||
349352
throw(ArgumentError("expect Symbol or collections of Symbols for the value of option `keep`"))
350353
end
351-
!in(:result, keep) && haskey(ntall, :result) && (keep = (keep..., :result))
352-
return NamedTuple{keep}(ntall)
354+
in(:result, keep) || (keep = (keep..., :result))
355+
names = ((n for n in keep if haskey(ntall, n))...,)
356+
return NamedTuple{names}(ntall)
353357
end
354-
else
355-
return nothing
356358
end
357359
end
358360

@@ -393,7 +395,7 @@ function proceed(sps::AbstractVector{<:StatsSpec};
393395
nsps == 0 && throw(ArgumentError("expect a nonempty vector"))
394396
traces = Vector{NamedTuple}(undef, nsps)
395397
for i in 1:nsps
396-
traces[i] = deepcopy(sps[i].args)
398+
traces[i] = sps[i].args
397399
end
398400
gids = groupfind(r->_procedure(r)(), sps)
399401
steps = pool((p for p in keys(gids))...)
@@ -438,7 +440,7 @@ function proceed(sps::AbstractVector{<:StatsSpec};
438440
if keepall
439441
return traces
440442
elseif keep===nothing
441-
return [haskey(r, :result) ? r.result : r[end] for r in traces]
443+
return [haskey(r, :result) ? r.result : isempty(r) ? nothing : r[end] for r in traces]
442444
else
443445
# Cannot iterate Symbol
444446
if keep isa Symbol

src/did.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,15 @@ The order of the arguments is irrelevant.
116116
- `args... kwargs...`: a list of arguments to be processed by [`parse_didargs`](@ref) and [`valid_didargs`](@ref).
117117
118118
# Notes
119-
When used outside [`@specset`](@ref),
119+
When expanded outside [`@specset`](@ref),
120120
a [`StatsSpec`](@ref) is constructed and then estimated by calling this instance.
121121
Options for [`StatsSpec`] can be provided in a bracket `[...]`
122122
as the first argument after `@did` with each option separated by white space.
123123
For options that take a Boolean value,
124124
specifying the name of the option is enough for setting the value to be true.
125125
By default, only a result object that is a subtype of [`DIDResult`](@ref) is returned.
126126
127-
When used inside [`@specset`](@ref),
127+
When expanded inside [`@specset`](@ref),
128128
`@did` informs [`@specset`](@ref) the methods for processing the arguments.
129129
Any option specified in the bracket is ignored.
130130

test/StatsProcedures.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ end
215215

216216
s6 = StatsSpec("", NP, NamedTuple())
217217
@test s6() === nothing
218+
@test s6(keepall=true) == NamedTuple()
219+
@test s6(keep=:result) == NamedTuple()
218220

219221
@test sprint(show, s1) == "name"
220222
@test sprint(show, s2) == "unnamed"
@@ -244,15 +246,15 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a
244246
s5 = StatsSpec("s5", IP, (a="a", b="b"))
245247
s6 = StatsSpec("s6", CP, (a="a", b="b1"))
246248
s7 = StatsSpec("s7", CP, (a="a", b="b2"))
247-
s8 = StatsSpec("s8", EP, (a="a", b="b"))
248-
249+
s8 = StatsSpec("s8", NP, NamedTuple())
250+
s9 = StatsSpec("s9", EP, (a="a", b="b"))
251+
249252
@test proceed([s1]) == ["aab"]
250253
@test proceed([s1,s2], verbose=true) == ["aab", "aab"]
251254
@test proceed([s1,s3], verbose=true) == ["aab", "aab1"]
252255
@test proceed([s1,s4], verbose=true) == ["aab", "ab"]
253256
@test proceed([s1,s5], verbose=true) == ["aab", "aab"]
254257
@test proceed([s1,s4,s5], verbose=true) == ["aab", "ab", "aab"]
255-
@test_throws ArgumentError proceed(StatsSpec[])
256258

257259
@test proceed([s1], keep=:a) == [(a="a",result="aab")]
258260
@test proceed([s1], keep=[:a,:b]) == [(a="a", b="b", result="aab")]
@@ -275,7 +277,12 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a
275277
@test ret[1].c === ret[2].c
276278
@test ret[1].result !== ret[2].result
277279

278-
@test_throws ErrorException proceed([s8])
280+
@test proceed([s8]) == [nothing]
281+
@test proceed([s8], keepall=true) == NamedTuple[NamedTuple()]
282+
@test proceed([s8], keep=:result) == NamedTuple[NamedTuple()]
283+
284+
@test_throws ArgumentError proceed(StatsSpec[])
285+
@test_throws ErrorException proceed([s9])
279286
end
280287

281288
@testset "_parse!" begin
@@ -351,4 +358,12 @@ end
351358
StatsSpec(testformatter(testparser(RP; a=a, b="b1"))...)(;)
352359
end
353360
@test r == ["a1a1b", "a1a1b1", "a2a2b", "a2a2b1", "a3a3b", "a3a3b1"]
361+
362+
r = @specset RP a="a1" begin
363+
StatsSpec(testformatter(testparser(; b="b"))...)(;)
364+
for i in 2:3
365+
StatsSpec(testformatter(testparser(; a="a"*"$i", b="b"))...)(;)
366+
end
367+
end
368+
@test r == ["a1a1b", "a2a2b", "a3a3b"]
354369
end

test/testutils.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ testresult(::AbstractTreatment, ::String) = ((result="testresult",), false)
3131
const TestResult = StatsStep{:TestResult, typeof(testresult)}
3232
namedargs(::TestResult) = (tr=nothing, str=nothing)
3333

34-
const NotImplemented = DiffinDiffsEstimator{:NotImplemented, Tuple{}}
35-
3634
const TestDID = DiffinDiffsEstimator{:TestDID, Tuple{TestStep,TestResult}}
35+
const NotImplemented = DiffinDiffsEstimator{:NotImplemented, Tuple{}}
3736

3837
const TR = TestTreatment(:t, 0)
3938
const PR = TestParallel(0)

0 commit comments

Comments
 (0)