Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace AuxLabels with Array2 #57

Merged
merged 1 commit into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions k2/csrc/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,18 @@ struct StridedPtr {
: data(data), stride(stride) {}
StridedPtr(const StridedPtr &other)
: data(other.data), stride(other.stride) {}
StridedPtr &operator=(const StridedPtr &other) {
StridedPtr tmp(other);
std::swap(data, tmp.data);
std::swap(stride, tmp.stride);
StridedPtr(StridedPtr &&other) { this->Swap(other); }
StridedPtr &operator=(StridedPtr other) {
this->Swap(other);
return *this;
}
bool operator==(const StridedPtr &other) const {
return data == other.data && stride == other.stride;
}
void Swap(StridedPtr &other) {
std::swap(data, other.data);
std::swap(stride, other.stride);
}
};

/*
Expand All @@ -69,6 +72,10 @@ struct Array2 {
using PtrT = Ptr;
using ValueType = typename std::iterator_traits<Ptr>::value_type;

Array2() = default;
Array2(IndexT size1, IndexT *indexes, IndexT size2, PtrT data)
: size1(size1), indexes(indexes), size2(size2), data(data) {}

IndexT size1;
IndexT *indexes; // indexes[0,1,...size1] should be defined; note, this
// means the array must be of at least size1+1. We
Expand All @@ -85,6 +92,17 @@ struct Array2 {

bool Empty() const { return size1 == 0; }

// just to replace `Swap` functions for Fsa and AuxLabels for now,
// may delete it if we finally find that we don't need to call it.
void Swap(Array2 &other) {
std::swap(size1, other.size1);
std::swap(size2, other.size2);
std::swap(indexes, other.indexes);
// it's OK here for Ptr=StridedPtr as we have specialized
// std::swap for StridedPtr
std::swap(data, other.data);
}

/* initialized definition:
An Array2 object is initialized if its `size1` member and `size2` member
are set and its `indexes` and `data` pointer allocated, and the values of
Expand Down Expand Up @@ -197,10 +215,21 @@ struct Array2Storage {

namespace std {
template <typename T, typename I>

struct iterator_traits<k2::StridedPtr<T, I>> {
typedef T value_type;
};

template <typename T, typename I>
void swap(k2::StridedPtr<T, I> &lhs, k2::StridedPtr<T, I> &rhs) {
lhs.Swap(rhs);
}

template <typename Ptr, typename I>
void swap(k2::Array2<Ptr, I> &lhs, k2::Array2<Ptr, I> &rhs) {
lhs.Swap(rhs);
}

} // namespace std

#endif // K2_CSRC_ARRAY_H_
188 changes: 107 additions & 81 deletions k2/csrc/aux_labels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ static void CountExtraStates(const k2::Fsa &fsa_in,
auto &states = *num_extra_states;
for (int32_t i = 0; i != fsa_in.arcs.size(); ++i) {
const auto &arc = fsa_in.arcs[i];
int32_t pos_start = labels_in.start_pos[i];
int32_t pos_end = labels_in.start_pos[i + 1];
states[arc.dest_state] += std::max(0, pos_end - pos_start - 1);
int32_t begin = labels_in.indexes[i];
int32_t end = labels_in.indexes[i + 1];
states[arc.dest_state] += std::max(0, end - begin - 1);
}
}

Expand Down Expand Up @@ -87,129 +87,152 @@ static void MapStates(const std::vector<int32_t> &num_extra_states,

namespace k2 {

void Swap(AuxLabels *labels1, AuxLabels *labels2) {
CHECK_NOTNULL(labels1);
CHECK_NOTNULL(labels2);
std::swap(labels1->start_pos, labels2->start_pos);
std::swap(labels1->labels, labels2->labels);
void AuxLabels1Mapper::GetSizes(Array2Size<int32_t> *aux_size) {
CHECK_NOTNULL(aux_size);
aux_size->size1 = arc_map_.size();
int32_t num_labels = 0;
for (const auto &arc_index : arc_map_) {
int32_t begin = labels_in_.indexes[arc_index];
int32_t end = labels_in_.indexes[arc_index + 1];
num_labels += end - begin;
}
aux_size->size2 = num_labels;
}

void MapAuxLabels1(const AuxLabels &labels_in,
const std::vector<int32_t> &arc_map, AuxLabels *labels_out) {
void AuxLabels1Mapper::GetOutput(AuxLabels *labels_out) {
CHECK_NOTNULL(labels_out);
auto &start_pos = labels_out->start_pos;
auto &labels = labels_out->labels;
start_pos.clear();
start_pos.reserve(arc_map.size() + 1);
labels.clear();
auto &start_pos = labels_out->indexes;
auto &labels = labels_out->data;
int32_t num_labels = 0;
int32_t i = 0;
for (; i != arc_map_.size(); ++i) {
start_pos[i] = num_labels;
const auto arc_index = arc_map_[i];
int32_t begin = labels_in_.indexes[arc_index];
int32_t end = labels_in_.indexes[arc_index + 1];
for (auto it = begin; it != end; ++it) {
labels[num_labels++] = labels_in_.data[it];
}
}
start_pos[i] = num_labels;
}

void AuxLabels2Mapper::GetSizes(Array2Size<int32_t> *aux_size) {
CHECK_NOTNULL(aux_size);
aux_size->size1 = arc_map_.size();
int32_t num_labels = 0;
auto labels_in_iter_begin = labels_in.labels.begin();
for (const auto &arc_index : arc_map) {
start_pos.push_back(num_labels);
int32_t pos_start = labels_in.start_pos[arc_index];
int32_t pos_end = labels_in.start_pos[arc_index + 1];
labels.insert(labels.end(), labels_in_iter_begin + pos_start,
labels_in_iter_begin + pos_end);
num_labels += pos_end - pos_start;
for (const auto &arc_indexes : arc_map_) {
for (const auto &arc_index : arc_indexes) {
int32_t begin = labels_in_.indexes[arc_index];
int32_t end = labels_in_.indexes[arc_index + 1];
num_labels += end - begin;
}
}
start_pos.push_back(num_labels);
aux_size->size2 = num_labels;
}

void MapAuxLabels2(const AuxLabels &labels_in,
const std::vector<std::vector<int32_t>> &arc_map,
AuxLabels *labels_out) {
void AuxLabels2Mapper::GetOutput(AuxLabels *labels_out) {
CHECK_NOTNULL(labels_out);
auto &start_pos = labels_out->start_pos;
auto &labels = labels_out->labels;
start_pos.clear();
start_pos.reserve(arc_map.size() + 1);
labels.clear();

auto &start_pos = labels_out->indexes;
auto &labels = labels_out->data;
int32_t num_labels = 0;
auto labels_in_iter_begin = labels_in.labels.begin();
for (const auto &arc_indexes : arc_map) {
start_pos.push_back(num_labels);
for (const auto &arc_index : arc_indexes) {
int32_t pos_start = labels_in.start_pos[arc_index];
int32_t pos_end = labels_in.start_pos[arc_index + 1];
labels.insert(labels.end(), labels_in_iter_begin + pos_start,
labels_in_iter_begin + pos_end);
num_labels += pos_end - pos_start;
int32_t i = 0;
for (; i != arc_map_.size(); ++i) {
start_pos[i] = num_labels;
for (const auto &arc_index : arc_map_[i]) {
int32_t begin = labels_in_.indexes[arc_index];
int32_t end = labels_in_.indexes[arc_index + 1];
for (auto it = begin; it != end; ++it) {
labels[num_labels++] = labels_in_.data[it];
}
}
}
start_pos.push_back(num_labels);
start_pos[i] = num_labels;
}

void FstInverter::GetSizes(Array2Size<int32_t> *fsa_size,
Array2Size<int32_t> *aux_size) {
CHECK_NOTNULL(fsa_size);
CHECK_NOTNULL(aux_size);
int32_t num_extra_states = 0;
int32_t num_arcs = 0;
int32_t num_non_eps_labels = 0;
for (int32_t i = 0; i != fsa_in_.arcs.size(); ++i) {
const auto &arc = fsa_in_.arcs[i];
int32_t begin = labels_in_.indexes[i];
int32_t end = labels_in_.indexes[i + 1];
num_extra_states += std::max(0, end - begin - 1);
num_arcs += std::max(1, end - begin);
if (arc.label != kEpsilon) ++num_non_eps_labels;
}
fsa_size->size1 = num_extra_states + fsa_in_.NumStates();
fsa_size->size2 = num_arcs;
aux_size->size1 = num_arcs;
aux_size->size2 = num_non_eps_labels;
}

void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out,
AuxLabels *aux_labels_out) {
void FstInverter::GetOutput(Fsa *fsa_out, AuxLabels *labels_out) {
CHECK_NOTNULL(fsa_out);
CHECK_NOTNULL(aux_labels_out);
fsa_out->arc_indexes.clear();
fsa_out->arcs.clear();
aux_labels_out->start_pos.clear();
aux_labels_out->labels.clear();

if (IsEmpty(fsa_in)) {
aux_labels_out->start_pos.push_back(0);
CHECK_NOTNULL(labels_out);

if (IsEmpty(fsa_in_)) {
labels_out->indexes[0] = 0;
return;
}

auto num_states_in = fsa_in.NumStates();
auto num_states_in = fsa_in_.NumStates();
// get the number of extra states we need to create for each state
// in fsa_in when inverting
std::vector<int32_t> num_extra_states(num_states_in, 0);
CountExtraStates(fsa_in, labels_in, &num_extra_states);
CountExtraStates(fsa_in_, labels_in_, &num_extra_states);

// map state in fsa_in to state in fsa_out
std::vector<int32_t> state_map(num_states_in, 0);
std::vector<int32_t> state_ids(num_states_in, 0);
MapStates(num_extra_states, &state_map, &state_ids);

// a maximal approximation
int32_t num_arcs_out = labels_in.labels.size() + fsa_in.arcs.size();
// TODO(haowen): replace with fsa_out->size2
std::vector<Arc> arcs;
arcs.reserve(num_arcs_out);
// `+1` for the end position of the last arc's olabel sequence
arcs.reserve(labels_out->size1);
std::vector<int32_t> start_pos;
start_pos.reserve(num_arcs_out + 1);
start_pos.reserve(labels_out->size1 + 1);
std::vector<int32_t> labels;
labels.reserve(fsa_in.arcs.size());
int32_t final_state_in = fsa_in.FinalState();
labels.reserve(labels_out->size2);
int32_t final_state_in = fsa_in_.FinalState();

int32_t num_non_eps_ilabel_processed = 0;
start_pos.push_back(0);
for (auto i = 0; i != fsa_in.arcs.size(); ++i) {
const auto &arc = fsa_in.arcs[i];
int32_t pos_start = labels_in.start_pos[i];
int32_t pos_end = labels_in.start_pos[i + 1];
for (auto i = 0; i != fsa_in_.arcs.size(); ++i) {
const auto &arc = fsa_in_.arcs[i];
int32_t pos_begin = labels_in_.indexes[i];
int32_t pos_end = labels_in_.indexes[i + 1];
int32_t src_state = arc.src_state;
int32_t dest_state = arc.dest_state;
if (dest_state == final_state_in) {
// every arc entering the final state must have exactly
// one olabel == kFinalSymbol
CHECK_EQ(pos_start + 1, pos_end);
CHECK_EQ(labels_in.labels[pos_start], kFinalSymbol);
CHECK_EQ(pos_begin + 1, pos_end);
CHECK_EQ(labels_in_.data[pos_begin], kFinalSymbol);
}
if (pos_end - pos_start <= 1) {
if (pos_end - pos_begin <= 1) {
int32_t curr_label =
(pos_end - pos_start == 0) ? kEpsilon : labels_in.labels[pos_start];
(pos_end - pos_begin == 0) ? kEpsilon : labels_in_.data[pos_begin];
arcs.emplace_back(state_map[src_state], state_map[dest_state],
curr_label);
} else {
// expand arcs with olabels
arcs.emplace_back(state_map[src_state], state_ids[dest_state] + 1,
labels_in.labels[pos_start]);
labels_in_.data[pos_begin]);
start_pos.push_back(num_non_eps_ilabel_processed);
for (int32_t pos = pos_start + 1; pos < pos_end - 1; ++pos) {
for (int32_t pos = pos_begin + 1; pos < pos_end - 1; ++pos) {
++state_ids[dest_state];
arcs.emplace_back(state_ids[dest_state], state_ids[dest_state] + 1,
labels_in.labels[pos]);
labels_in_.data[pos]);
start_pos.push_back(num_non_eps_ilabel_processed);
}
++state_ids[dest_state];
arcs.emplace_back(state_ids[dest_state], state_map[arc.dest_state],
labels_in.labels[pos_end - 1]);
labels_in_.data[pos_end - 1]);
}
// push non-epsilon ilabel in fsa_in as olabel of fsa_out
if (arc.label != kEpsilon) {
Expand All @@ -219,15 +242,18 @@ void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out,
start_pos.push_back(num_non_eps_ilabel_processed);
}

labels.resize(labels.size());
arcs.resize(arcs.size());
start_pos.resize(start_pos.size());
// any failure indicates there are some errors
// TODO(haowen): replace with fsa_out->size2
CHECK_EQ(arcs.size(), labels_out->size1);
CHECK_EQ(start_pos.size(), labels_out->size1 + 1);
CHECK_EQ(labels.size(), labels_out->size2);

std::vector<int32_t> arc_map;
ReorderArcs(arcs, fsa_out, &arc_map);
AuxLabels labels_tmp;
labels_tmp.start_pos = std::move(start_pos);
labels_tmp.labels = std::move(labels);
MapAuxLabels1(labels_tmp, arc_map, aux_labels_out);
AuxLabels labels_tmp(labels_out->size1, start_pos.data(), labels_out->size2,
labels.data());
AuxLabels1Mapper aux_mapper(labels_tmp, arc_map);
// don't need to call `GetSizes` here as `labels_out` has been initialized
aux_mapper.GetOutput(labels_out);
}
} // namespace k2
Loading