Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix determinize issues #42

Merged
merged 2 commits into from
May 16, 2020
Merged

Conversation

qindazhu
Copy link
Collaborator

@qindazhu qindazhu commented May 16, 2020

Fixed some issues of Determinize. The alogrithm now can ouput FSA correctly.

Input:
ori_max

Output
new_max (2)

Noted that the weights_out and derivs_out are still not correct, I'll debug into it and make another PR to fix them.

Fsa b;
std::vector<float> b_arc_weights;
std::vector<std::vector<int32_t>> arc_derivs;
DeterminizePrunedMax(*max_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leave EXPECT statements empty for now, will add after fix issues of weights. please ignore this for now.

@@ -28,6 +28,7 @@ LogSumTracebackLink::LogSumTracebackLink(

int32_t GetMostRecentCommonAncestor(
std::unordered_set<LogSumTracebackState *> *cur_states) {
if (cur_states->size() == 1) return 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intention was that it return 0 in this case; even the documentation says that. What happened if it returned zero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's 0, there label will not be used in DetStateToCompact
https://github.com/danpovey/k2/blob/75c4cd1b1fdd2a20009997c269b905b438575cdb/k2/csrc/determinize.h#L745-L746

Different states with same seq_len will mapped to a same det_state.

Comment on lines +387 to +388
std::priority_queue<std::shared_ptr<DetState<TracebackState>>,
std::vector<std::shared_ptr<DetState<TracebackState>>>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the problem with unique_ptr here? I made it unique_ptr as it's more lightweight than shared_ptr (and this is only ever owned in one place, I believed.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

priority_queue.top() returns a const reference so we cannnot call std::move() on its returned value to construct any unqiue_ptr or shared_ptr.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see; shared_ptr is fine.

@danpovey
Copy link
Collaborator

danpovey commented May 16, 2020 via email

@qindazhu qindazhu force-pushed the haowen-determinize-test branch from 75c4cd1 to cfd8b79 Compare May 16, 2020 12:30
@qindazhu
Copy link
Collaborator Author

Oh, yes, it should be zero as I have changed the state_id in DetStateToCompact to input_state_id (it is originally zero for any unnormalized states).

RE add elem->state_id, Can you show some details with code? I really didn't get your idea.

Revert it back to zero.

b = d.state_id * 103979 + d.seq_len;
int32_t input_state_id = d.elements.begin()->first;

uint64_t a = input_state_id + 17489 * d.seq_len,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove input_state_id here, just have 17489 * d.seq_len and so on... and I'll say below..

b = d.state_id * 103979 + d.seq_len;
int32_t input_state_id = d.elements.begin()->first;

uint64_t a = input_state_id + 17489 * d.seq_len,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove input_state_id here, just have 17489 * d.seq_len and so on... and I'll say below..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, around line 776, insert:

    a += elem->state_id;   // This is `base_state`: the state from which we 
                                          // start (and accept the specified symbol sequence).

(we don't need it in b, actually).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, no, to reduce the chance of collisions, do instead:
a = elem->state_id + 14051 * a

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. I mean collisions on a, which is used for hashing. collisions should still be vanishingly rare if we take a and b together.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, add the base_state, got it, thanks!

@@ -781,6 +768,7 @@ class DetStateMap {
b = symbol + 102983 * b;
elem = elem->prev_elements[0].prev_state;
}
a += elem->state_id;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things: (1) after suggesting this I suggested a different version involving multiplying by a prime.. please use that one. And (2) in my original suggested code there was a comment; please include that!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry for the comments (have dinner just now..)

RE prime: seems there's notification sync issue on my chrome, will add, thanks

std::vector<float> b_arc_weights;
std::vector<std::vector<std::pair<int32_t, float>>> arc_derivs;
DeterminizePrunedLogSum(*log_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like, to have this code automatically:
(1) check that the result is deterministic,
(2) check that the result is equivalent to the original (in the appropriate semiring)

and if possible check that the arc_derivs make sense somehow, although that is more complex and can be left for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can merge this PR when you fix the other issue, though; you can work on this later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add all EXPECT statements after fix issues of weights (As I said before, the weights_out now are not correct)

@qindazhu qindazhu force-pushed the haowen-determinize-test branch from ba919b2 to d3b0ce8 Compare May 16, 2020 13:08
@danpovey danpovey merged commit 9ca4a69 into k2-fsa:master May 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants