diff --git a/smcpp/model.py b/smcpp/model.py index 527fd2e..245c6a9 100644 --- a/smcpp/model.py +++ b/smcpp/model.py @@ -235,7 +235,16 @@ def for_pop(self, pid): return self.model1 else: assert i == 1 - return _concat_models(self.model1, self.model2, self.split, pid) + assert self.model1.N0 == self.model2.N0 + assert self.model1._spline_class is self.model2._spline_class + k1, k2 = [np.searchsorted(m.knots, self.split) for m in (self.model1, self.model2)] + kts = np.r_[self.model2.knots[:k2], [self.split], self.model1.knots[k1 + 1:]] + m = SMCModel(kts, self.model1.N0, self.model2._spline_class, self.model2.pid) + m[:k2] = self.model2[:k2] + m[k2] = ad.admath.log(self.model1(self.split).item()) + m[k2 + 1:] = self.model1[k1 + 1:] + return m + # return _concat_models(self.model1, self.model2, self.split) # Propagate changes from submodels up @targets('model update') @@ -338,7 +347,7 @@ def __setitem__(self, coords, x): self._models[a][cc] = x -def _concat_models(m1, m2, t, pid): +def _concat_models(m1, m2, t): if m1.N0 != m2.N0: raise RuntimeException() cs1 = np.cumsum(m1.s) diff --git a/smcpp/plotting.py b/smcpp/plotting.py index dee1847..3910844 100644 --- a/smcpp/plotting.py +++ b/smcpp/plotting.py @@ -85,12 +85,6 @@ def g(x, y, label, data=data, **kwargs): # if not logy: # y *= 1e-3 series.append([None, x2, y2, ax.scatter, off, m.N0, g]) - if split: - for i in 1, 2: - x = series[-i][1] - coords = x <= mb.split - for j in 1, 2: - series[-i][j] = series[-i][j][coords] else: x = np.cumsum(d['s']) x = np.insert(x, 0, 0)[:-1]