@@ -18,167 +18,6 @@ Summary
1818 Annotations
1919
2020
21- Overview
22- --------
23-
24- Annotations store metadata about concepts including names, cardinalities, distribution
25- types, and custom attributes. They are required to initialize:
26-
27- - **Models ** (e.g., ConceptBottleneckModel): Specify concept structure and distributions
28- - **ConceptLoss **: Route to appropriate loss functions based on concept types
29- - **ConceptMetrics **: Organize metrics by concept and compute per-concept statistics
30-
31- Distribution information is critical - it tells the model how to represent each concept
32- (e.g., Bernoulli for binary, Categorical for multi-class, Normal for continuous).
33-
34- Distributions can be provided either:
35-
36- 1. **In annotations metadata ** (recommended): Include 'distribution' key in metadata
37- 2. **Via model's variable_distributions parameter **: Pass distributions at model initialization
38-
39- Quick Start
40- -----------
41-
42- **Option 1: Distributions in metadata (recommended) **
43-
44- .. code-block :: python
45-
46- from torch_concepts.annotations import AxisAnnotation, Annotations
47- from torch.distributions import Bernoulli, Categorical
48-
49- # Distributions included in annotations
50- ann = Annotations({
51- 1 : AxisAnnotation(
52- labels = [' is_round' , ' is_smooth' , ' color' , ' class_A' , ' class_B' ],
53- cardinalities = [1 , 1 , 3 , 1 , 1 ],
54- metadata = {
55- ' is_round' : {' type' : ' discrete' , ' distribution' : Bernoulli},
56- ' is_smooth' : {' type' : ' discrete' , ' distribution' : Bernoulli},
57- ' color' : {' type' : ' discrete' , ' distribution' : Categorical},
58- ' class_A' : {' type' : ' discrete' , ' distribution' : Bernoulli},
59- ' class_B' : {' type' : ' discrete' , ' distribution' : Bernoulli}
60- }
61- )
62- })
63-
64- # Use in model (no variable_distributions needed)
65- from torch_concepts.nn import ConceptBottleneckModel
66- model = ConceptBottleneckModel(
67- input_size = 256 ,
68- annotations = ann,
69- task_names = [' class_A' , ' class_B' ]
70- )
71-
72- # Use in loss
73- from torch_concepts.nn import ConceptLoss
74- from torch_concepts import GroupConfig
75- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
76-
77- loss_config = GroupConfig(
78- binary = BCEWithLogitsLoss(),
79- categorical = CrossEntropyLoss()
80- )
81- loss = ConceptLoss(annotations = ann[1 ], fn_collection = loss_config)
82-
83- # Use in metrics
84- from torch_concepts.nn import ConceptMetrics
85- from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy
86-
87- metrics_config = GroupConfig(
88- binary = {' accuracy' : BinaryAccuracy()},
89- categorical = {' accuracy' : MulticlassAccuracy}
90- )
91- metrics = ConceptMetrics(
92- annotations = ann[1 ],
93- fn_collection = metrics_config,
94- summary_metrics = True ,
95- perconcept_metrics = True
96- )
97-
98- **Option 2: Via variable_distributions dictionary **
99-
100- .. code-block :: python
101-
102- # Annotations without distributions
103- ann = Annotations({
104- 1 : AxisAnnotation(
105- labels = [' is_round' , ' is_smooth' , ' color' , ' class_A' , ' class_B' ],
106- cardinalities = [1 , 1 , 3 , 1 , 1 ],
107- metadata = {
108- ' is_round' : {' type' : ' discrete' },
109- ' is_smooth' : {' type' : ' discrete' },
110- ' color' : {' type' : ' discrete' },
111- ' class_A' : {' type' : ' discrete' },
112- ' class_B' : {' type' : ' discrete' }
113- }
114- )
115- })
116-
117- # Provide distributions at model init
118- variable_distributions = {
119- ' is_round' : Bernoulli,
120- ' is_smooth' : Bernoulli,
121- ' color' : Categorical,
122- ' class_A' : Bernoulli,
123- ' class_B' : Bernoulli
124- }
125-
126- model = ConceptBottleneckModel(
127- input_size = 256 ,
128- annotations = ann,
129- variable_distributions = variable_distributions,
130- task_names = [' class_A' , ' class_B' ]
131- )
132-
133- # Distributions added internally, then used in loss/metrics
134- loss = ConceptLoss(annotations = model.concept_annotations, fn_collection = loss_config)
135- metrics = ConceptMetrics(
136- annotations = model.concept_annotations,
137- fn_collection = metrics_config,
138- summary_metrics = True ,
139- perconcept_metrics = True
140- )
141-
142- **Option 3: Using GroupConfig for automatic type-based assignment **
143-
144- For models with many concepts of the same types, use ``GroupConfig `` to automatically assign distributions:
145-
146- .. code-block :: python
147-
148- from torch_concepts import GroupConfig
149-
150- # Annotations with concept types
151- ann = Annotations({
152- 1 : AxisAnnotation(
153- labels = [' is_round' , ' is_smooth' , ' color' , ' shape' , ' class_A' , ' class_B' ],
154- cardinalities = [1 , 1 , 3 , 4 , 1 , 1 ],
155- metadata = {
156- ' is_round' : {' type' : ' discrete' }, # binary (card=1)
157- ' is_smooth' : {' type' : ' discrete' }, # binary (card=1)
158- ' color' : {' type' : ' discrete' }, # categorical (card=3)
159- ' shape' : {' type' : ' discrete' }, # categorical (card=4)
160- ' class_A' : {' type' : ' discrete' }, # binary (card=1)
161- ' class_B' : {' type' : ' discrete' } # binary (card=1)
162- }
163- )
164- })
165-
166- # GroupConfig automatically assigns by concept type and cardinality
167- variable_distributions = GroupConfig(
168- binary = Bernoulli, # for cardinality=1
169- categorical = Categorical # for cardinality>1
170- )
171-
172- model = ConceptBottleneckModel(
173- input_size = 256 ,
174- annotations = ann,
175- variable_distributions = variable_distributions,
176- task_names = [' class_A' , ' class_B' ]
177- )
178-
179- This approach is ideal for large-scale datasets (e.g., CUB-200 with 312 attributes).
180-
181-
18221Class Documentation
18322-------------------
18423
0 commit comments