@@ -7,7 +7,7 @@ using Parameters
77using StatsBase
88
99@with_kw mutable struct FixedParameters
10- n_rounds:: Int = 10
10+ n_rounds:: Int = 50
1111 n_folds:: Int = 5
1212 seed:: Union{Nothing,Int} = nothing
1313 T:: Int = 100
@@ -115,7 +115,7 @@ function choose_individuals(experiment::Experiment, recourse_systems::AbstractAr
115115 ŷ = probs (sys. model, sys. data. X)
116116 n_classes = size (sys. data. y, 1 )
117117 if n_classes == 1
118- cand_ = findall (vec (ŷ) .!= target)
118+ cand_ = findall (round .( vec (ŷ) ) .!= target)
119119 else
120120 ŷ = Flux. onecold (ŷ, 1 : n_classes)
121121 cand_ = findall (vec (ŷ) .!= target)
@@ -157,39 +157,43 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
157157 T = args. T
158158 target = experiment. target
159159
160- # Generate recourse:
161- factuals = select_factual (counterfactual_data, chosen_individuals)
160+ if length (chosen_individuals) > 0
162161
163- results = generate_counterfactual (
164- factuals, target, counterfactual_data, M, generator;
165- T= T, num_counterfactuals= experiment. num_counterfactuals, generative_model_params= args. generative_model_params,
166- latent_space= args. latent_space
167- )
162+ # Generate recourse:
163+ factuals = select_factual (counterfactual_data, chosen_individuals)
168164
169- # Unwrap new data:
170- indices_ = rand (1 : experiment. num_counterfactuals, length (results)) # randomly draw from generated counterfactuals
171- X′ = reduce (hcat, @. (selectdim (counterfactual (results), 3 , indices_)))
172- y′ = reduce (hcat, @. (selectdim (counterfactual_label (results), 3 , indices_)))
165+ results = generate_counterfactual (
166+ factuals, target, counterfactual_data, M, generator;
167+ T= T, num_counterfactuals= experiment. num_counterfactuals, generative_model_params= args. generative_model_params,
168+ latent_space= args. latent_space
169+ )
173170
174- # If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
175- chosen_individuals = chosen_individuals[vec (.! (isnan .(y′)))]
171+ # Unwrap new data:
172+ indices_ = rand (1 : experiment. num_counterfactuals, length (results)) # randomly draw from generated counterfactuals
173+ X′ = reduce (hcat, @. (selectdim (counterfactual (results), 3 , indices_)))
174+ y′ = reduce (hcat, @. (selectdim (counterfactual_label (results), 3 , indices_)))
176175
177- # Update data:
178- X[:, chosen_individuals] = X′
179- y[:, chosen_individuals] = y′
176+ # If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
177+ chosen_individuals = chosen_individuals[vec (.! (isnan .(y′)))]
180178
181- # Generative model:
182- gen_mod = deepcopy (counterfactual_data. generative_model)
183- if ! isnothing (gen_mod)
184- CounterfactualExplanations. GenerativeModels. retrain! (gen_mod, X, y)
185- end
179+ # Update data:
180+ X[:, chosen_individuals] = X′
181+ y[:, chosen_individuals] = y′
186182
187- # Update data, classifier and benchmark:
188- recourse_system. data. X = X
189- recourse_system. data. y = y
190- recourse_system. data. generative_model = gen_mod
191- recourse_system. model = Models. train (M, counterfactual_data)
192- recourse_system. benchmark = vcat (recourse_system. benchmark, CounterfactualExplanations. Benchmark. benchmark (results))
183+ # Generative model:
184+ gen_mod = deepcopy (counterfactual_data. generative_model)
185+ if ! isnothing (gen_mod)
186+ CounterfactualExplanations. GenerativeModels. retrain! (gen_mod, X, y)
187+ end
188+
189+ # Update data, classifier and benchmark:
190+ recourse_system. data. X = X
191+ recourse_system. data. y = y
192+ recourse_system. data. generative_model = gen_mod
193+ recourse_system. model = Models. train (M, counterfactual_data)
194+ recourse_system. benchmark = vcat (recourse_system. benchmark, CounterfactualExplanations. Benchmark. benchmark (results))
195+
196+ end
193197
194198end
195199
0 commit comments