Skip to content

Commit 3560b37

Browse files
committed
update documentation of high-level and conceptarium
1 parent 95073b9 commit 3560b37

File tree

8 files changed

+907
-1139
lines changed

8 files changed

+907
-1139
lines changed

doc/guides/using_high_level.rst

Lines changed: 892 additions & 312 deletions
Large diffs are not rendered by default.

doc/modules/annotations.rst

Lines changed: 0 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -18,167 +18,6 @@ Summary
1818
Annotations
1919

2020

21-
Overview
22-
--------
23-
24-
Annotations store metadata about concepts including names, cardinalities, distribution
25-
types, and custom attributes. They are required to initialize:
26-
27-
- **Models** (e.g., ConceptBottleneckModel): Specify concept structure and distributions
28-
- **ConceptLoss**: Route to appropriate loss functions based on concept types
29-
- **ConceptMetrics**: Organize metrics by concept and compute per-concept statistics
30-
31-
Distribution information is critical - it tells the model how to represent each concept
32-
(e.g., Bernoulli for binary, Categorical for multi-class, Normal for continuous).
33-
34-
Distributions can be provided either:
35-
36-
1. **In annotations metadata** (recommended): Include 'distribution' key in metadata
37-
2. **Via model's variable_distributions parameter**: Pass distributions at model initialization
38-
39-
Quick Start
40-
-----------
41-
42-
**Option 1: Distributions in metadata (recommended)**
43-
44-
.. code-block:: python
45-
46-
from torch_concepts.annotations import AxisAnnotation, Annotations
47-
from torch.distributions import Bernoulli, Categorical
48-
49-
# Distributions included in annotations
50-
ann = Annotations({
51-
1: AxisAnnotation(
52-
labels=['is_round', 'is_smooth', 'color', 'class_A', 'class_B'],
53-
cardinalities=[1, 1, 3, 1, 1],
54-
metadata={
55-
'is_round': {'type': 'discrete', 'distribution': Bernoulli},
56-
'is_smooth': {'type': 'discrete', 'distribution': Bernoulli},
57-
'color': {'type': 'discrete', 'distribution': Categorical},
58-
'class_A': {'type': 'discrete', 'distribution': Bernoulli},
59-
'class_B': {'type': 'discrete', 'distribution': Bernoulli}
60-
}
61-
)
62-
})
63-
64-
# Use in model (no variable_distributions needed)
65-
from torch_concepts.nn import ConceptBottleneckModel
66-
model = ConceptBottleneckModel(
67-
input_size=256,
68-
annotations=ann,
69-
task_names=['class_A', 'class_B']
70-
)
71-
72-
# Use in loss
73-
from torch_concepts.nn import ConceptLoss
74-
from torch_concepts import GroupConfig
75-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
76-
77-
loss_config = GroupConfig(
78-
binary=BCEWithLogitsLoss(),
79-
categorical=CrossEntropyLoss()
80-
)
81-
loss = ConceptLoss(annotations=ann[1], fn_collection=loss_config)
82-
83-
# Use in metrics
84-
from torch_concepts.nn import ConceptMetrics
85-
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy
86-
87-
metrics_config = GroupConfig(
88-
binary={'accuracy': BinaryAccuracy()},
89-
categorical={'accuracy': MulticlassAccuracy}
90-
)
91-
metrics = ConceptMetrics(
92-
annotations=ann[1],
93-
fn_collection=metrics_config,
94-
summary_metrics=True,
95-
perconcept_metrics=True
96-
)
97-
98-
**Option 2: Via variable_distributions dictionary**
99-
100-
.. code-block:: python
101-
102-
# Annotations without distributions
103-
ann = Annotations({
104-
1: AxisAnnotation(
105-
labels=['is_round', 'is_smooth', 'color', 'class_A', 'class_B'],
106-
cardinalities=[1, 1, 3, 1, 1],
107-
metadata={
108-
'is_round': {'type': 'discrete'},
109-
'is_smooth': {'type': 'discrete'},
110-
'color': {'type': 'discrete'},
111-
'class_A': {'type': 'discrete'},
112-
'class_B': {'type': 'discrete'}
113-
}
114-
)
115-
})
116-
117-
# Provide distributions at model init
118-
variable_distributions = {
119-
'is_round': Bernoulli,
120-
'is_smooth': Bernoulli,
121-
'color': Categorical,
122-
'class_A': Bernoulli,
123-
'class_B': Bernoulli
124-
}
125-
126-
model = ConceptBottleneckModel(
127-
input_size=256,
128-
annotations=ann,
129-
variable_distributions=variable_distributions,
130-
task_names=['class_A', 'class_B']
131-
)
132-
133-
# Distributions added internally, then used in loss/metrics
134-
loss = ConceptLoss(annotations=model.concept_annotations, fn_collection=loss_config)
135-
metrics = ConceptMetrics(
136-
annotations=model.concept_annotations,
137-
fn_collection=metrics_config,
138-
summary_metrics=True,
139-
perconcept_metrics=True
140-
)
141-
142-
**Option 3: Using GroupConfig for automatic type-based assignment**
143-
144-
For models with many concepts of the same types, use ``GroupConfig`` to automatically assign distributions:
145-
146-
.. code-block:: python
147-
148-
from torch_concepts import GroupConfig
149-
150-
# Annotations with concept types
151-
ann = Annotations({
152-
1: AxisAnnotation(
153-
labels=['is_round', 'is_smooth', 'color', 'shape', 'class_A', 'class_B'],
154-
cardinalities=[1, 1, 3, 4, 1, 1],
155-
metadata={
156-
'is_round': {'type': 'discrete'}, # binary (card=1)
157-
'is_smooth': {'type': 'discrete'}, # binary (card=1)
158-
'color': {'type': 'discrete'}, # categorical (card=3)
159-
'shape': {'type': 'discrete'}, # categorical (card=4)
160-
'class_A': {'type': 'discrete'}, # binary (card=1)
161-
'class_B': {'type': 'discrete'} # binary (card=1)
162-
}
163-
)
164-
})
165-
166-
# GroupConfig automatically assigns by concept type and cardinality
167-
variable_distributions = GroupConfig(
168-
binary=Bernoulli, # for cardinality=1
169-
categorical=Categorical # for cardinality>1
170-
)
171-
172-
model = ConceptBottleneckModel(
173-
input_size=256,
174-
annotations=ann,
175-
variable_distributions=variable_distributions,
176-
task_names=['class_A', 'class_B']
177-
)
178-
179-
This approach is ideal for large-scale datasets (e.g., CUB-200 with 312 attributes).
180-
181-
18221
Class Documentation
18322
-------------------
18423

