-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG] Add Poincare model
- Loading branch information
Showing
27 changed files
with
112,130 additions
and
98 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
diff --git a/include/poincare_embedding.hpp b/include/poincare_embedding.hpp | ||
index f1f1c5f..e1d68cc 100644 | ||
--- a/include/poincare_embedding.hpp | ||
+++ b/include/poincare_embedding.hpp | ||
@@ -22,7 +22,7 @@ | ||
|
||
namespace poincare_disc{ | ||
|
||
- constexpr float EPS = 1e-6; | ||
+ // constexpr float EPS = 1e-6; | ||
|
||
/////////////////////////////////////////////////////////////////////////////////////////// | ||
// Initializer | ||
@@ -114,11 +114,12 @@ namespace poincare_disc{ | ||
return *this; | ||
} | ||
|
||
- Vector<real>& add_clip_(const real c, const Vector<real>& v, const real thresh=1.0-EPS) | ||
+ Vector<real>& add_clip_(const real c, const Vector<real>& v, const real eps) | ||
{ | ||
real uu = this->squared_sum(), uv = this->dot(v), vv = v.squared_sum(); | ||
real C = uu + 2*c*uv + c*c*vv; // resulting norm | ||
real scale = 1.0; | ||
+ real thresh = 1.0 - eps; | ||
if(C > thresh * thresh){ | ||
scale = thresh / sqrt(C); | ||
} | ||
@@ -132,7 +133,7 @@ namespace poincare_disc{ | ||
data_.get()[i] = (data_.get()[i] + c * v.data_.get()[i]) * scale; | ||
} | ||
} | ||
- assert(this->squared_sum() <= (thresh + EPS) * (thresh+EPS)); | ||
+ assert(this->squared_sum() <= (thresh + eps) * (thresh+eps)); | ||
return *this; | ||
} | ||
|
||
@@ -251,7 +252,7 @@ namespace poincare_disc{ | ||
using real = RealType; | ||
public: | ||
Distance(): u_(), v_(), uu_(), vv_(), uv_(), alpha_(), beta_(), gamma_() {} | ||
- real operator()(const Vector<real>& u, const Vector<real>& v) | ||
+ real operator()(const Vector<real>& u, const Vector<real>& v, real eps) | ||
{ | ||
u_ = u; | ||
v_ = v; | ||
@@ -259,11 +260,11 @@ namespace poincare_disc{ | ||
vv_ = v_.squared_sum(); | ||
uv_ = u_.dot(v_); | ||
alpha_ = 1 - uu_; | ||
- if(alpha_ <= 0){ alpha_ = EPS; } // TODO: ensure 0 <= uu_ <= 1-EPS; | ||
+ if(alpha_ <= 0){ alpha_ = eps; } // TODO: ensure 0 <= uu_ <= 1-eps; | ||
// if(!(alpha_ > 0)){ std::cout << "uu_: " << uu_ << ", alpha_: " << alpha_ << std::endl; } | ||
// assert(alpha_ > 0); | ||
beta_ = 1 - vv_; | ||
- if(beta_ <= 0){ beta_ = EPS; } // TODO: ensure 0 <= vv_ <= 1-EPS; | ||
+ if(beta_ <= 0){ beta_ = eps; } // TODO: ensure 0 <= vv_ <= 1-eps; | ||
// if(!(beta_ > 0)){ std::cout << "vv_: " << vv_ << ", beta_: " << beta_ << std::endl; } | ||
// assert(beta_ > 0); | ||
gamma_ = 1 + 2 * (uu_ - 2 * uv_ + vv_) / alpha_ / beta_; | ||
@@ -406,11 +407,14 @@ namespace poincare_disc{ | ||
char delim = '\t'; | ||
real lr0 = 0.01; // learning rate | ||
real lr1 = 0.0001; // learning rate | ||
+ real eps = 1e-6; // epsilon for numerical stability in clipping | ||
+ std::size_t burn_in = 0; // whether to use burn-in for initialization | ||
}; | ||
|
||
template <class RealType> | ||
- void clip(Vector<RealType>& v, const RealType& thresh = 1-EPS) | ||
+ void clip(Vector<RealType>& v, const RealType& eps) | ||
{ | ||
+ RealType thresh = 1.0 - eps; | ||
RealType vv = v.squared_sum(); | ||
if(vv >= thresh*thresh){ | ||
v.mult_(thresh / std::sqrt(vv)); | ||
@@ -500,7 +504,7 @@ namespace poincare_disc{ | ||
|
||
// clip | ||
for(std::size_t i = 0, I = embeddings.nrow(); i < I; ++i){ | ||
- clip(embeddings[i]); | ||
+ clip(embeddings[i], config.eps); | ||
} | ||
|
||
// construct negative sampler | ||
@@ -551,7 +555,7 @@ namespace poincare_disc{ | ||
auto i = left_indices[0] = itr->first; | ||
auto j = right_indices[0] = itr->second; | ||
|
||
- exp_neg_dist_values[0] = std::exp(-dists[0](embeddings[i], embeddings[j])); | ||
+ exp_neg_dist_values[0] = std::exp(-dists[0](embeddings[i], embeddings[j], config.eps)); | ||
for(std::size_t k = 0; k < config.neg_size; ++k){ | ||
#if SAMPLING_STRATEGY == LEFT_SAMPLING | ||
auto i = left_indices[k + 1] = negative_sampler(); | ||
@@ -563,7 +567,7 @@ namespace poincare_disc{ | ||
auto i = left_indices[k + 1] = negative_sampler(); | ||
auto j = right_indices[k + 1] = negative_sampler(); | ||
#endif | ||
- exp_neg_dist_values[k + 1] = std::exp(-dists[k + 1](embeddings[i], embeddings[j])); | ||
+ exp_neg_dist_values[k + 1] = std::exp(-dists[k + 1](embeddings[i], embeddings[j], config.eps)); | ||
} | ||
|
||
// compute gradient | ||
@@ -589,8 +593,8 @@ namespace poincare_disc{ | ||
// update | ||
for(std::size_t k = 0; k < 1 + config.neg_size; ++k){ | ||
auto i = left_indices[k], j = right_indices[k]; | ||
- embeddings[i].add_clip_(-lr(), left_grads[k]); | ||
- embeddings[j].add_clip_(-lr(), right_grads[k]); | ||
+ embeddings[i].add_clip_(-lr(), left_grads[k], config.eps); | ||
+ embeddings[j].add_clip_(-lr(), right_grads[k], config.eps); | ||
} | ||
|
||
lr.update(); | ||
@@ -634,6 +638,22 @@ namespace poincare_disc{ | ||
|
||
std::cout << "embedding size: " << embeddings.nrow() << " x " << embeddings.ncol() << std::endl; | ||
|
||
+ if(config.burn_in > 0){ | ||
+ std::cout << "burn in start" << std::endl; | ||
+ // burn in | ||
+ LinearLearningRate<real> lr_init(config.lr0 / 10., config.lr1 / 10., data_size * config.burn_in); | ||
+ for(std::size_t epoch = 0; epoch < config.burn_in; ++epoch){ | ||
+ std::cout << "epoch " << epoch+1 << "/" << config.max_epoch << " start" << std::endl; | ||
+ // std::cout << "random shuffle data" << std::endl; | ||
+ std::random_shuffle(data.begin(), data.end()); | ||
+ // single thread | ||
+ const unsigned int thread_seed = engine(); | ||
+ train_thread(embeddings, dict.counts(), data.begin(), data.end(), config, lr_init, 0, thread_seed); | ||
+ } | ||
+ | ||
+ std::cout << "burn in end" << std::endl; | ||
+ } | ||
+ | ||
// fit | ||
LinearLearningRate<real> lr(config.lr0, config.lr1, data_size * config.max_epoch); | ||
std::vector<std::pair<std::size_t, std::size_t> > fake_pairs(config.neg_size); | ||
diff --git a/src/poincare_embedding.cpp b/src/poincare_embedding.cpp | ||
index 7eb7293..a67626e 100644 | ||
--- a/src/poincare_embedding.cpp | ||
+++ b/src/poincare_embedding.cpp | ||
@@ -35,6 +35,8 @@ struct Arguments | ||
real uniform_range = 0.001; | ||
real lr0 = 0.01; | ||
real lr1 = 0.0001; | ||
+ real eps = 1e-6; // epsilon for numerical stability in clipping | ||
+ std::size_t burn_in = 0; // number of epochs for burn-in initialization | ||
}; | ||
|
||
Arguments parse_args(int narg, char** argv) | ||
@@ -92,7 +94,20 @@ Arguments parse_args(int narg, char** argv) | ||
if( x <= 0 ){ goto HELP; } | ||
result.uniform_range = static_cast<real>(x); | ||
continue; | ||
- }else if(arg == "-h" || arg == "--help"){ | ||
+ }else if(arg == "-E" || arg == "--epsilon"){ | ||
+ arg = argv[++i]; | ||
+ double x = std::stod(arg); | ||
+ if( x <= 0 ){ goto HELP; } | ||
+ result.eps = static_cast<real>(x); | ||
+ continue; | ||
+ }else if(arg == "-b" || arg == "--burn_in"){ | ||
+ arg = argv[++i]; | ||
+ int n = std::stoi(arg); | ||
+ if( n < 0 ){ goto HELP; } | ||
+ result.burn_in = static_cast<std::size_t>(n); | ||
+ continue; | ||
+ } | ||
+ else if(arg == "-h" || arg == "--help"){ | ||
goto HELP; | ||
} | ||
|
||
@@ -134,6 +149,8 @@ Arguments parse_args(int narg, char** argv) | ||
" -u, --uniform_range : float > 0 embedding uniform initializer range\n" | ||
" -l, --learning_rate_init : float > 0 initial learning rate\n" | ||
" -L, --learning_rate_final : float > 0 final learning rate\n" | ||
+ " -E, --epsilon : float > 0 epsilon for clipping vectors\n" | ||
+ " -b, --burn_in : int >= 0 number of epochs for burn-in (0 => no burn-in)\n" | ||
<< std::endl; | ||
exit(0); | ||
} | ||
@@ -159,6 +176,8 @@ int main(int narg, char** argv) | ||
config.dim = args.dim; | ||
config.lr0 = args.lr0; | ||
config.lr1 = args.lr1; | ||
+ config.eps = args.eps; | ||
+ config.burn_in = args.burn_in; | ||
config.initializer = UniformInitializer<real>(-args.uniform_range, args.uniform_range); | ||
|
||
std::cout << "settings:" << "\n" | ||
@@ -172,6 +191,7 @@ int main(int narg, char** argv) | ||
<< " " << "lr0 : " << config.lr0 << "\n" | ||
<< " " << "lr1 : " << config.lr1 << "\n" | ||
<< " " << "uniform_range : " << args.uniform_range << "\n" | ||
+ << " " << "epsilon : " << args.eps << "\n" | ||
<< std::endl; | ||
|
||
|
Oops, something went wrong.