Skip to content

Commit

Permalink
Re-sync with internal repository (#99)
Browse files Browse the repository at this point in the history
Co-authored-by: Facebook Community Bot <[email protected]>
  • Loading branch information
facebook-github-bot and facebook-github-bot authored Feb 4, 2022
1 parent 9c322fa commit 82dc86c
Show file tree
Hide file tree
Showing 17 changed files with 1,345 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pplbench/ppls/beanmachine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
37 changes: 37 additions & 0 deletions pplbench/ppls/beanmachine/base_bm_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import abstractmethod
from typing import Dict, List

import xarray as xr
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples
from beanmachine.ppl.model.rv_identifier import RVIdentifier

from ..base_ppl_impl import BasePPLImplementation


class BaseBeanMachineImplementation(BasePPLImplementation):
@abstractmethod
def __init__(self, **model_attrs) -> None:
...

@abstractmethod
def data_to_observations(self, data: xr.Dataset) -> Dict:
"""Convert the model data into observation format used by Bean Machine"""
...

@abstractmethod
def get_queries(self) -> List[RVIdentifier]:
"""
:returns: The list of random variables that we are interested in for a
particular inference
"""
...

@abstractmethod
def extract_data_from_bm(self, samples: MonteCarloSamples) -> xr.Dataset:
"""Convert the result of inference into a dataset for PPLBench."""
...
171 changes: 171 additions & 0 deletions pplbench/ppls/beanmachine/crowd_sourced_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List

import beanmachine.ppl as bm
import numpy as np
import torch
import torch.distributions as dist
import xarray as xr
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples

from .base_bm_impl import BaseBeanMachineImplementation


class CrowdSourcedAnnotation(BaseBeanMachineImplementation):
def __init__(
self,
n: int,
k: int,
num_categories: int,
expected_correctness: float,
num_labels_per_item: int,
concentration: float,
) -> None:
"""
:param attrs: model arguments
"""
self.n = n # Total number of items
self.k = k # Total number of labelers
self.num_categories = num_categories # Number of label classes
self.expected_correctness = (
expected_correctness # Prior belief on correctness of labelers
)
self.num_labels_per_item = num_labels_per_item # Number of labels per item
self.concentration = concentration # Strength of prior on expected_correctness
self.alpha = torch.ones(
num_categories, num_categories, dtype=torch.get_default_dtype()
)
for c in range(self.num_categories):
for c_prime in range(self.num_categories):
self.alpha[c][c_prime] = (
self.expected_correctness
if (c == c_prime)
else ((1 - self.expected_correctness) / (self.num_categories - 1))
)
self.alpha *= self.concentration
self.labelers = torch.zeros(n, num_labels_per_item, dtype=torch.int)

@bm.random_variable
def confusion_matrix(self, j: int, c: int) -> dist.Distribution:
"""
Confusion matrix for each labeler (j) and category (c), where each row is a Dirichlet distribution.
"""
return dist.Dirichlet(self.alpha[c])

@bm.random_variable
def prev(self) -> dist.Distribution:
"""
Prevalance for each of the categories in a Dirichlet distribution so it adds up to 1.
"""
return dist.Dirichlet(
torch.ones(self.num_categories) * (1.0 / self.num_categories)
)

@bm.random_variable
def true_label(self, i: int) -> dist.Distribution:
"""
True label distribution for each item (i) given the prevalence.
"""
return dist.Categorical(self.prev())

@bm.random_variable
def label(self, i: int, j: int) -> dist.Distribution:
"""
Observed label distribution for each item (i) and label (j).
"""
labeler = self.labelers[i, j].item()
return dist.Categorical(
self.confusion_matrix(labeler, self.true_label(i).item())
)

def data_to_observations(self, data: xr.Dataset) -> Dict:
"""
Take data from the model generator and convert them to a dictionary that maps
from random variables to observations, which could be used by Bean Machine.
:param data: A dataset from the model generator
:returns: a dictionary that maps random variables to their corresponding
observations
"""
# transpose the dataset to ensure that it is the way we expect
data = data.transpose("item", "item_label")

labelers_val = torch.tensor(
data.labelers.values, dtype=torch.get_default_dtype()
)
labels_val = torch.tensor(data.labels.values, dtype=torch.get_default_dtype())

self.labelers = labelers_val
observations = {}
for i in range(self.n):
for j in range(self.num_labels_per_item):
observations[self.label(i, j)] = labels_val[i, j]

return observations

def get_queries(self) -> List:
confusion_matrix_variables = []
for j in range(self.k):
for c in range(self.num_categories):
confusion_matrix_variables.append(self.confusion_matrix(j, c))
return [self.prev()] + confusion_matrix_variables

def extract_data_from_bm(self, samples: MonteCarloSamples) -> xr.Dataset:
"""
Takes the output of Bean Machine and converts into a format expected
by PPLBench.
:param samples: a MonteCarloSamples object returns by Bean Machine
:returns: a dataset over inferred parameters
"""

prev_samples = (
samples.get_variable(self.prev(), include_adapt_steps=True).detach().numpy()
).squeeze(0)

individual_reviewer_samples = []
for j in range(self.k):
category_samples = []
for c in range(self.num_categories):
category_samples.append(
# Shape is (1, num_iterations, num_categories)
samples.get_variable(
self.confusion_matrix(j, c), include_adapt_steps=True
)
.detach()
.numpy()
)
# Collect samples from every category for each reviewer
# The shape of np.stack(category_samples, axis=1).squeeze(0) is (num_categories, num_iterations, num_categories)
individual_reviewer_samples.append(
np.stack(category_samples, axis=1).squeeze(0)
)

# Combine all reviewers to create the final confusion matrix
# The shape of individual_reviewer_samples is (k, num_categories, num_iterations, num_categories)
confusion_matrix_samples = np.array(individual_reviewer_samples)

# Swap axes so they're in the correct order for xr.Dataset format
# The final shape of confusion_matrix_samples is (num_iterations, k, num_categories, num_categories)
confusion_matrix_samples = np.swapaxes(
np.swapaxes(confusion_matrix_samples, 0, 2), 1, 2
)

return xr.Dataset(
{
"prev": (["draw", "num_categories"], prev_samples),
"confusion_matrix": (
["draw", "labelers", "true_category", "observed_category"],
confusion_matrix_samples,
),
},
coords={
"draw": np.arange(prev_samples.shape[0]),
"num_categories": np.arange(prev_samples.shape[-1]),
"labelers": np.arange(confusion_matrix_samples.shape[1]),
"true_category": np.arange(confusion_matrix_samples.shape[-2]),
"observed_category": np.arange(confusion_matrix_samples.shape[-1]),
},
)
45 changes: 45 additions & 0 deletions pplbench/ppls/beanmachine/examples/crowd_sourced_annotation.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"model": {
"class": "crowd_sourced_annotation.CrowdSourcedAnnotation",
"args": {"n": 1000, "k": 10, "num_categories": 2, "expected_correctness": 0.8, "num_labels_per_item": 3, "concentration": 10}
},
"iterations": 1000,
"trials": 2,
"ppls": [
{
"name": "beanmachine",
"inference": {
"class": "inference.MCMC",
"infer_args": {
"algorithm": "SingleSiteNewtonianMonteCarlo",
"real_space_alpha": 1.0,
"real_space_beta": 5.0
}
},
"legend": {"color": "purple", "name": "bm-nmc"}
},
{
"name": "beanmachine.graph",
"inference": {
"class": "inference.NMC"
},
"legend": {"color": "green", "name": "bmgraph-NMC"}
},
{
"name": "stan",
"inference": {
"class": "inference.MCMC",
"infer_args": {
"algorithm": "NUTS"
}
},
"legend": {
"color": "red",
"name": "stan-nuts"
}
}
],
"save_samples": true,
"loglevels": {"beanmachine": "INFO", "pystan": "INFO", "pplbench": "INFO"},
"figures": {"generate_pll": true, "suffix": "png"}
}
36 changes: 36 additions & 0 deletions pplbench/ppls/beanmachine/examples/logistic_regression.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"model": {
"class": "logistic_regression.LogisticRegression",
"args": {"n": 2000, "k": 10, "rho": 3.0}
},
"iterations": 1500,
"trials": 2,
"ppls": [
{
"name": "beanmachine",
"inference": {
"class": "inference.MCMC",
"infer_args": {
"algorithm": "GlobalNoUTurnSampler"
}
},
"legend": {"color": "purple", "name": "bm-nuts"}
},
{
"name": "stan",
"inference": {
"class": "inference.MCMC",
"infer_args": {
"algorithm": "NUTS"
}
},
"legend": {
"color": "red",
"name": "stan-nuts"
}
}
],
"save_samples": true,
"loglevels": {"beanmachine": "INFO", "pystan": "INFO", "pplbench": "INFO"},
"figures": {"generate_pll": true, "suffix": "png"}
}
45 changes: 45 additions & 0 deletions pplbench/ppls/beanmachine/examples/n_schools.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"model": {
"class": "n_schools.NSchools",
"args": {"n": 20000, "num_states": 20, "num_districts_per_state": 10}
},
"iterations": 1500,
"trials": 2,
"ppls": [
{
"name": "beanmachine",
"inference": {
"class": "inference.MCMC",
"infer_args": {
"algorithm": "GlobalNoUTurnSampler"
}
},
"legend": {"color": "red", "name": "bm-NUTS"}
},
{
"name": "beanmachine.graph",
"inference": {
"class": "inference.NMC"
},
"legend": {"color": "green", "name": "bmgraph-NMC"}
},
{
"name": "beanmachine.graph",
"inference": {
"class": "inference.GlobalMCMC"
},
"legend": {"color": "purple", "name": "bmgraph-NUTS"}
},
{
"name": "stan",
"inference": {
"class": "inference.MCMC",
"infer_args": {"algorithm": "NUTS"}
},
"legend": {"color": "blue", "name": "stan-NUTS"}
}
],
"save_samples": true,
"loglevels": {"pystan": "INFO", "pplbench": "INFO"},
"figures": {"generate_pll": true, "suffix": "png"}
}
55 changes: 55 additions & 0 deletions pplbench/ppls/beanmachine/examples/robust_regression.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"model": {
"class": "robust_regression.RobustRegression",
"args": {
"n": 2000,
"k": 10
}
},
"iterations": 1500,
"trials": 3,
"ppls": [
{
"name": "stan",
"inference": {
"class": "inference.VI",
"infer_args": {
"algorithm": "fullrank"
}
},
"legend": {
"color": "red",
"name": "stan-VI-full"
}
},
{
"name": "stan",
"inference": {
"class": "inference.MCMC"
},
"legend": {
"color": "blue"
}
},
{
"name": "beanmachine",
"inference": {
"class": "inference.MCMC",
"infer_args": {
"algorithm": "GlobalNoUTurnSampler"
}
},
"legend": {"color": "purple", "name": "bm-nuts"}
}
],
"save_samples": true,
"loglevels": {
"beanmachine": "INFO",
"pystan": "INFO",
"pplbench": "INFO"
},
"figures": {
"generate_pll": true,
"suffix": "png"
}
}
Loading

0 comments on commit 82dc86c

Please sign in to comment.