Skip to content

Commit 374e5ca

Browse files
committed
ConceptMetrics: metrics refactor + GroupConfig: flexible/modular init of metrics and losses.
1 parent 17a5ee6 commit 374e5ca

File tree

27 files changed

+1846
-1276
lines changed

27 files changed

+1846
-1276
lines changed

conceptarium/conf/_default.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
defaults:
22
- dataset: asia
33
- model: cbm_joint
4-
- loss: _default
4+
- loss: standard
5+
- metrics: standard
56
- _self_
67

78
# =============================================================

conceptarium/conf/loss/_default.yaml

Lines changed: 0 additions & 13 deletions
This file was deleted.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# =============================================================
2+
# Loss settings
3+
# =============================================================
4+
_target_: "torch_concepts.nn.ConceptLoss"
5+
6+
fn_collection:
7+
_target_: "torch_concepts.nn.modules.utils.GroupConfig"
8+
binary:
9+
_target_: "torch.nn.BCEWithLogitsLoss"
10+
categorical:
11+
_target_: "torch.nn.CrossEntropyLoss"
12+
# continuous:
13+
# ... not supported yet
Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
# =============================================================
2+
# Loss settings
3+
# =============================================================
14
_target_: "torch_concepts.nn.WeightedConceptLoss"
25

36
weight: 0.8 # weight applied to concepts, (1-weight) applied to task
4-
task_names: ${model.task_names}
5-
fn_collection:
6-
discrete:
7-
binary:
8-
path: "torch.nn.BCEWithLogitsLoss"
9-
kwargs: {}
10-
categorical:
11-
path: "torch.nn.CrossEntropyLoss"
12-
kwargs: {}
7+
task_names: ${dataset.default_task_names}
138

9+
fn_collection:
10+
_target_: "torch_concepts.nn.modules.utils.GroupConfig"
11+
binary:
12+
_target_: "torch.nn.BCEWithLogitsLoss"
13+
categorical:
14+
_target_: "torch.nn.CrossEntropyLoss"
1415
# continuous:
15-
# ... not supported yet
16+
# ... not supported yet
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# =============================================================
2+
# Metrics settings
3+
# =============================================================
4+
_target_: "torch_concepts.nn.ConceptMetrics"
5+
6+
# tracking of summary metrics for each concept type
7+
summary_metrics: true
8+
# tracking of metrics for each individual concept
9+
# `true` for all concepts, list of concept names, or `false` for none
10+
# ${dataset.default_task_names} for tracking tasks individually
11+
perconcept_metrics: true
12+
13+
fn_collection:
14+
_target_: "torch_concepts.nn.modules.utils.GroupConfig"
15+
binary:
16+
accuracy:
17+
_target_: "torchmetrics.classification.BinaryAccuracy"
18+
categorical:
19+
accuracy:
20+
- _target_: "hydra.utils.get_class"
21+
path: "torchmetrics.classification.MulticlassAccuracy"
22+
- average: 'micro'
23+
# continuous:
24+
# ... not supported yet

conceptarium/conf/model/_commons.yaml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
defaults:
2-
- metrics: _default
32
- _self_
43

54

@@ -53,16 +52,10 @@ optim_kwargs:
5352
# factor: 0.2
5453

5554

55+
# TODO: implement this
5656
# =============================================================
57-
# Metrics settings
57+
# Training settings
5858
# =============================================================
59-
# tracking of summary metrics for each concept type
60-
summary_metrics: true
61-
# tracking of metrics for each individual concept
62-
# `true` for all concepts, list of concept names, or `false` for none
63-
perconcept_metrics: ${dataset.default_task_names}
64-
65-
# TODO: implement this
6659
# train_interv_prob: 0.1
6760
# test_interv_policy: nodes_true # levels_true, levels_pred, nodes_true, nodes_pred, random
68-
# test_interv_noise: 0.
61+
# test_interv_noise: 0.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
defaults:
2+
- _commons
3+
- _self_
4+
5+
_target_: "torch_concepts.nn.ConceptBottleneckModel_Independent"
6+
7+
task_names: ${dataset.default_task_names}
8+
9+
inference:
10+
_target_: "torch_concepts.nn.DeterministicInference"
11+
_partial_: true

conceptarium/conf/model/metrics/_default.yaml

Lines changed: 0 additions & 19 deletions
This file was deleted.

conceptarium/conf/sweep.yaml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,22 @@ hydra:
99
# standard grid search
1010
params:
1111
seed: 1
12-
dataset: asia
12+
dataset: asia, sachs, insurance
1313
model: cbm
14-
# loss: weighted
15-
# loss.weight: 0.99
14+
#loss: standard, weighted
1615

1716
model:
18-
summary_metrics: true
19-
perconcept_metrics: true # or ${dataset.default_task_names}
20-
# train_interv_prob: 0.8
21-
# test_interv_noise: 0.8 # for bndatasets only
2217
optim_kwargs:
23-
lr: 0.001
18+
lr: 0.01
19+
20+
metrics:
21+
summary_metrics: true
22+
perconcept_metrics: true #${dataset.default_task_names}
2423

2524
trainer:
2625
logger: null
2726
max_epochs: 200
28-
patience: 30
27+
patience: 20
2928

3029
matmul_precision: medium
3130

conceptarium/run_experiment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def main(cfg: DictConfig) -> None:
4444
# ----------------------------------
4545
logger.info("----------------------INIT MODEL-------------------------------------")
4646
loss = instantiate(cfg.loss, annotations=datamodule.annotations, _convert_="all")
47-
model = instantiate(cfg.model, annotations=datamodule.annotations, loss=loss, _convert_="all")
47+
logger.info(loss)
48+
metrics = instantiate(cfg.metrics, annotations=datamodule.annotations, _convert_="all")
49+
logger.info(metrics)
50+
model = instantiate(cfg.model, annotations=datamodule.annotations, loss=loss, metrics=metrics, _convert_="all")
4851

4952
logger.info("----------------------BEGIN TRAINING---------------------------------")
5053
try:

0 commit comments

Comments
 (0)