doc/modules/high_level_api.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
High-level API
2-
==============
2+
=====================
33

4-
High-level APIs allow you to quickly build and train concept-based models using pre-configured components and minimal code.
4+
High-level API models allow you to quickly build and train concept-based models using pre-configured components and minimal code.
55

66
.. |pyc_logo| image:: https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/pyc.svg
77
:width: 20px
@@ -156,7 +156,7 @@ Configure losses and metrics using ``GroupConfig`` to automatically handle mixed
156156
)
157157
158158
concept_loss = ConceptLoss(
159-
annotations=annotations[1], # AxisAnnotation for concepts
159+
annotations=annotations,
160160
fn_collection=loss_config
161161
)
162162
@@ -174,7 +174,7 @@ Configure losses and metrics using ``GroupConfig`` to automatically handle mixed
174174
)
175175
176176
concept_metrics = ConceptMetrics(
177-
annotations=annotations[1],
177+
annotations=annotations,
178178
fn_collection=metrics_config,
179179
summary_metrics=True, # Compute average across concepts
180180
perconcept_metrics=True # Compute per-concept metrics

doc/modules/nn.loss.rst

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -28,66 +28,6 @@ Summary
2828
WeightedMSELoss
2929

3030

31-
Overview
32-
--------
33-
34-
High-level losses automatically route to appropriate loss functions based on concept types (binary, categorical, continuous) using annotation metadata.
35-
36-
Quick Start
37-
-----------
38-
39-
.. code-block:: python
40-
41-
import torch
42-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
43-
from torch_concepts import Annotations, AxisAnnotation, GroupConfig
44-
from torch_concepts.nn import ConceptLoss, ConceptBottleneckModel
45-
from torch.distributions import Bernoulli, Categorical
46-
47-
# Define annotations with mixed types
48-
ann = Annotations({
49-
1: AxisAnnotation(
50-
labels=['is_round', 'is_smooth', 'color', 'class_A'],
51-
cardinalities=[1, 1, 3, 1],
52-
metadata={
53-
'is_round': {'type': 'discrete', 'distribution': Bernoulli},
54-
'is_smooth': {'type': 'discrete', 'distribution': Bernoulli},
55-
'color': {'type': 'discrete', 'distribution': Categorical},
56-
'class_A': {'type': 'discrete', 'distribution': Bernoulli}
57-
}
58-
)
59-
})
60-
61-
# Configure loss functions by concept type using GroupConfig
62-
loss_config = GroupConfig(
63-
binary=BCEWithLogitsLoss(),
64-
categorical=CrossEntropyLoss()
65-
)
66-
67-
# Automatic routing by concept type
68-
loss = ConceptLoss(annotations=ann[1], fn_collection=loss_config)
69-
70-
# Use in Lightning training
71-
model = ConceptBottleneckModel(
72-
input_size=256,
73-
annotations=ann,
74-
task_names=['class_A'],
75-
loss=loss,
76-
optim_class=torch.optim.AdamW,
77-
optim_kwargs={'lr': 0.001}
78-
)
79-
80-
# Manual usage
81-
predictions = torch.randn(32, 6) # batch_size=32, 2 binary + 3 categorical + 1 binary
82-
targets = torch.cat([
83-
torch.randint(0, 2, (32, 2)), # binary targets
84-
torch.randint(0, 3, (32, 1)), # categorical target (class indices)
85-
torch.randint(0, 2, (32, 1)) # binary target
86-
], dim=1)
87-
88-
loss_value = loss(predictions, targets)
89-
90-
9131
Class Documentation
9232
-------------------
9333

0 commit comments

Comments
 (0)