Skip to content

Commit 899b254

Browse files
committed
fixed a small bug
1 parent 38a04c4 commit 899b254

File tree

1 file changed

+33
-29
lines changed

1 file changed

+33
-29
lines changed

src/experiments/functions.jl

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Parameters
77
using StatsBase
88

99
@with_kw mutable struct FixedParameters
10-
n_rounds::Int = 10
10+
n_rounds::Int = 50
1111
n_folds::Int = 5
1212
seed::Union{Nothing,Int} = nothing
1313
T::Int = 100
@@ -115,7 +115,7 @@ function choose_individuals(experiment::Experiment, recourse_systems::AbstractAr
115115
= probs(sys.model, sys.data.X)
116116
n_classes = size(sys.data.y, 1)
117117
if n_classes == 1
118-
cand_ = findall(vec(ŷ) .!= target)
118+
cand_ = findall(round.(vec(ŷ)) .!= target)
119119
else
120120
= Flux.onecold(ŷ, 1:n_classes)
121121
cand_ = findall(vec(ŷ) .!= target)
@@ -157,39 +157,43 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
157157
T = args.T
158158
target = experiment.target
159159

160-
# Generate recourse:
161-
factuals = select_factual(counterfactual_data, chosen_individuals)
160+
if length(chosen_individuals) > 0
162161

163-
results = generate_counterfactual(
164-
factuals, target, counterfactual_data, M, generator;
165-
T=T, num_counterfactuals=experiment.num_counterfactuals, generative_model_params=args.generative_model_params,
166-
latent_space=args.latent_space
167-
)
162+
# Generate recourse:
163+
factuals = select_factual(counterfactual_data, chosen_individuals)
168164

169-
# Unwrap new data:
170-
indices_ = rand(1:experiment.num_counterfactuals, length(results)) # randomly draw from generated counterfactuals
171-
X′ = reduce(hcat, @.(selectdim(counterfactual(results), 3, indices_)))
172-
y′ = reduce(hcat, @.(selectdim(counterfactual_label(results), 3, indices_)))
165+
results = generate_counterfactual(
166+
factuals, target, counterfactual_data, M, generator;
167+
T=T, num_counterfactuals=experiment.num_counterfactuals, generative_model_params=args.generative_model_params,
168+
latent_space=args.latent_space
169+
)
173170

174-
# If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
175-
chosen_individuals = chosen_individuals[vec(.!(isnan.(y′)))]
171+
# Unwrap new data:
172+
indices_ = rand(1:experiment.num_counterfactuals, length(results)) # randomly draw from generated counterfactuals
173+
X′ = reduce(hcat, @.(selectdim(counterfactual(results), 3, indices_)))
174+
y′ = reduce(hcat, @.(selectdim(counterfactual_label(results), 3, indices_)))
176175

177-
# Update data:
178-
X[:, chosen_individuals] = X′
179-
y[:, chosen_individuals] = y′
176+
# If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
177+
chosen_individuals = chosen_individuals[vec(.!(isnan.(y′)))]
180178

181-
# Generative model:
182-
gen_mod = deepcopy(counterfactual_data.generative_model)
183-
if !isnothing(gen_mod)
184-
CounterfactualExplanations.GenerativeModels.retrain!(gen_mod, X, y)
185-
end
179+
# Update data:
180+
X[:, chosen_individuals] = X′
181+
y[:, chosen_individuals] = y′
186182

187-
# Update data, classifier and benchmark:
188-
recourse_system.data.X = X
189-
recourse_system.data.y = y
190-
recourse_system.data.generative_model = gen_mod
191-
recourse_system.model = Models.train(M, counterfactual_data)
192-
recourse_system.benchmark = vcat(recourse_system.benchmark, CounterfactualExplanations.Benchmark.benchmark(results))
183+
# Generative model:
184+
gen_mod = deepcopy(counterfactual_data.generative_model)
185+
if !isnothing(gen_mod)
186+
CounterfactualExplanations.GenerativeModels.retrain!(gen_mod, X, y)
187+
end
188+
189+
# Update data, classifier and benchmark:
190+
recourse_system.data.X = X
191+
recourse_system.data.y = y
192+
recourse_system.data.generative_model = gen_mod
193+
recourse_system.model = Models.train(M, counterfactual_data)
194+
recourse_system.benchmark = vcat(recourse_system.benchmark, CounterfactualExplanations.Benchmark.benchmark(results))
195+
196+
end
193197

194198
end
195199

0 commit comments

Comments
 (0)