Skip to content

Commit 0761c95

Browse files
committed
Merge branch 'factors' of https://github.com/pyc-team/pytorch_concepts into factors
2 parents 4eff755 + aeed36b commit 0761c95

28 files changed

+439
-242
lines changed

README.md

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<p align="center">
2-
<img src="doc/_static/img/pyc_logo.png" alt="PyC Logo" width="40%">
2+
<img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/pyc_logo.png" alt="PyC Logo" width="40%">
33
</p>
44

55
<p align="center">
@@ -10,12 +10,12 @@
1010
</p>
1111

1212
<p align="center">
13-
<a href="https://pytorch-concepts.readthedocs.io/en/latest/guides/installation.html">🚀 Getting Started</a> -
14-
<a href="https://pytorch-concepts.readthedocs.io/">📚 Documentation</a> -
13+
<a href="https://pytorch-concepts.readthedocs.io/en/latest/guides/installation.html">🚀 Getting Started</a> -
14+
<a href="https://pytorch-concepts.readthedocs.io/">📚 Documentation</a> -
1515
<a href="https://pytorch-concepts.readthedocs.io/en/latest/guides/using.html">💻 User guide</a>
1616
</p>
1717

18-
<img src="doc/_static/img/logos/pyc.svg" width="20px" align="center"> PyC is a library built upon <img src="doc/_static/img/logos/pytorch.svg" width="20px" align="center"> PyTorch and <img src="doc/_static/img/logos/lightning.svg" width="20px" align="center"> Pytorch Lightning to easily implement **interpretable and causally transparent deep learning models**.
18+
<img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/pyc.svg" width="20px"> PyC is a library built upon <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/pytorch.svg" width="20px" align="center"> PyTorch and <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/lightning.svg" width="20px" align="center"> Pytorch Lightning to easily implement **interpretable and causally transparent deep learning models**.
1919
The library provides primitives for layers (encoders, predictors, special layers), probabilistic models, and APIs for running experiments at scale.
2020

2121
The name of the library stands for both
@@ -26,7 +26,7 @@ The name of the library stands for both
2626

2727
# Quick Start
2828

