-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from haraoka-screen/initial
Add files
- Loading branch information
Showing
10 changed files
with
1,434 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | ||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python | ||
|
||
name: Python package | ||
|
||
on: | ||
push: | ||
branches: [ "main" ] | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.8", "3.9", "3.10", "3.11"] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install flake8 pytest pytest-cov | ||
pip install -r requirements.txt | ||
sudo apt-get update | ||
sudo apt-get install r-base | ||
echo 'install.packages("muRty")' > install.r | ||
sudo Rscript install.r | ||
- name: Lint with flake8 | ||
run: | | ||
# stop the build if there are Python syntax errors or undefined names | ||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics | ||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide | ||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics | ||
- name: Test with pytest | ||
run: | | ||
pytest -v --cov=lingd --cov-report=term-missing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,39 @@ | ||
# lingd | ||
# LiNG Discovery Algorithm | ||
|
||
If you use LiNGD, install R and the muRty package and set the path to Rscript. | ||
|
||
## Installation | ||
|
||
```sh | ||
pip install git+https://github.com/cdt15/lingd.git | ||
``` | ||
|
||
## Usage | ||
|
||
```python | ||
from lingd import LiNGD | ||
|
||
# create instance and fit to the data. | ||
model = LiNGD() | ||
model.fit(X) | ||
|
||
# estimated adjacency matrices | ||
print(model.adjacency_matrices_) | ||
|
||
# cost of each matrices | ||
print(model.costs_) | ||
|
||
# stability of each matrices | ||
print(model.is_stables_) | ||
|
||
# bound of causal effects | ||
print(model.bound_of_causal_effect(1)) | ||
``` | ||
|
||
## Example | ||
|
||
[lingd/examples](./examples) | ||
|
||
## References | ||
|
||
* Gustavo Lacerda, Peter Spirtes, Joseph Ramsey, and Patrik O. Hoyer. **Discovering cyclic causal models by independent components analysis**. *In Proceedings of the Twenty-Fourth Conference on Uncertainty in Artificial Intelligence (UAI'08)*. AUAI Press, Arlington, Virginia, USA, 366– 374. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .lingd import LiNGD | ||
|
||
__all__ = [ | ||
"LiNGD", | ||
] | ||
|
||
__version__ = "1.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import numpy as np | ||
from sklearn.utils import check_scalar, check_array | ||
from sklearn.decomposition import FastICA | ||
|
||
import os | ||
import shutil | ||
import subprocess | ||
import tempfile | ||
|
||
|
||
class LiNGD: | ||
"""Implementation of LiNG Discovery algorithm [1]_ | ||
References | ||
---------- | ||
.. [1] Gustavo Lacerda, Peter Spirtes, Joseph Ramsey, and Patrik O. Hoyer. | ||
Discovering cyclic causal models by independent components analysis. | ||
In Proceedings of the Twenty-Fourth Conference on Uncertainty | ||
in Artificial Intelligence (UAI'08). AUAI Press, Arlington, Virginia, USA, 366–374. | ||
""" | ||
|
||
def __init__(self, k=5): | ||
""" Construct a LiNG-D model. | ||
Parameters | ||
---------- | ||
k : int, option (default=5) | ||
Number of candidate causal graphs to estimate. | ||
""" | ||
k = check_scalar(k, "k", int, min_val=1) | ||
|
||
self._k = k | ||
self._fitted = False | ||
|
||
def fit(self, X): | ||
""" Fit the model to X. | ||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
Training data, where ``n_samples`` is the number of samples | ||
and ``n_features`` is the number of features. | ||
Returns | ||
------- | ||
self : object | ||
Returns the instance itself. | ||
""" | ||
X = check_array(X) | ||
|
||
ica = FastICA() | ||
ica.fit_transform(X) | ||
W_ica = ica.components_ | ||
|
||
permutes, costs = self._run_murty(1 / np.abs(W_ica), k=self._k) | ||
costs = np.array(costs) | ||
|
||
estimated_Bs = [] | ||
for i, p in enumerate(permutes): | ||
PW_ica = np.zeros_like(W_ica) | ||
PW_ica[p] = W_ica | ||
|
||
D = np.diag(PW_ica)[:, np.newaxis] | ||
|
||
W_estimate = PW_ica / D | ||
B_estimate = np.eye(len(W_estimate)) - W_estimate | ||
|
||
estimated_Bs.append(B_estimate) | ||
estimated_Bs = np.array(estimated_Bs) | ||
|
||
is_stables = [] | ||
for B in estimated_Bs: | ||
values, _ = np.linalg.eig(B) | ||
is_stables.append(all(abs(values) < 1)) | ||
is_stables = np.array(is_stables) | ||
|
||
self._X = X | ||
self._adjacency_matrices = estimated_Bs | ||
self._costs = costs | ||
self._is_stables = is_stables | ||
self._fitted = True | ||
|
||
return self | ||
|
||
def bound_of_causal_effect(self, target_index): | ||
""" | ||
This method calculates the causal effect from target_index to each feature | ||
Parameters | ||
---------- | ||
target_index : int | ||
The index of the intervention target. | ||
Returns | ||
------- | ||
causal_effects : array-like, shape (k, n_features) | ||
list of causal effects. | ||
""" | ||
|
||
self._check_is_fitted() | ||
|
||
target_index = check_scalar( | ||
target_index, | ||
"target_index", | ||
int, | ||
min_val=0, | ||
max_val=self._X.shape[1] - 1 | ||
) | ||
|
||
aces = [] | ||
for B in self._adjacency_matrices: | ||
X1 = self._intervention(self._X, B, target_index, 1) | ||
X0 = self._intervention(self._X, B, target_index, 0) | ||
ace = (X1.mean(axis=0) - X0.mean(axis=0)).tolist() | ||
aces.append(ace) | ||
|
||
return np.array(aces) | ||
|
||
@property | ||
def adjacency_matrices_(self): | ||
self._check_is_fitted() | ||
return self._adjacency_matrices | ||
|
||
@property | ||
def costs_(self): | ||
self._check_is_fitted() | ||
return self._costs | ||
|
||
@property | ||
def is_stables_(self): | ||
self._check_is_fitted() | ||
return self._is_stables | ||
|
||
def _run_murty(self, X, k): | ||
# XXX: muRty occurs an error if X has 2 variables and k is greater than 2. | ||
if X.shape[1] == 2: | ||
k = 1 | ||
|
||
try: | ||
temp_dir = tempfile.mkdtemp() | ||
|
||
path = os.path.join(os.path.dirname(__file__), "murty.r") | ||
|
||
args = [f"--temp_dir={temp_dir}"] | ||
|
||
np.savetxt(os.path.join(temp_dir, "X.csv"), X, delimiter=",") | ||
np.savetxt(os.path.join(temp_dir, "k.csv"), [k], delimiter=",") | ||
|
||
# run | ||
ret = subprocess.run(["Rscript", path, *args], capture_output=True) | ||
if ret.returncode != 0: | ||
if ret.returncode == 2: | ||
msg = "muRty is not installed." | ||
else: | ||
msg = ret.stderr.decode() | ||
raise RuntimeError(msg) | ||
|
||
# retrieve result | ||
permutes = [] | ||
|
||
for f in os.listdir(temp_dir): | ||
if not f.startswith("solution"): | ||
continue | ||
|
||
solution = np.loadtxt(os.path.join(temp_dir, f), delimiter=",", skiprows=1) | ||
|
||
permute = [x[1] for x in np.argwhere(solution > 0)] | ||
permutes.append(permute) | ||
|
||
costs = np.loadtxt(os.path.join(temp_dir, "costs.csv"), delimiter=",", skiprows=1) | ||
costs = np.array(costs).flatten().tolist() | ||
except FileNotFoundError: | ||
raise RuntimeError("Rscript is not found.") | ||
except BaseException as e: | ||
raise RuntimeError(str(e)) | ||
finally: | ||
if os.path.exists(temp_dir): | ||
shutil.rmtree(temp_dir) | ||
|
||
return permutes, costs | ||
|
||
def _check_is_fitted(self): | ||
if not self._fitted: | ||
raise RuntimeError("This instance is not fitted yet. Call 'fit' with \ | ||
appropriate arguments before using this instance.") | ||
|
||
def _intervention(self, X, B, target_index, value): | ||
# estimate error terms | ||
e = ((np.eye(len(B)) - B) @ X.T).T | ||
|
||
# set the given intervention value | ||
e[:, target_index] = value | ||
|
||
# resample errors | ||
resample_index = np.random.choice(np.arange(len(e)), size=len(e)) | ||
e = e[resample_index] | ||
|
||
# remove edges | ||
B_ = B.copy() | ||
B_[target_index, :] = 0 | ||
|
||
# generate data | ||
A = np.linalg.inv(np.eye(len(B)) - B_) | ||
X_ = (A @ e.T).T | ||
|
||
return X_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
if (!require(muRty)) { | ||
quit("no", 2) | ||
} | ||
|
||
library(muRty) | ||
|
||
# command args | ||
temp_dir <- NULL | ||
|
||
args <- commandArgs(trailingOnly=TRUE) | ||
for (arg in args) { | ||
s <- strsplit(arg, "=")[[1]] | ||
if (length(s) < 2) { | ||
next | ||
} | ||
|
||
if (s[1] == "--temp_dir") { | ||
temp_dir <- paste(s[2:length(s)], collapse="=") | ||
} | ||
} | ||
|
||
# function args | ||
path <- file.path(temp_dir, "X.csv") | ||
X <- read.csv(path, sep=',', header=FALSE) | ||
|
||
path <- file.path(temp_dir, "k.csv") | ||
k <- read.csv(path, sep=',', header=FALSE) | ||
|
||
# run muRty | ||
result <- get_k_best(mat=as.matrix(X), k_best=k) | ||
|
||
# write result | ||
count = 0 | ||
for (i in seq_along(result$solutions)) { | ||
filename <- paste("solutions", sprintf("%08d", i), ".csv", sep="") | ||
path <- file.path(temp_dir, filename) | ||
write.csv(result$solutions[i], path, row.names=FALSE) | ||
} | ||
|
||
path <- file.path(temp_dir, "costs.csv") | ||
write.csv(result$costs, path, row.names=FALSE) | ||
|
||
quit("no", 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
numpy | ||
scikit-learn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import setuptools | ||
|
||
with open('README.md', 'r', encoding='utf-8') as fh: | ||
README = fh.read() | ||
|
||
import lingd | ||
|
||
VERSION = lingd.__version__ | ||
|
||
setuptools.setup( | ||
name='lingd', | ||
version=VERSION, | ||
description='LiNG discovery algorithm', | ||
long_description=README, | ||
long_description_content_type='text/markdown', | ||
install_requires=[ | ||
'numpy', | ||
'scikit-learn', | ||
], | ||
url='https://github.com/cdt15/lingd', | ||
packages=setuptools.find_packages(exclude=['tests', 'examples']), | ||
package_data={ | ||
'lingd': ['*.r'], | ||
}, | ||
classifiers=[ | ||
'Programming Language :: Python :: 3', | ||
'License :: OSI Approved :: MIT License', | ||
'Operating System :: OS Independent', | ||
], | ||
python_requires='>=3.8', | ||
) |
Empty file.
Oops, something went wrong.