diff --git a/dowhy/causal_identifier/complete_adjustment.py b/dowhy/causal_identifier/complete_adjustment.py index 5e387bd7b..15341b44f 100644 --- a/dowhy/causal_identifier/complete_adjustment.py +++ b/dowhy/causal_identifier/complete_adjustment.py @@ -2,72 +2,69 @@ import pywhy_graphs -class CompleteAdjustment: +def adjustable(self, G, X, Y, Z=None): - def __init__(self, graph, x, y, z=None): - self._graph = graph - self._X = x - self._Y = y - if z is None: - self._Z = set() - else: - self._Z = z + if Z is None: + Z = set() - def adjustable(self, G): - #check amenability - if not self._is_amenable(): - return False - - #check if z contains any node from the forbidden set + #check amenability + if not self._is_amenable(): + return False + + #check if z contains any node from the forbidden set - if not self._check_forbidden_set(): - return False + if not self._check_forbidden_set(): + return False + + #find the proper back-door graph + proper_back_door_graph = self._proper_backdoor_graph() - #find the proper back-door graph - proper_back_door_graph = self._proper_backdoor_graph() + #check if z m-seperates x and y in Gpbd + if not pywhy_graphs.m_seperated(proper_back_door_graph, X, Y, Z): + return False + + return True - #check if z m-seperates x and y in Gpbd - if not pywhy_graphs.m_seperated(proper_back_door_graph, self._X, self._Y, self._Z): +def _is_amenable(G, X, Y): + dp = G.directed_paths(G, X, Y) + pdp = pywhy_graphs.possibly_directed_paths(G, dp) + ppdp = pywhy_graphs.proper_paths(G, pdp) + visible_edges = frozenset(pywhy_graphs.get_visible_edges(G, X)) + for elem in ppdp: + first_edge = elem[0] + if first_edge in visible_edges and first_edge[0] in X: + continue + else: return False - - return True + return True + +def _check_forbidden_set(G,X,Y,Z): + + if Z is None: + Z = set() - def _is_amenable(self): - dp = self._graph.directed_paths(self._graph, self._X, self._Y) - pdp = pywhy_graphs.possibly_directed_paths(self._graph, dp) - ppdp = pywhy_graphs.proper_paths(self._graph, pdp) - visible_edges = frozenset(pywhy_graphs.get_visible_edges(self._graph, self._X)) - for elem in ppdp: - first_edge = elem[0] - if first_edge in visible_edges and first_edge[0] in self._X: - continue - else: - return False + forbidden_set = pywhy_graphs.find_forbidden_set(G, X, Y) + if len(Z.intersection(forbidden_set)) > 0: + return False + else: return True - - def _check_forbidden_set(self): - forbidden_set = pywhy_graphs.find_forbidden_set(self._graph, self._X, self._Y) - if len(self._Z.intersection(forbidden_set)) > 0: - return False - else: - return True - def _proper_backdoor_graph(self): - dp = self._graph.directed_paths(self._X, self._Y) - pdp = pywhy_graphs.possibly_directed_paths(self._graph, dp) - ppdp = pywhy_graphs.proper_paths(self._graph, pdp) - visible_edges = pywhy_graphs.get_visible_edges(self._graph) # assuming all are directed edges - x_vedges = [] - for elem in visible_edges: - if elem[0] in self._X: - x_vedges.append(elem) - x_vedges = frozenset(x_vedges) - all_edges = [] - for elem in ppdp: - all_edges.extend(elem) - all_edges = frozenset(all_edges) - to_remove = all_edges.intersection(x_vedges) - G = self._graph.copy() - for elem in to_remove: - G.remove_edge(elem[0], elem[1], G.directed_edge_name) - return G \ No newline at end of file +def _proper_backdoor_graph(G,X,Y): + dp = G.directed_paths(X, Y) + pdp = pywhy_graphs.possibly_directed_paths(G, dp) + ppdp = pywhy_graphs.proper_paths(G, pdp) + visible_edges = pywhy_graphs.get_visible_edges(G) # assuming all are directed edges + x_vedges = [] + for elem in visible_edges: + if elem[0] in X: + x_vedges.append(elem) + x_vedges = frozenset(x_vedges) + all_edges = [] + for elem in ppdp: + all_edges.extend(elem) + all_edges = frozenset(all_edges) + to_remove = all_edges.intersection(x_vedges) + G = G.copy() + for elem in to_remove: + G.remove_edge(elem[0], elem[1], G.directed_edge_name) + return G \ No newline at end of file