-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test: more gradient optimizer tests #1217
Merged
Oceania2018
merged 6 commits into
SciSharp:master
from
novikov-alexander:alnovi/gradient_more_tests
Jun 19, 2024
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
d54f7a6
test: more gradients tests
novikov-alexander 43f43eb
Merge branch 'SciSharp:master' into alnovi/gradient_more_tests
novikov-alexander b3ce158
Update tensor_util.cs
novikov-alexander 18db147
Update GradientDescentOptimizerTests.cs
novikov-alexander b21a58a
Merge branch 'SciSharp:master' into alnovi/gradient_more_tests
novikov-alexander 483ac82
Update tensor_util.cs
novikov-alexander File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
using Microsoft.VisualStudio.TestTools.UnitTesting; | ||
using Microsoft.VisualStudio.TestTools.UnitTesting; | ||
using System; | ||
using System.Linq; | ||
using Tensorflow; | ||
using Tensorflow.NumPy; | ||
using static Tensorflow.Binding; | ||
|
@@ -67,6 +68,51 @@ public void TestBasic() | |
TestBasic<double>(); | ||
} | ||
|
||
private void TestMinimizeResourceVariable<T>() where T : struct | ||
{ | ||
var dtype = GetTypeForNumericType<T>(); | ||
|
||
// train.GradientDescentOptimizer is V1 only API. | ||
tf.Graph().as_default(); | ||
using (var sess = self.cached_session()) | ||
{ | ||
var var0 = tf.Variable(new[,] { { 1.0f, 2.0f } }, dtype: dtype); | ||
var var1 = tf.Variable(new[] { 3.0 }, dtype: dtype); | ||
var x = tf.constant(new[,] { { 4.0f }, { 5.0f } }, dtype: dtype); | ||
|
||
var pred = math_ops.matmul(var0, x) + var1; | ||
var loss = pred * pred; | ||
var sgd_op = tf.train.GradientDescentOptimizer(1.0f).minimize(loss); | ||
|
||
var global_variables = tf.global_variables_initializer(); | ||
sess.run(global_variables); | ||
|
||
sess.run(new[] { var0, var1 }); | ||
// Fetch params to validate initial values | ||
self.assertAllCloseAccordingToType<T>(new[,] { { 1.0, 2.0 } }, self.evaluate<T[,]>(var0)); | ||
self.assertAllCloseAccordingToType(new[] { 3.0 }, self.evaluate<T[]>(var1)); | ||
// Run 1 step of sgd | ||
sgd_op.run(); | ||
// Validate updated params | ||
var np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0; | ||
var np_grad = 2 * np_pred; | ||
self.assertAllCloseAccordingToType( | ||
new[,] { { 1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0 } }, | ||
self.evaluate<T[,]>(var0)); | ||
self.assertAllCloseAccordingToType( | ||
new[] { 3.0 - np_grad }, | ||
self.evaluate<T[]>(var1)); | ||
} | ||
} | ||
|
||
[TestMethod] | ||
public void TestMinimizeResourceVariable() | ||
{ | ||
//TODO: add np.half | ||
TestMinimizeResourceVariable<float>(); | ||
TestMinimizeResourceVariable<double>(); | ||
} | ||
|
||
private void TestTensorLearningRate<T>() where T : struct | ||
{ | ||
var dtype = GetTypeForNumericType<T>(); | ||
|
@@ -115,5 +161,72 @@ public void TestTensorLearningRate() | |
TestTensorLearningRate<float>(); | ||
TestTensorLearningRate<double>(); | ||
} | ||
|
||
public void TestGradWrtRef<T>() where T : struct | ||
{ | ||
var dtype = GetTypeForNumericType<T>(); | ||
|
||
var graph = tf.Graph().as_default(); | ||
using (var sess = self.cached_session()) | ||
{ | ||
var opt = tf.train.GradientDescentOptimizer(3.0f); | ||
var values = new[] { 1.0, 3.0 }; | ||
var vars_ = values.Select( | ||
v => tf.Variable(new[] { v }, dtype: dtype) as IVariableV1 | ||
).ToList(); | ||
var grads_and_vars = opt.compute_gradients(tf.add(vars_[0], vars_[1]), vars_); | ||
sess.run(tf.global_variables_initializer()); | ||
foreach (var (grad, _) in grads_and_vars) | ||
self.assertAllCloseAccordingToType(new[] { 1.0 }, self.evaluate<T[]>(grad)); | ||
|
||
} | ||
} | ||
|
||
[TestMethod] | ||
public void TestGradWrtRef() | ||
{ | ||
TestGradWrtRef<float>(); | ||
TestGradWrtRef<double>(); | ||
} | ||
|
||
public void TestWithGlobalStep<T>() where T : struct | ||
{ | ||
var dtype = GetTypeForNumericType<T>(); | ||
|
||
tf.Graph().as_default(); | ||
using (var sess = self.cached_session()) | ||
{ | ||
var global_step = tf.Variable(0, trainable: false); | ||
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); | ||
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); | ||
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype); | ||
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype); | ||
var grads_and_vars = new[] { | ||
Tuple.Create(grads0, var0 as IVariableV1), | ||
Tuple.Create(grads1, var1 as IVariableV1) | ||
}; | ||
var sgd_op = tf.train.GradientDescentOptimizer(3.0f) | ||
.apply_gradients(grads_and_vars, global_step: global_step); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AsakusaRinne why does |
||
sess.run(tf.global_variables_initializer()); | ||
// Fetch params to validate initial values | ||
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0)); | ||
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1)); | ||
// Run 1 step of sgd | ||
sgd_op.run(); | ||
// Validate updated params and global_step | ||
self.assertAllCloseAccordingToType(new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, self.evaluate<T[]>(var0)); | ||
self.assertAllCloseAccordingToType(new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, self.evaluate<T[]>(var1)); | ||
Assert.AreEqual(1, self.evaluate<int>(global_step)); | ||
} | ||
|
||
} | ||
|
||
[TestMethod] | ||
public void TestWithGlobalStep() | ||
{ | ||
TestWithGlobalStep<float>(); | ||
TestWithGlobalStep<double>(); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Wanglongzhi2001 Hm, now it calculates but the test doesn't pass. However, the code corresponds to TensorFlow original test. I have to check math there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There just was a small typo, but I didn't have time to debug it :-D