Skip to content

Commit

Permalink
Merge pull request #1 from andre-martins/master
Browse files Browse the repository at this point in the history
Fix sequence budget factor.
  • Loading branch information
nunonmg authored Feb 24, 2021
2 parents 40371e5 + 5df0c2f commit ed96c93
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 38 deletions.
48 changes: 19 additions & 29 deletions lpsmap/ad3ext/FactorSequenceBudget.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,9 @@ class FactorSequenceBudget : public GenericFactor {
values[0].resize(num_states);
path[0].resize(num_states);
for (int l = 0; l < num_states; ++l) {
if (l == 0) {
// The state that counts for the budget.
values[0][l].resize(2);
path[0][l].resize(2);
} else {
values[0][l].resize(1);
path[0][l].resize(1);
}
values[0][l].resize(1);
path[0][l].resize(1);
int bin = 0;
if (l == 0) ++bin;
values[0][l][bin] =
GetNodeScore(
0, l, variable_log_potentials,additional_log_potentials)
Expand All @@ -117,23 +110,18 @@ class FactorSequenceBudget : public GenericFactor {
path[i+1].resize(num_states);
for (int k = 0; k < num_states; ++k) {
int num_bins = (budget_ < i+1)? budget_+1 : i+2;
if (k == 0) {
// The state that counts for the budget.
// k == 0 and b = budget_ not allowed.
if (num_bins == budget_+1) --num_bins;
values[i+1][k].resize(num_bins+1);
path[i+1][k].resize(num_bins+1);
} else {
values[i+1][k].resize(num_bins);
path[i+1][k].resize(num_bins);
}
values[i+1][k].resize(num_bins);
path[i+1][k].resize(num_bins);
for (int b = 0; b < num_bins; ++b) {
double best_value = -std::numeric_limits<double>::infinity();
int best = -1;
for (int l = 0; l < num_states_[i]; ++l) {
if (l == 0 && b == 0) continue;
if (i > 0 && path[i][l][b] < 0) continue;
double val = values[i][l][b] +
int bin = b;
if (l == 0) --bin; // The state that counts for the budget.
if (bin < 0) continue;
if (bin >= path[i][l].size()) continue;
if (i > 0 && path[i][l][bin] < 0) continue;
double val = values[i][l][bin] +
GetEdgeScore(i+1, l, k, variable_log_potentials,
additional_log_potentials);
if (best < 0 || val > best_value) {
Expand All @@ -142,12 +130,11 @@ class FactorSequenceBudget : public GenericFactor {
}
}
int bin = b;
if (k == 0) ++bin;
values[i+1][k][bin] = best_value +
GetNodeScore(i+1, k, variable_log_potentials,
additional_log_potentials);
path[i+1][k][bin] = best;
//cout << "path[" << i+1 << "][" << k << "][" << bin << "] = " << best << endl;
assert(best >= 0);
}
}
}
Expand All @@ -159,9 +146,12 @@ class FactorSequenceBudget : public GenericFactor {
int num_bins = (budget_ < length)? budget_+1 : length+1;
for (int b = 0; b < num_bins; ++b) {
for (int l = 0; l < num_states_[length - 1]; ++l) {
if (l == 0 && b == 0) continue;
if (length > 1 && path[length-1][l][b] < 0) continue;
double val = values[length-1][l][b] +
int bin = b;
if (l == 0) --bin;
if (bin < 0) continue;
if (bin >= path[length-1][l].size()) continue;
if (length > 1 && path[length-1][l][bin] < 0) continue;
double val = values[length-1][l][bin] +
GetEdgeScore(length, l, 0, variable_log_potentials,
additional_log_potentials);
if (best < 0 || val > best_value) {
Expand All @@ -180,8 +170,8 @@ class FactorSequenceBudget : public GenericFactor {
for (int i = length - 1; i > 0; --i) {
//cout << "sequence[" << i << "] = " << (*sequence)[i] << endl;
//cout << b << endl;
(*sequence)[i - 1] = path[i][(*sequence)[i]][b];
if ((*sequence)[i] == 0) --b;
(*sequence)[i - 1] = path[i][(*sequence)[i]][b];
}
*value = best_value;
}
Expand Down Expand Up @@ -341,4 +331,4 @@ class FactorSequenceBudget : public GenericFactor {

} // namespace AD3

#endif // FACTOR_SEQUENCE_BUDGET
#endif // FACTOR_SEQUENCE_BUDGET
4 changes: 3 additions & 1 deletion lpsmap/ad3ext/sequencebudget.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 10 additions & 8 deletions lpsmap/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@ def main():
transition = torch.zeros((n+1,2,2), requires_grad=True)
transition[:, 1, 1] = 1./temperature

fg = TorchFactorGraph()
u = fg.variable_from(x/temperature)
for budget in range(12):
fg = TorchFactorGraph()
u = fg.variable_from(x/temperature)

fg.add(SequenceBudget(u, transition, 2))
#fg.add(Sequence(u, transition))
#fg.add(Budget(u, 2))
fg.solve()
fg.add(SequenceBudget(u, transition, budget))
#fg.add(Sequence(u, transition))
#fg.add(Budget(u, 2))
fg.solve()

print("solution: \n", u.value[:,0])
print("solution: \n", u.value[:,0])
print(sum(u.value[:,0]))

if __name__ == '__main__':
main()
main()
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def install_for_development(self):
Extension('lpsmap.ad3qp.base', ["lpsmap/ad3qp/base.pyx"]),
Extension('lpsmap.ad3ext.sequence',
["lpsmap/ad3ext/sequence.pyx"]),
Extension('lpsmap.ad3ext.sequencebudget',
["lpsmap/ad3ext/sequencebudget.pyx"]),
Extension('lpsmap.ad3ext.tree',
["lpsmap/ad3ext/tree.pyx",
"lpsmap/ad3ext/DependencyDecoder.cpp"]),
Expand Down

0 comments on commit ed96c93

Please sign in to comment.