29-
You can install PyC with core dependencies from [PyPI](https://pypi.org/project/pytorch-concepts/):
29+
You can install <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/pyc.svg" width="20px"> PyC with core dependencies from [PyPI](https://pypi.org/project/pytorch-concepts/):
3030

3131
```bash
3232
pip install pytorch-concepts
@@ -38,19 +38,19 @@ After installation, you can import it in your Python scripts as:
3838
import torch_concepts as pyc
3939
```
4040

41-
Follow our [user guide](https://pytorch-concepts.readthedocs.io/en/latest/guides/using.html) to get started with building interpretable models using PyC!
41+
Follow our [user guide](https://pytorch-concepts.readthedocs.io/en/latest/guides/using.html) to get started with building interpretable models using <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/pyc.svg" width="20px"> PyC!
4242

4343
---
4444

45-
# <img src="doc/_static/img/logos/pyc.svg" width="20px" align="center"> PyC Software Stack
45+
# <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/pyc.svg" width="20px"> PyC Software Stack
4646
The library is organized to be modular and accessible at different levels of abstraction:
47-
- <img src="doc/_static/img/logos/conceptarium.svg" width="20px" align="center"> **Conceptarium (No-code API). Use case: applications and benchmarking.** These APIs allow to easily run large-scale highly parallelized and standardized experiments by interfacing with configuration files. Built on top of <img src="doc/_static/img/logos/hydra-head.svg" width="20px" align="center"> Hydra and <img src="doc/_static/img/logos/wandb.svg" width="20px" align="center"> WandB.
48-
- **High-level APIs. Use case: use out-of-the-box state-of-the-art models.** These APIs allow to instantiate use implemented models with 1 line of code. This interface is built in <img src="doc/_static/img/logos/lightning.svg" width="20px" align="center"> Pytorch Lightning to easily standardize training and evaluation.
47+
- <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/conceptarium.svg" width="20px" align="center"> **Conceptarium (No-code API). Use case: applications and benchmarking.** These APIs allow to easily run large-scale highly parallelized and standardized experiments by interfacing with configuration files. Built on top of <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/hydra-head.svg" width="20px" align="center"> Hydra and <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/wandb.svg" width="20px" align="center"> WandB.
48+
- **High-level APIs. Use case: use out-of-the-box state-of-the-art models.** These APIs allow to instantiate use implemented models with 1 line of code. This interface is built in <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/lightning.svg" width="20px" align="center"> Pytorch Lightning to easily standardize training and evaluation.
4949
- **Mid-level APIs. Use case: build custom interpretable and causally transparent probabilistic graphical models.** These APIs allow to build new interpretable probabilistic models and run efficient tensorial probabilistic inference.
50-
- **Low-level APIs. Use case: assemble custom interpretable architectures.** These APIs allow to build architectures from basic interpretable layers in a plain <img src="doc/_static/img/logos/pytorch.svg" width="20px" align="center"> PyTorch-like interface. These APIs also include metrics, losses, and datasets.
50+
- **Low-level APIs. Use case: assemble custom interpretable architectures.** These APIs allow to build architectures from basic interpretable layers in a plain <img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/logos/pytorch.svg" width="20px" align="center"> PyTorch-like interface. These APIs also include metrics, losses, and datasets.
5151

5252
<p align="center">
53-
<img src="doc/_static/img/pyc_software_stack.png" alt="PyC Software Stack" width="90%">
53+
<img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/pyc_software_stack.png" alt="PyC Software Stack" width="90%">
5454
</p>
5555

5656
---
@@ -96,9 +96,10 @@ Reference authors: [Pietro Barbiero](http://www.pietrobarbiero.eu/), [Giovanni D
9696
This project is supported by the following organizations:
9797

9898
<p align="center">
99-
<img src="doc/_static/img/funding/fwo_kleur.png" alt="FWO - Research Foundation Flanders" height="60" style="margin: 20px;">
99+
<img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/funding/fwo_kleur.png" alt="FWO - Research Foundation Flanders" height="60" style="margin: 20px;">
100100
&nbsp;&nbsp;&nbsp;&nbsp;
101-
<img src="doc/_static/img/funding/hasler.png" alt="Hasler Foundation" height="60" style="margin: 20px;">
101+
<img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/funding/hasler.png" alt="Hasler Foundation" height="60" style="margin: 20px;">
102102
&nbsp;&nbsp;&nbsp;&nbsp;
103-
<img src="doc/_static/img/funding/snsf.png" alt="SNSF - Swiss National Science Foundation" height="60" style="margin: 20px;">
103+
<img src="https://raw.githubusercontent.com/pyc-team/pytorch_concepts/refs/heads/factors/doc/_static/img/funding/snsf.png" alt="SNSF - Swiss National Science Foundation" height="60" style="margin: 20px;">
104104
</p>
105+

doc/guides/using_low_level.rst

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,23 @@ takes as input both ``Endogenous`` and ``Exogenous`` representations and produce
6666

6767
.. code-block:: python
6868
69-
pyc.nn.HyperLinearCUC(in_features_endogenous=10, in_features_exogenous=7,
70-
embedding_size=24, out_features=3)
69+
pyc.nn.HyperLinearCUC(
70+
in_features_endogenous=10,
71+
in_features_exogenous=7,
72+
embedding_size=24,
73+
out_features=3
74+
)
7175
7276
As a final example, graph learners are a special layers that learn relationships between concepts.
7377
They do not follow the standard naming convention of encoders and predictors, but their purpose should be
7478
clear from their name.
7579

7680
.. code-block:: python
7781
78-
wanda = pyc.nn.WANDAGraphLearner(['c1', 'c2', 'c3'], ['task A', 'task B', 'task C'])
82+
wanda = pyc.nn.WANDAGraphLearner(
83+
['c1', 'c2', 'c3'],
84+
['task A', 'task B', 'task C']
85+
)
7986
8087
8188
Step 1: Import Libraries
@@ -152,9 +159,7 @@ Train with both concept and task supervision:
152159
import torch.nn.functional as F
153160
154161
# Compute losses
155-
concept_loss = F.binary_cross_entropy_with_endogenous(
156-
concept_endogenous, concept_labels
157-
)
162+
concept_loss = F.binary_cross_entropy(torch.sigmoid(concept_endogenous), concept_labels)
158163
task_loss = F.cross_entropy(task_endogenous, task_labels)
159164
total_loss = task_loss + 0.5 * concept_loss
160165
@@ -183,9 +188,11 @@ The context manager takes two main arguments: **strategies** and **policies**.
183188
policy = UniformPolicy(out_features=n_concepts)
184189
185190
# Apply intervention to encoder
186-
with intervention(policies=policy,
187-
strategies=strategy,
188-
target_concepts=[0, 2]) as new_encoder_layer:
191+
with intervention(
192+
policies=policy,
193+
strategies=strategy,
194+
target_concepts=[0, 2]
195+
) as new_encoder_layer:
189196
intervened_concepts = new_encoder_layer(input=x)
190197
intervened_tasks = model['predictor'](endogenous=intervened_concepts)
191198

doc/guides/using_mid_level_causal.rst

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ Structural Equation Models
5656
.. code-block:: python
5757
5858
sem_model = ProbabilisticModel(
59-
variables=[exogenous_var, genotype_var, ...],
60-
parametric_cpds=[exogenous_cpd, genotype_cpd, ...]
59+
variables=[exogenous_var, genotype_var],
60+
parametric_cpds=[exogenous_cpd, genotype_cpd]
6161
)
6262
6363
Interventions
@@ -78,9 +78,9 @@ For example, to set ``smoking`` to 0 (prevent smoking) and query the effect on d
7878
)
7979
8080
with intervention(
81-
policies=UniformPolicy(out_features=1),
82-
strategies=smoking_strategy_0,
83-
target_concepts=["smoking"]
81+
policies=UniformPolicy(out_features=1),
82+
strategies=smoking_strategy_0,
83+
target_concepts=["smoking"]
8484
):
8585
intervened_results_0 = inference_engine.query(
8686
query_concepts=["genotype", "smoking", "tar", "cancer"],
@@ -258,9 +258,9 @@ Perform do-interventions to estimate causal effects:
258258
)
259259
260260
with intervention(
261-
policies=UniformPolicy(out_features=1),
262-
strategies=smoking_strategy_0,
263-
target_concepts=["smoking"]
261+
policies=UniformPolicy(out_features=1),
262+
strategies=smoking_strategy_0,
263+
target_concepts=["smoking"]
264264
):
265265
intervened_results_0 = inference_engine.query(
266266
query_concepts=["genotype", "smoking", "tar", "cancer"],
@@ -275,9 +275,9 @@ Perform do-interventions to estimate causal effects:
275275
)
276276
277277
with intervention(
278-
policies=UniformPolicy(out_features=1),
279-
strategies=smoking_strategy_1,
280-
target_concepts=["smoking"]
278+
policies=UniformPolicy(out_features=1),
279+
strategies=smoking_strategy_1,
280+
target_concepts=["smoking"]
281281
):
282282
intervened_results_1 = inference_engine.query(
283283
query_concepts=["genotype", "smoking", "tar", "cancer"],

doc/guides/using_mid_level_proba.rst

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,29 @@ At this API level, models are represented as probabilistic models where:
3131

3232
.. code-block:: python
3333
34-
concepts = pyc.EndogenousVariable(concepts=["c1", "c2", "c3"], parents=[],
35-
distribution=torch.distributions.RelaxedBernoulli)
34+
concepts = pyc.EndogenousVariable(
35+
concepts=["c1", "c2", "c3"],
36+
parents=[],
37+
distribution=torch.distributions.RelaxedBernoulli
38+
)
3639
3740
- ``ParametricCPD`` objects represent conditional probability distributions (CPDs) between variables in the probabilistic model and are parameterized by |pyc_logo| PyC layers. For instance we can define a list of three parametric CPDs for the above concepts as:
3841

3942
.. code-block:: python
4043
41-
concept_cpd = pyc.nn.ParametricCPD(concepts=["c1", "c2", "c3"],
42-
parametrization=pyc.nn.LinearZC(in_features=10, out_features=3))
44+
concept_cpd = pyc.nn.ParametricCPD(
45+
concepts=["c1", "c2", "c3"],
46+
parametrization=pyc.nn.LinearZC(in_features=10, out_features=3)
47+
)
4348
4449
- ``ProbabilisticModel`` objects are a collection of variables and CPDs. For instance we can define a model as:
4550

4651
.. code-block:: python
4752
48-
probabilistic_model = pyc.nn.ProbabilisticModel(variables=concepts,
49-
parametric_cpds=concept_cpd)
53+
probabilistic_model = pyc.nn.ProbabilisticModel(
54+
variables=concepts,
55+
parametric_cpds=concept_cpd
56+
)
5057
5158
Inference
5259
^^^^^^^^^
@@ -55,8 +62,11 @@ Inference is performed using efficient tensorial probabilistic inference algorit
5562

5663
.. code-block:: python
5764
58-
inference_engine = pyc.nn.AncestralSamplingInference(probabilistic_model=probabilistic_model,
59-
graph_learner=wanda, temperature=1.)
65+
inference_engine = pyc.nn.AncestralSamplingInference(
66+
probabilistic_model=probabilistic_model,
67+
graph_learner=wanda,
68+
temperature=1.
69+
)
6070
predictions = inference_engine.query(["c1"], evidence={'input': x})
6171
6272
@@ -203,9 +213,11 @@ Perform do-calculus interventions:
203213
)
204214
205215
# Apply intervention to encoder
206-
with intervention(policies=policy,
207-
strategies=strategy,
208-
target_concepts=["round", "smooth"]):
216+
with intervention(
217+
policies=policy,
218+
strategies=strategy,
219+
target_concepts=["round", "smooth"]
220+
):
209221
intervened_predictions = inference_engine.query(
210222
query_concepts=["round", "smooth", "bright", "class_A", "class_B"],
211223
evidence={'input': x}

doc/modules/low_level_api.rst

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,24 @@ takes as input both ``Endogenous`` and ``Exogenous`` representations and produce
8282

8383
.. code-block:: python
8484
85-
pyc.nn.HyperLinearCUC(in_features_endogenous=10, in_features_exogenous=7,
86-
embedding_size=24, out_features=3)
85+
pyc.nn.HyperLinearCUC(
86+
in_features_endogenous=10,
87+
in_features_exogenous=7,
88+
embedding_size=24,
89+
out_features=3
90+
)
8791
8892
As a final example, graph learners are a special layers that learn relationships between concepts.
8993
They do not follow the standard naming convention of encoders and predictors, but their purpose should be
9094
clear from their name.
9195

9296
.. code-block:: python
9397
94-
wanda = pyc.nn.WANDAGraphLearner(['c1', 'c2', 'c3'], ['task A', 'task B', 'task C'])
98+
wanda = pyc.nn.WANDAGraphLearner(
99+
['c1', 'c2', 'c3'],
100+
['task A', 'task B', 'task C']
101+
)
102+
95103
96104
Models
97105
^^^^^^^^^^^
@@ -123,8 +131,10 @@ At this API level, there are two types of inference that can be performed:
123131

124132
.. code-block:: python
125133
126-
int_strategy = pyc.nn.DoIntervention(model=concept_bottleneck_model["encoder"],
127-
constants=-10)
134+
int_strategy = pyc.nn.DoIntervention(
135+
model=concept_bottleneck_model["encoder"],
136+
constants=-10
137+
)
128138
129139
**Intervention Policies**: define the order/set of concepts to intervene on e.g., we can intervene on all concepts uniformly:
130140

@@ -136,10 +146,13 @@ At this API level, there are two types of inference that can be performed:
136146

137147
.. code-block:: python
138148
139-
with pyc.nn.intervention(policies=int_policy,
140-
strategies=int_strategy,
141-
target_concepts=[0, 2]) as new_encoder_layer:
142-
149+
with pyc.nn.intervention(
150+
policies=int_policy,
151+
strategies=int_strategy,
152+
target_concepts=[0, 2]
153+
) as new_encoder_layer:
143154
endogenous_concepts = new_encoder_layer(input=x)
144-
endogenous_tasks = concept_bottleneck_model['predictor'](endogenous=endogenous_concepts)
155+
endogenous_tasks = concept_bottleneck_model['predictor'](
156+
endogenous=endogenous_concepts
157+
)
145158

doc/modules/mid_level_api.rst

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,29 @@ At this API level, models are represented as probabilistic models where:
4040

4141
.. code-block:: python
4242
43-
concepts = pyc.EndogenousVariable(concepts=["c1", "c2", "c3"], parents=[],
44-
distribution=torch.distributions.RelaxedBernoulli)
43+
concepts = pyc.EndogenousVariable(
44+
concepts=["c1", "c2", "c3"],
45+
parents=[],
46+
distribution=torch.distributions.RelaxedBernoulli
47+
)
4548
4649
- ``ParametricCPD`` objects represent conditional probability distributions (CPDs) between variables in the probabilistic model and are parameterized by |pyc_logo| PyC layers. For instance we can define a list of three parametric CPDs for the above concepts as:
4750

4851
.. code-block:: python
4952
50-
concept_cpd = pyc.nn.ParametricCPD(concepts=["c1", "c2", "c3"],
51-
parametrization=pyc.nn.LinearZC(in_features=10, out_features=3))
53+
concept_cpd = pyc.nn.ParametricCPD(
54+
concepts=["c1", "c2", "c3"],
55+
parametrization=pyc.nn.LinearZC(in_features=10, out_features=3)
56+
)
5257
5358
- ``ProbabilisticModel`` objects are a collection of variables and CPDs. For instance we can define a model as:
5459

5560
.. code-block:: python
5661
57-
probabilistic_model = pyc.nn.ProbabilisticModel(variables=concepts,
58-
parametric_cpds=concept_cpd)
62+
probabilistic_model = pyc.nn.ProbabilisticModel(
63+
variables=concepts,
64+
parametric_cpds=concept_cpd
65+
)
5966
6067
Inference
6168
^^^^^^^^^
@@ -64,8 +71,11 @@ Inference is performed using efficient tensorial probabilistic inference algorit
6471

6572
.. code-block:: python
6673
67-
inference_engine = pyc.nn.AncestralSamplingInference(probabilistic_model=probabilistic_model,
68-
graph_learner=wanda, temperature=1.)
74+
inference_engine = pyc.nn.AncestralSamplingInference(
75+
probabilistic_model=probabilistic_model,
76+
graph_learner=wanda,
77+
temperature=1.
78+
)
6979
predictions = inference_engine.query(["c1"], evidence={'input': x})
7080
7181
@@ -106,8 +116,8 @@ Structural Equation Models
106116
.. code-block:: python
107117
108118
sem_model = ProbabilisticModel(
109-
variables=[exogenous_var, genotype_var, ...],
110-
parametric_cpds=[exogenous_cpd, genotype_cpd, ...]
119+
variables=[exogenous_var, genotype_var],
120+
parametric_cpds=[exogenous_cpd, genotype_cpd]
111121
)
112122
113123
Interventions
@@ -128,9 +138,9 @@ For example, to set ``smoking`` to 0 (prevent smoking) and query the effect on d
128138
)
129139
130140
with intervention(
131-
policies=UniformPolicy(out_features=1),
132-
strategies=smoking_strategy_0,
133-
target_concepts=["smoking"]
141+
policies=UniformPolicy(out_features=1),
142+
strategies=smoking_strategy_0,
143+
target_concepts=["smoking"]
134144
):
135145
intervened_results_0 = inference_engine.query(
136146
query_concepts=["genotype", "smoking", "tar", "cancer"],

0 commit comments

Comments
 (0)