Skip to content

Commit b693467

Browse files
committed
after merge
2 parents 899b254 + a16b99b commit b693467

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AlgorithmicRecourseDynamics"
22
uuid = "3d1ede72-abb8-4340-bf8e-2ae06849b5ec"
33
authors = ["Patrick Altmeyer"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
@@ -27,6 +27,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2727

2828
[compat]
2929
CSV = "0.10"
30+
CounterfactualExplanations = "0.1"
3031
DataFrames = "1"
3132
Distances = "0.10"
3233
Flux = "0.13"

src/experiments/functions.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,13 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
173173
X′ = reduce(hcat, @.(selectdim(counterfactual(results), 3, indices_)))
174174
y′ = reduce(hcat, @.(selectdim(counterfactual_label(results), 3, indices_)))
175175

176-
# If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
177-
chosen_individuals = chosen_individuals[vec(.!(isnan.(y′)))]
176+
# If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
177+
valid_ces = vec(.!(isnan.(y′)))
178+
chosen_individuals = chosen_individuals[valid_ces]
178179

179-
# Update data:
180-
X[:, chosen_individuals] = X′
181-
y[:, chosen_individuals] = y′
180+
# Update data:
181+
X[:, chosen_individuals] = X′[:, valid_ces]
182+
y[:, chosen_individuals] = y′[:, valid_ces]
182183

183184
# Generative model:
184185
gen_mod = deepcopy(counterfactual_data.generative_model)

0 commit comments

Comments
 (0)