-
Notifications
You must be signed in to change notification settings - Fork 0
/
jssa.qmd
330 lines (244 loc) · 10.5 KB
/
jssa.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
---
title: Jolly-Seber-Schwarz-Arnason
description: Estimating survival and abundance in PyMC
author:
name: Philip T. Patton
affiliation:
- Marine Mammal Research Program
- Hawaiʻi Institute of Marine Biology
date: today
bibliography: refs.bib
jupyter: pymc
---
In this notebook, I explore the Jolly-Seber-Schwarz-Arnason (JSSA) model for estimating survival and abundance using capture recapture data. JSSA is very similar to the CJS framework, except that it also models entry into the population, permitting esimation of the superpopulation size. Like the CJS notebook, I have drawn considerable inspiration from Austin Rochford's [notebook](https://austinrochford.com/posts/2018-01-31-capture-recapture.html) on capture-recapture in PyMC, the [second chapter](https://github.com/philpatton/autocapture) of my dissertation (a work in progress), and @mccrea2014.
As a demonstration of the JSSA framework, I use the classic European dipper data of @lebreton1992. I first convert the dataset into the $M$-array, since the data is in capture history format.
```{python}
from pymc.distributions.dist_math import factln
from scipy.linalg import circulant
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
import pymc as pm
import pytensor.tensor as pt
plt.rcParams['figure.dpi'] = 600
plt.style.use('fivethirtyeight')
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.bottom'] = False
sns.set_palette("tab10")
def create_recapture_array(history):
"""Create the recapture array from a capture history."""
_, occasion_count = history.shape
interval_count = occasion_count - 1
recapture_array = np.zeros((interval_count, interval_count), int)
for occasion in range(occasion_count - 1):
# which individuals, captured at t, were later recaptured?
captured_this_time = history[:, occasion] == 1
captured_later = (history[:, (occasion + 1):] > 0).any(axis=1)
now_and_later = captured_this_time & captured_later
# when were they next recaptured?
remaining_history = history[now_and_later, (occasion + 1):]
next_capture_occasion = (remaining_history.argmax(axis=1)) + occasion
# how many of them were there?
ind, count = np.unique(next_capture_occasion, return_counts=True)
recapture_array[occasion, ind] = count
return recapture_array.astype(int)
def create_m_array(history):
'''Create the m-array from a capture history.'''
# leftmost column of the m-array
number_released = history.sum(axis=0)
# core of the m-array
recapture_array = create_recapture_array(history)
number_recaptured = recapture_array.sum(axis=1)
# no animals that were released on the last occasion are recaptured
number_recaptured = np.append(number_recaptured, 0)
never_recaptured = number_released - number_recaptured
# add a dummy row at the end to make everything stack
zeros = np.zeros(recapture_array.shape[1])
recapture_array = np.row_stack((recapture_array, zeros))
# stack the relevant values into the m-array
m_array = np.column_stack((number_released, recapture_array, never_recaptured))
return m_array.astype(int)
def fill_lower_diag_ones(x):
'''Fill the lower diagonal of a matrix with ones.'''
return pt.triu(x) + pt.tril(pt.ones_like(x), k=-1)
dipper = np.loadtxt('dipper.csv', delimiter=',').astype(int)
dipper[:5]
```
```{python}
dipper_m = create_m_array(dipper)
dipper_m
```
The JSSA model requires modeling the number of unmarked animals that were released during an occasion. We can calculate this using the $m$-array by subtracting the number of marked animals who were released from the total number of released animals.
```{python}
recapture_array = create_recapture_array(dipper)
number_released = dipper_m[:,0]
number_marked_released = recapture_array.sum(axis=0)
# shift number_released to get the years to align
number_unmarked_released = number_released[1:] - number_marked_released
# add the number released on the first occasion
number_unmarked_released = np.concatenate(
([number_released[0]], number_unmarked_released)
)
number_unmarked_released
```
Similar to the CJS model, this model requires a number of tricks to vectorize the operations. Many pertain to the distribution of the unmarked individuals. Similar to [occupancy notebook](https://philpatton.github.io/occ.html), I use a custom distribution to model the entrants into the population. Austin Rochford refers to this as an incomplete multinomial distribution.
```{python}
n, occasion_count = dipper.shape
interval_count = occasion_count - 1
# generate indices for the m_array
intervals = np.arange(interval_count)
row_indices = np.reshape(intervals, (interval_count, 1))
col_indices = np.reshape(intervals, (1, interval_count))
# matrix indicating the number of intervals between sampling occassions
intervals_between = np.clip(col_indices - row_indices, 0, np.inf)
# index for generating sequences like [[0], [0,1], [0,1,2]]
alive_yet_unmarked_index = circulant(np.arange(occasion_count))
```
```{python}
def logp(x, n, p):
x_last = n - x.sum()
# calculate thwe logp for the observations
res = factln(n) + pt.sum(x * pt.log(p) - factln(x)) \
+ x_last * pt.log(1 - p.sum()) - factln(x_last)
# ensure that the good conditions are met.
good_conditions = pt.all(x >= 0) & pt.all(x <= n) & (pt.sum(x) <= n) & \
(n >= 0)
res = pm.math.switch(good_conditions, res, -np.inf)
return res
```
```{python}
# m-array for the CJS portion of the likelihood
cjs_marr = dipper_m[:-1,1:]
cjs_marr
```
Aside from the unmarked portion of the model, the JSSA model is essentially identical to the CJS model above. In this version, I also model survival as time-varying, holding other parameters constant $p(\cdot)\phi(t)b_0(\cdot)$
```{python}
#| fig-cap: Visual representation of the JSSA model
#| label: fig-jssa
# JSSA produces this warning. it's unclear why since it samples well
import warnings
warnings.filterwarnings(
"ignore",
message="Failed to infer_shape from Op AdvancedSubtensor"
)
with pm.Model() as jssa:
## Priors
# catchability, survival, and pent
p = pm.Uniform('p', 0., 1.)
phi = pm.Uniform('phi', 0., 1., shape=interval_count)
b0 = pm.Uniform('b0', 0., 1.)
# beta = pm.Dirichlet('beta', np.ones(interval_count))
# # only estimate first beta, others constant
b_other = (1 - b0) / (interval_count)
beta = pt.concatenate(
([b0], pt.repeat(b_other, interval_count))
)
# improper flat prior for N
flat_dist = pm.Flat.dist()
total_captured = number_unmarked_released.sum()
N = pm.Truncated("N", flat_dist, lower=total_captured)
## Entry
# add [1] to ensure the addition of the raw beta_0
p_alive_yet_unmarked = pt.concatenate(
([1], pt.cumprod((1 - p) * phi))
)
# tril produces the [[0], [0,1], [0,1,2]] patterns for the recursion
psi = pt.tril(
beta * p_alive_yet_unmarked[alive_yet_unmarked_index]
).sum(axis=1)
# distribution for the unmarked animals
unmarked = pm.CustomDist(
'Unmarked captures',
N,
psi * p,
logp=logp,
observed=number_unmarked_released
)
## CJS
# broadcast phi into a matrix
phi_mat = pt.ones_like(recapture_array) * phi
phi_mat = fill_lower_diag_ones(phi_mat) # fill irrelevant values
# probability of surviving between i and j in the m-array
p_alive = pt.cumprod(phi_mat, axis=1)
p_alive = pt.triu(p_alive) # select relevant (upper triangle) values
# p_not_cap: probability of not being captured between i and j
p_not_cap = pt.triu((1 - p) ** intervals_between)
# nu: probabilities associated with each cell in the m-array
nu = p_alive * p_not_cap * p
# probability for the animals that were never recaptured
chi = 1 - nu.sum(axis=1)
# combine the probabilities into a matrix
chi = pt.reshape(chi, (interval_count, 1))
marr_probs = pt.horizontal_stack(nu, chi)
# distribution of the m-array
marr = pm.Multinomial(
'M-array',
n=number_released[:-1], # last count irrelevant for CJS
p=marr_probs,
observed=cjs_marr
)
pm.model_to_graphviz(jssa)
```
```{python}
with jssa:
jssa_idata = pm.sample()
```
```{python}
#| fig-cap: Traceplots for the dipper JSSA model. MLEs from the openCR package shown by vertical and horizontal lines.
#| label: fig-jssa_trace
phi_mle = [0.63, 0.46, 0.48, 0.62, 0.61, 0.58]
p_mle = [0.9]
b0_mle = [0.079]
N_mle = [310]
az.plot_trace(
jssa_idata,
figsize=(10, 8),
lines=[("phi", {}, phi_mle), ("p", {}, [p_mle]), ("N", {}, [N_mle]), ("b0", {}, [b0_mle])]
);
```
The traceplots include the maximum likelihood estimates from the model, which I estimated usingthe openCR package in R. Again, there is high level of agreement between the two methods. I plot the survival estimates over time, and the posterior draws of $N$, $p$, and $b$.
```{python}
#| fig-cap: Violin plots of the posterior for apparent surival over time from the cormorant CJS. Horizontal lines represent the posterior median.
#| label: fig-jssa_surv
fig, ax = plt.subplots(figsize=(6,4))
t = np.arange(1981, 1987)
phi_samps = az.extract(jssa_idata, var_names='phi').values.T
phi_median = np.median(phi_samps, axis=0)
ax.plot(t, phi_median, linestyle='dotted', color='lightgray', linewidth=2)
ax.violinplot(phi_samps, t, showmedians=True, showextrema=False)
ax.set_ylim((0,1))
ax.set_ylabel(r'Apparent survival $\phi$')
ax.set_title(r'Dipper JSSA')
plt.show()
```
```{python}
#| fig-cap: 'Posterior draws of $N,$ $b_0,$ and $p$ from the dipper JSSA model.'
#| label: fig-jssa_post
post = jssa_idata.posterior
# stack the draws for each chain, creating a (n_draws, n_species) array
p_samps = post.p.to_numpy().flatten()
N_samps = post.N.to_numpy().flatten()
b_samps = post.b0.to_numpy().flatten()
# create the plot
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
# add the scatter for each species
alph = 0.2
ax0.scatter(p_samps, N_samps, s=5, alpha=alph)
ax0.spines.right.set_visible(False)
ax0.spines.top.set_visible(False)
ax0.set_ylabel(r'$N$')
ax0.set_xlabel(r'$p$')
ax1.scatter(b_samps, N_samps, s=5, alpha=alph)
ax1.set_xlabel(r'$b_0$')
fig.suptitle('Dipper JSSA posterior draws')
plt.show()
```
```{python}
%load_ext watermark
%watermark -n -u -v -iv -w
```