-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Re-sync with internal repository (#99)
Co-authored-by: Facebook Community Bot <[email protected]>
- Loading branch information
1 parent
9c322fa
commit 82dc86c
Showing
17 changed files
with
1,345 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import abstractmethod | ||
from typing import Dict, List | ||
|
||
import xarray as xr | ||
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples | ||
from beanmachine.ppl.model.rv_identifier import RVIdentifier | ||
|
||
from ..base_ppl_impl import BasePPLImplementation | ||
|
||
|
||
class BaseBeanMachineImplementation(BasePPLImplementation): | ||
@abstractmethod | ||
def __init__(self, **model_attrs) -> None: | ||
... | ||
|
||
@abstractmethod | ||
def data_to_observations(self, data: xr.Dataset) -> Dict: | ||
"""Convert the model data into observation format used by Bean Machine""" | ||
... | ||
|
||
@abstractmethod | ||
def get_queries(self) -> List[RVIdentifier]: | ||
""" | ||
:returns: The list of random variables that we are interested in for a | ||
particular inference | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def extract_data_from_bm(self, samples: MonteCarloSamples) -> xr.Dataset: | ||
"""Convert the result of inference into a dataset for PPLBench.""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Dict, List | ||
|
||
import beanmachine.ppl as bm | ||
import numpy as np | ||
import torch | ||
import torch.distributions as dist | ||
import xarray as xr | ||
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples | ||
|
||
from .base_bm_impl import BaseBeanMachineImplementation | ||
|
||
|
||
class CrowdSourcedAnnotation(BaseBeanMachineImplementation): | ||
def __init__( | ||
self, | ||
n: int, | ||
k: int, | ||
num_categories: int, | ||
expected_correctness: float, | ||
num_labels_per_item: int, | ||
concentration: float, | ||
) -> None: | ||
""" | ||
:param attrs: model arguments | ||
""" | ||
self.n = n # Total number of items | ||
self.k = k # Total number of labelers | ||
self.num_categories = num_categories # Number of label classes | ||
self.expected_correctness = ( | ||
expected_correctness # Prior belief on correctness of labelers | ||
) | ||
self.num_labels_per_item = num_labels_per_item # Number of labels per item | ||
self.concentration = concentration # Strength of prior on expected_correctness | ||
self.alpha = torch.ones( | ||
num_categories, num_categories, dtype=torch.get_default_dtype() | ||
) | ||
for c in range(self.num_categories): | ||
for c_prime in range(self.num_categories): | ||
self.alpha[c][c_prime] = ( | ||
self.expected_correctness | ||
if (c == c_prime) | ||
else ((1 - self.expected_correctness) / (self.num_categories - 1)) | ||
) | ||
self.alpha *= self.concentration | ||
self.labelers = torch.zeros(n, num_labels_per_item, dtype=torch.int) | ||
|
||
@bm.random_variable | ||
def confusion_matrix(self, j: int, c: int) -> dist.Distribution: | ||
""" | ||
Confusion matrix for each labeler (j) and category (c), where each row is a Dirichlet distribution. | ||
""" | ||
return dist.Dirichlet(self.alpha[c]) | ||
|
||
@bm.random_variable | ||
def prev(self) -> dist.Distribution: | ||
""" | ||
Prevalance for each of the categories in a Dirichlet distribution so it adds up to 1. | ||
""" | ||
return dist.Dirichlet( | ||
torch.ones(self.num_categories) * (1.0 / self.num_categories) | ||
) | ||
|
||
@bm.random_variable | ||
def true_label(self, i: int) -> dist.Distribution: | ||
""" | ||
True label distribution for each item (i) given the prevalence. | ||
""" | ||
return dist.Categorical(self.prev()) | ||
|
||
@bm.random_variable | ||
def label(self, i: int, j: int) -> dist.Distribution: | ||
""" | ||
Observed label distribution for each item (i) and label (j). | ||
""" | ||
labeler = self.labelers[i, j].item() | ||
return dist.Categorical( | ||
self.confusion_matrix(labeler, self.true_label(i).item()) | ||
) | ||
|
||
def data_to_observations(self, data: xr.Dataset) -> Dict: | ||
""" | ||
Take data from the model generator and convert them to a dictionary that maps | ||
from random variables to observations, which could be used by Bean Machine. | ||
:param data: A dataset from the model generator | ||
:returns: a dictionary that maps random variables to their corresponding | ||
observations | ||
""" | ||
# transpose the dataset to ensure that it is the way we expect | ||
data = data.transpose("item", "item_label") | ||
|
||
labelers_val = torch.tensor( | ||
data.labelers.values, dtype=torch.get_default_dtype() | ||
) | ||
labels_val = torch.tensor(data.labels.values, dtype=torch.get_default_dtype()) | ||
|
||
self.labelers = labelers_val | ||
observations = {} | ||
for i in range(self.n): | ||
for j in range(self.num_labels_per_item): | ||
observations[self.label(i, j)] = labels_val[i, j] | ||
|
||
return observations | ||
|
||
def get_queries(self) -> List: | ||
confusion_matrix_variables = [] | ||
for j in range(self.k): | ||
for c in range(self.num_categories): | ||
confusion_matrix_variables.append(self.confusion_matrix(j, c)) | ||
return [self.prev()] + confusion_matrix_variables | ||
|
||
def extract_data_from_bm(self, samples: MonteCarloSamples) -> xr.Dataset: | ||
""" | ||
Takes the output of Bean Machine and converts into a format expected | ||
by PPLBench. | ||
:param samples: a MonteCarloSamples object returns by Bean Machine | ||
:returns: a dataset over inferred parameters | ||
""" | ||
|
||
prev_samples = ( | ||
samples.get_variable(self.prev(), include_adapt_steps=True).detach().numpy() | ||
).squeeze(0) | ||
|
||
individual_reviewer_samples = [] | ||
for j in range(self.k): | ||
category_samples = [] | ||
for c in range(self.num_categories): | ||
category_samples.append( | ||
# Shape is (1, num_iterations, num_categories) | ||
samples.get_variable( | ||
self.confusion_matrix(j, c), include_adapt_steps=True | ||
) | ||
.detach() | ||
.numpy() | ||
) | ||
# Collect samples from every category for each reviewer | ||
# The shape of np.stack(category_samples, axis=1).squeeze(0) is (num_categories, num_iterations, num_categories) | ||
individual_reviewer_samples.append( | ||
np.stack(category_samples, axis=1).squeeze(0) | ||
) | ||
|
||
# Combine all reviewers to create the final confusion matrix | ||
# The shape of individual_reviewer_samples is (k, num_categories, num_iterations, num_categories) | ||
confusion_matrix_samples = np.array(individual_reviewer_samples) | ||
|
||
# Swap axes so they're in the correct order for xr.Dataset format | ||
# The final shape of confusion_matrix_samples is (num_iterations, k, num_categories, num_categories) | ||
confusion_matrix_samples = np.swapaxes( | ||
np.swapaxes(confusion_matrix_samples, 0, 2), 1, 2 | ||
) | ||
|
||
return xr.Dataset( | ||
{ | ||
"prev": (["draw", "num_categories"], prev_samples), | ||
"confusion_matrix": ( | ||
["draw", "labelers", "true_category", "observed_category"], | ||
confusion_matrix_samples, | ||
), | ||
}, | ||
coords={ | ||
"draw": np.arange(prev_samples.shape[0]), | ||
"num_categories": np.arange(prev_samples.shape[-1]), | ||
"labelers": np.arange(confusion_matrix_samples.shape[1]), | ||
"true_category": np.arange(confusion_matrix_samples.shape[-2]), | ||
"observed_category": np.arange(confusion_matrix_samples.shape[-1]), | ||
}, | ||
) |
45 changes: 45 additions & 0 deletions
45
pplbench/ppls/beanmachine/examples/crowd_sourced_annotation.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
{ | ||
"model": { | ||
"class": "crowd_sourced_annotation.CrowdSourcedAnnotation", | ||
"args": {"n": 1000, "k": 10, "num_categories": 2, "expected_correctness": 0.8, "num_labels_per_item": 3, "concentration": 10} | ||
}, | ||
"iterations": 1000, | ||
"trials": 2, | ||
"ppls": [ | ||
{ | ||
"name": "beanmachine", | ||
"inference": { | ||
"class": "inference.MCMC", | ||
"infer_args": { | ||
"algorithm": "SingleSiteNewtonianMonteCarlo", | ||
"real_space_alpha": 1.0, | ||
"real_space_beta": 5.0 | ||
} | ||
}, | ||
"legend": {"color": "purple", "name": "bm-nmc"} | ||
}, | ||
{ | ||
"name": "beanmachine.graph", | ||
"inference": { | ||
"class": "inference.NMC" | ||
}, | ||
"legend": {"color": "green", "name": "bmgraph-NMC"} | ||
}, | ||
{ | ||
"name": "stan", | ||
"inference": { | ||
"class": "inference.MCMC", | ||
"infer_args": { | ||
"algorithm": "NUTS" | ||
} | ||
}, | ||
"legend": { | ||
"color": "red", | ||
"name": "stan-nuts" | ||
} | ||
} | ||
], | ||
"save_samples": true, | ||
"loglevels": {"beanmachine": "INFO", "pystan": "INFO", "pplbench": "INFO"}, | ||
"figures": {"generate_pll": true, "suffix": "png"} | ||
} |
36 changes: 36 additions & 0 deletions
36
pplbench/ppls/beanmachine/examples/logistic_regression.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
{ | ||
"model": { | ||
"class": "logistic_regression.LogisticRegression", | ||
"args": {"n": 2000, "k": 10, "rho": 3.0} | ||
}, | ||
"iterations": 1500, | ||
"trials": 2, | ||
"ppls": [ | ||
{ | ||
"name": "beanmachine", | ||
"inference": { | ||
"class": "inference.MCMC", | ||
"infer_args": { | ||
"algorithm": "GlobalNoUTurnSampler" | ||
} | ||
}, | ||
"legend": {"color": "purple", "name": "bm-nuts"} | ||
}, | ||
{ | ||
"name": "stan", | ||
"inference": { | ||
"class": "inference.MCMC", | ||
"infer_args": { | ||
"algorithm": "NUTS" | ||
} | ||
}, | ||
"legend": { | ||
"color": "red", | ||
"name": "stan-nuts" | ||
} | ||
} | ||
], | ||
"save_samples": true, | ||
"loglevels": {"beanmachine": "INFO", "pystan": "INFO", "pplbench": "INFO"}, | ||
"figures": {"generate_pll": true, "suffix": "png"} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
{ | ||
"model": { | ||
"class": "n_schools.NSchools", | ||
"args": {"n": 20000, "num_states": 20, "num_districts_per_state": 10} | ||
}, | ||
"iterations": 1500, | ||
"trials": 2, | ||
"ppls": [ | ||
{ | ||
"name": "beanmachine", | ||
"inference": { | ||
"class": "inference.MCMC", | ||
"infer_args": { | ||
"algorithm": "GlobalNoUTurnSampler" | ||
} | ||
}, | ||
"legend": {"color": "red", "name": "bm-NUTS"} | ||
}, | ||
{ | ||
"name": "beanmachine.graph", | ||
"inference": { | ||
"class": "inference.NMC" | ||
}, | ||
"legend": {"color": "green", "name": "bmgraph-NMC"} | ||
}, | ||
{ | ||
"name": "beanmachine.graph", | ||
"inference": { | ||
"class": "inference.GlobalMCMC" | ||
}, | ||
"legend": {"color": "purple", "name": "bmgraph-NUTS"} | ||
}, | ||
{ | ||
"name": "stan", | ||
"inference": { | ||
"class": "inference.MCMC", | ||
"infer_args": {"algorithm": "NUTS"} | ||
}, | ||
"legend": {"color": "blue", "name": "stan-NUTS"} | ||
} | ||
], | ||
"save_samples": true, | ||
"loglevels": {"pystan": "INFO", "pplbench": "INFO"}, | ||
"figures": {"generate_pll": true, "suffix": "png"} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
{ | ||
"model": { | ||
"class": "robust_regression.RobustRegression", | ||
"args": { | ||
"n": 2000, | ||
"k": 10 | ||
} | ||
}, | ||
"iterations": 1500, | ||
"trials": 3, | ||
"ppls": [ | ||
{ | ||
"name": "stan", | ||
"inference": { | ||
"class": "inference.VI", | ||
"infer_args": { | ||
"algorithm": "fullrank" | ||
} | ||
}, | ||
"legend": { | ||
"color": "red", | ||
"name": "stan-VI-full" | ||
} | ||
}, | ||
{ | ||
"name": "stan", | ||
"inference": { | ||
"class": "inference.MCMC" | ||
}, | ||
"legend": { | ||
"color": "blue" | ||
} | ||
}, | ||
{ | ||
"name": "beanmachine", | ||
"inference": { | ||
"class": "inference.MCMC", | ||
"infer_args": { | ||
"algorithm": "GlobalNoUTurnSampler" | ||
} | ||
}, | ||
"legend": {"color": "purple", "name": "bm-nuts"} | ||
} | ||
], | ||
"save_samples": true, | ||
"loglevels": { | ||
"beanmachine": "INFO", | ||
"pystan": "INFO", | ||
"pplbench": "INFO" | ||
}, | ||
"figures": { | ||
"generate_pll": true, | ||
"suffix": "png" | ||
} | ||
} |
Oops, something went wrong.