-
Notifications
You must be signed in to change notification settings - Fork 0
/
sklearntest.py
78 lines (64 loc) · 2.46 KB
/
sklearntest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import matplotlib.pyplot as plt
# unused but required import for doing 3d projections with matplotlib < 3.2
import mpl_toolkits.mplot3d # noqa: F401
from matplotlib import ticker
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
import numpy as np
from PIL import Image
from sklearn import datasets, manifold
n_samples = 1500
S_points, S_color = datasets.make_s_curve(n_samples, random_state=0)
def plot_3d(points, points_color, title):
x, y, z = points.T
fig, ax = plt.subplots(
figsize=(6, 6),
facecolor="white",
tight_layout=True,
subplot_kw={"projection": "3d"},
)
fig.suptitle(title, size=16)
col = ax.scatter(x, y, z, c=points_color, s=50, alpha=0.8)
canvas = FigureCanvasAgg(fig)
ax.view_init(azim=-60, elev=9)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.zaxis.set_major_locator(ticker.MultipleLocator(1))
fig.colorbar(col, ax=ax, orientation="horizontal", shrink=0.6, aspect=60, pad=0.01)
canvas.draw()
buf = canvas.buffer_rgba()
X = np.asarray(buf)
im = Image.fromarray(X)
im.save("plot1.png")
def plot_2d(points, points_color, title):
fig, ax = plt.subplots(figsize=(3, 3), facecolor="white", constrained_layout=True)
fig.suptitle(title, size=16)
# make a Figure and attach it to a canvas.
fig, ax = plt.subplots(figsize=(3, 3), facecolor="white", constrained_layout=True)
canvas = FigureCanvasAgg(fig)
add_2d_scatter(ax, points, points_color)
# Retrieve a view on the renderer buffer
canvas.draw()
buf = canvas.buffer_rgba()
# convert to a NumPy array
X = np.asarray(buf)
im = Image.fromarray(X)
im.save("plot2.png")
def add_2d_scatter(ax, points, points_color, title=None):
x, y = points.T
ax.scatter(x, y, c=points_color, s=50, alpha=0.8)
ax.set_title(title)
ax.xaxis.set_major_formatter(ticker.NullFormatter())
ax.yaxis.set_major_formatter(ticker.NullFormatter())
plot_3d(S_points, S_color, "Original S-curve samples")
n_neighbors = 12 # neighborhood which is used to recover the locally linear structure
n_components = 2 # number of coordinates for the manifold
md_scaling = manifold.MDS(
n_components=n_components,
max_iter=12,
n_init=4,
random_state=0,
normalized_stress=False,
)
S_scaling = md_scaling.fit_transform(S_points)
plot_2d(S_scaling, S_color, "Multidimensional scaling")