Skip to content

Commit

Permalink
Fix Stochastic Gradient Descent Example
Browse files Browse the repository at this point in the history
The example that is currently in the docs does not run. dtype, penalty, lrate, loss are not defined. This new version sets the default values for the parameters of cumlSGD, and copies Mini Batch SGD Regression's dtype for pred_data['col1'], pred_data['col2']. When running this example, I also got slightly different values for the output, so these were also updated.
  • Loading branch information
tylerjthomas9 authored Nov 13, 2020
1 parent 77da916 commit 9d2635b
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions python/cuml/solvers/sgd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ class SGD(Base):
import cudf
from cuml.solvers import SGD as cumlSGD
X = cudf.DataFrame()
X['col1'] = np.array([1,1,2,2], dtype = np.float32)
X['col2'] = np.array([1,2,2,3], dtype = np.float32)
X['col1'] = np.array([1,1,2,2], dtype=np.float32)
X['col2'] = np.array([1,2,2,3], dtype=np.float32)
y = cudf.Series(np.array([1, 1, 2, 2], dtype=np.float32))
pred_data = cudf.DataFrame()
pred_data['col1'] = np.asarray([3, 2], dtype=dtype)
pred_data['col2'] = np.asarray([5, 5], dtype=dtype)
cu_sgd = cumlSGD(learning_rate=lrate, eta0=0.005, epochs=2000,
pred_data['col1'] = np.asarray([3, 2], dtype=np.float32)
pred_data['col2'] = np.asarray([5, 5], dtype=np.float32)
cu_sgd = cumlSGD(learning_rate='constant', eta0=0.005, epochs=2000,
fit_intercept=True, batch_size=2,
tol=0.0, penalty=penalty, loss=loss)
tol=0.0, penalty='none', loss='squared_loss')
cu_sgd.fit(X, y)
cu_pred = cu_sgd.predict(pred_data).to_array()
print(" cuML intercept : ", cu_sgd.intercept_)
Expand All @@ -156,11 +156,11 @@ class SGD(Base):
.. code-block:: python
cuML intercept : 0.004561662673950195
cuML coef : 0 0.9834546
1 0.010128272
dtype: float32
cuML predictions : [3.0055666 2.0221121]
cuML intercept : 0.0041877031326293945
cuML coef : 0 0.984174
1 0.009776
dtype: float32
cuML predictions : [3.005588 2.0214138]
Parameters
Expand Down

0 comments on commit 9d2635b

Please sign in to comment.