Skip to content

Commit 07de467

Browse files
authored
Updated to "Release of FiP" (#107)
1 parent 5abcee8 commit 07de467

40 files changed

+9721
-3
lines changed

examples/lightning_example/config/training.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# This recreates the latest run:
2-
# The seed of the run was: 65384781
31
seed_everything: true
42
trainer:
53
max_epochs: 2000

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "causica"
3-
version = "0.4.1"
3+
version = "0.4.2"
44
description = ""
55
readme = "README.md"
66
authors = ["Microsoft Research - Causica"]

research_experiments/fip/README.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# A Fixed-Point Approach for Causal Generative Modeling (FiP)
2+
[![Static Badge](https://img.shields.io/badge/paper-FiP-brightgreen?style=plastic&label=Paper&labelColor=yellow)
3+
](https://arxiv.org/pdf/2404.06969)
4+
5+
This repo implements FiP proposed in the ICML 2024 paper "A Fixed-Point Approach for Causal Generative Modeling".
6+
7+
FiP is a transformer-based approach to learn Structural Causal Models (SCMs) from observational data. To do so, FiP uses an equivalent formulation of SCMs that does not require Directed Acyclic Graphs (DAGs), viewed as fixed-point problems on the causally ordered variables. To infer topological orders (TOs), we propose to amortize the learning of a TO inference method on synthetically generated datasets by sequentially predicting the leaves of graphs seen during training.
8+
9+
## Dependency
10+
We use [Poetry](https://python-poetry.org/) to manage the project dependencies, they are specified in [pyproject](pyproject.toml) file. To install poetry, run:
11+
12+
```console
13+
curl -sSL https://install.python-poetry.org | python3 -
14+
```
15+
To install the environment, run `poetry install` in the directory of fip project.
16+
17+
## Prepare the data
18+
To reproduce the results obtained in the [paper](https://arxiv.org/pdf/2404.06969), you need to generate the data. A more detailed explanation on how to generate the data can be found in [README.md](src/fip/data_generation/README.md).
19+
20+
### AVICI / Csuite / Causal NF data generation
21+
To generate the [AVICI](https://arxiv.org/abs/2205.12934) synthetic data, run the following command:
22+
```console
23+
bash src/fip/data_generation/avici_data.sh
24+
```
25+
This executes the [avici_data.py](src/fip/dataset_generation/avici_data.py) file to generate various datasets from the dataset distributions presented in [AVICI](https://arxiv.org/abs/2205.12934). The generated data will be saved in the `src/fip/data`.
26+
27+
Similarly, to generate the [CSuite](https://arxiv.org/abs/2202.02195) and the [Causal NF](https://arxiv.org/abs/2306.05415) synthetic data, run the following commands:
28+
```console
29+
bash src/fip/data_generation/csuite_data.sh
30+
bash src/fip/data_generation/normalizing_data.sh
31+
```
32+
33+
## Run experiments
34+
In the [launchers](src/fip/launchers) directory, we provide scripts to run the experiments reported in the paper. A more detailed explanation on how to use these files can be found in [README.md](src/fip/launchers/README.md).
35+
36+
37+
### Zero-Shot Inference of TOs
38+
To train the TO inference method on AVICI data, run the following command:
39+
```console
40+
python -m fip.launchers.amortization
41+
```
42+
The model as well as the config file will be saved in `src/fip/outputs`.
43+
44+
45+
### Learn FiP with (Partial) Causal Knowledge
46+
To train FiP when the DAG is known, run the following command:
47+
```console
48+
python -m fip.launchers.scm_learning_with_ground_truth
49+
--ground_truth_case graph
50+
--standardize
51+
```
52+
The model as well as the config file will be saved in `src/fip/outputs`. If you want to train FiP, when only the TO is known, replace `--ground_truth_case graph` with `--ground_truth_case perm`. These commands assume that the datasets have been generated and saved in `src/fip/data`.
53+
54+
55+
### Learn FiP End-to-End
56+
To train FiP end-to-end, run the following command:
57+
```console
58+
python -m fip.launchers.scm_learning_with_predicted_truth
59+
--run_id <name_of_the_directory_containing_the_amortized_model>
60+
--standardize
61+
```
62+
The model as well as the config file will be saved in `src/fip/outputs`. This command assumes that a TO inference model has been trained and saved in a directory located at `src/fip/outputs/<name_of_the_directory_containing_the_amortized_model>`. This command also assumes that the datasets have been generated and saved in `src/fip/data`.
63+
64+
65+
66+
67+
68+
69+
70+
71+
72+
73+
74+
75+
76+
77+
78+
79+
80+
81+
82+
83+
84+
85+
```

research_experiments/fip/poetry.lock

Lines changed: 3729 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[tool.poetry]
2+
name = "fip"
3+
version = "0.1.0"
4+
description = "A Fixed-Point Approach for Causal Generative Modeling"
5+
readme = "README.md"
6+
authors = ["Meyer Scetbon", "Joel Jennings", "Agrin Hilmkil", "Cheng Zhang", "Chao Ma"]
7+
packages = [
8+
{include = "fip", from = "src"}
9+
]
10+
license = "MIT"
11+
12+
[tool.poetry.dependencies]
13+
python = "~3.10"
14+
causica = "0.4.1"
15+
16+
[build-system]
17+
requires = ["poetry-core>=1.0.0"]
18+
build-backend = "poetry.core.masonry.api"

research_experiments/fip/src/fip/__init__.py

Whitespace-only changes.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
seed_everything: 234
2+
model:
3+
class_path: fip.tasks.amortization.leaf_prediction.LeafPrediction
4+
init_args:
5+
learning_rate: 3e-4
6+
weight_decay: 5e-9
7+
d_model: 128
8+
num_heads: 8
9+
dim_key: 32
10+
num_layers: 4
11+
d_ff: 256
12+
dropout: 0.
13+
max_num_leaf: 100
14+
num_to_keep_training: 10
15+
distributed: false
16+
elimination_type: "self"
17+
trainer:
18+
max_epochs: 2000
19+
accelerator: gpu
20+
check_val_every_n_epoch: 10
21+
log_every_n_steps: 10
22+
profiler: "simple"
23+
devices: 1
24+
accumulate_grad_batches: 1
25+
best_checkpoint_callback:
26+
dirpath: "./outputs/"
27+
filename: "best_model"
28+
save_top_k: 1
29+
mode: "min"
30+
monitor: "val_loss"
31+
every_n_epochs: 1
32+
last_checkpoint_callback:
33+
save_last: true
34+
save_top_k: 0 # only the last checkpoint is saved
35+
early_stopping_callback:
36+
monitor: "val_loss"
37+
min_delta: 0.0001
38+
patience: 500
39+
verbose: False
40+
mode: "min"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
class_path: fip.data_modules.numpy_tensor_data_module.NumpyTensorDataModule
2+
init_args:
3+
data_dir : "fip/data/er_linear_dag_scm/total_nodes_5/seed_1/"
4+
train_batch_size: 2000
5+
test_batch_size: 2000
6+
standardize: true
7+
with_true_graph: true
8+
split_data_noise: true
9+
dod: false
10+
num_workers: 0
11+
shuffle: true
12+
num_interventions: 0
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
seed_everything: 234
2+
model:
3+
class_path: fip.tasks.scm_learning_with_ground_truth.scm_learning_true_graph.SCMLearningTrueGraph
4+
init_args:
5+
lr: 1e-4
6+
weight_decay: 1e-10
7+
d_model: 128
8+
dim_key: 32
9+
num_heads: 8
10+
d_feedforward: 128
11+
total_nodes: 4
12+
total_layers: 2
13+
dropout_prob: 0.
14+
mask_type: "none"
15+
attn_type: "causal"
16+
cost_type: "dot_product"
17+
learnable_loss: false
18+
distributed: false
19+
trainer:
20+
max_epochs: 1000
21+
accelerator: gpu
22+
devices: 1
23+
check_val_every_n_epoch: 1
24+
log_every_n_steps: 10
25+
inference_mode: false
26+
profiler: "simple"
27+
early_stopping_callback:
28+
monitor: "val_loss"
29+
min_delta: 0.0001
30+
patience: 500
31+
verbose: False
32+
mode: "min"
33+
best_checkpoint_callback:
34+
dirpath: "./outputs/"
35+
filename: "best_model"
36+
save_top_k: 1
37+
mode: "min"
38+
monitor: "val_loss"
39+
every_n_epochs: 1
40+
last_checkpoint_callback:
41+
save_last: true
42+
save_top_k: 0 # only the last checkpoint is saved
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
seed_everything: 5000
2+
model:
3+
class_path: fip.tasks.scm_learning_with_ground_truth.scm_learning_true_perm.SCMLearningTruePerm
4+
init_args:
5+
lr: 1e-4
6+
weight_decay: 1e-10
7+
d_model: 128
8+
dim_key: 32
9+
num_heads: 8
10+
d_feedforward: 128
11+
total_nodes: 2
12+
total_layers: 2
13+
dropout_prob: 0.
14+
mask_type: "triang"
15+
attn_type: "causal"
16+
cost_type: "dot_product"
17+
learnable_loss: false
18+
distributed: false
19+
trainer:
20+
max_epochs: 1000
21+
accelerator: gpu
22+
devices: 1
23+
check_val_every_n_epoch: 1
24+
log_every_n_steps: 10
25+
inference_mode: false
26+
profiler: "simple"
27+
early_stopping_callback:
28+
monitor: "val_loss"
29+
min_delta: 0.0001
30+
patience: 500
31+
verbose: False
32+
mode: "min"
33+
best_checkpoint_callback:
34+
dirpath: "./outputs/"
35+
filename: "best_model"
36+
save_top_k: 1
37+
mode: "min"
38+
monitor: "val_loss"
39+
every_n_epochs: 1
40+
last_checkpoint_callback:
41+
save_last: true
42+
save_top_k: 0 # only the last checkpoint is saved
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
seed_everything: 5000
2+
model:
3+
class_path: fip.tasks.scm_learning_with_predicted_truth.scm_learning_predicted_leaf.SCMLearningPredLeaf
4+
init_args:
5+
lr: 1e-4
6+
weight_decay: 1e-10
7+
leaf_model_path: "./outputs/amortized_pred_checkpoint/leaf_predicition/best_model.ckpt"
8+
leaf_config_path: "./outputs/amortized_pred_checkpoint/leaf_predicition/config.yaml"
9+
d_model: 128
10+
dim_key: 32
11+
num_heads: 8
12+
d_feedforward: 128
13+
total_nodes: 4
14+
total_layers: 2
15+
dropout_prob: 0.
16+
mask_type: "triang"
17+
attn_type: "causal"
18+
cost_type: "dot_product"
19+
learnable_loss: false
20+
distributed: false
21+
trainer:
22+
max_epochs: 1000
23+
accelerator: gpu
24+
devices: 1
25+
check_val_every_n_epoch: 1
26+
log_every_n_steps: 10
27+
inference_mode: false
28+
early_stopping_callback:
29+
monitor: "val_loss"
30+
min_delta: 0.0001
31+
patience: 500
32+
verbose: False
33+
mode: "min"
34+
best_checkpoint_callback:
35+
dirpath: "./outputs/"
36+
filename: "best_model"
37+
save_top_k: 1
38+
mode: "min"
39+
monitor: "val_loss"
40+
every_n_epochs: 1
41+
last_checkpoint_callback:
42+
save_last: true
43+
save_top_k: 0 # only the last checkpoint is saved
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
class_path: fip.data_modules.synthetic_data_module.SyntheticDataModule
2+
init_args:
3+
sem_samplers:
4+
class_path: fip.data_generation.sem_factory.SemSamplerFactory
5+
init_args:
6+
node_nums: [10]
7+
noises: ['gaussian']
8+
graphs: ['er', 'sf_in', 'sf_out']
9+
funcs: ['rff']
10+
config_gaussian:
11+
low: 0.2
12+
high: 2.0
13+
config_er:
14+
edges_per_node: [1,2,3]
15+
config_sf:
16+
edges_per_node: [1,2,3]
17+
attach_power: [1.]
18+
config_linear:
19+
weight_low: 1.
20+
weight_high: 3.
21+
bias_low: -3.
22+
bias_high: 3.
23+
config_rff:
24+
num_rf: 100
25+
length_low: 7.
26+
length_high: 10.
27+
out_low: 10.
28+
out_high: 20.
29+
bias_low: -3.
30+
bias_high: 3.
31+
train_batch_size: 2
32+
test_batch_size: 8
33+
sample_dataset_size: 200
34+
standardize: false
35+
num_samples_used: 200
36+
num_workers: 23
37+
pin_memory: true
38+
persistent_workers: true
39+
prefetch_factor: 2
40+
factor_epoch: 16
41+
num_sems: 0
42+
shuffle: true
43+
num_interventions: 0
44+
num_intervention_samples: 0
45+
proportion_treatment: 0.
46+
sample_counterfactuals: false

0 commit comments

Comments
 (0)