-
Notifications
You must be signed in to change notification settings - Fork 106
/
Copy pathImageClassificationKeras.cs
102 lines (87 loc) · 3.14 KB
/
ImageClassificationKeras.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
using System.Collections.Generic;
using Tensorflow;
using Tensorflow.Keras;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras.Utils;
using System.IO;
using Tensorflow.Keras.Engine;
namespace TensorFlowNET.Examples;
/// <summary>
/// This tutorial shows how to classify images of flowers.
/// https://www.tensorflow.org/tutorials/images/classification
/// </summary>
public class ImageClassificationKeras : SciSharpExample, IExample
{
int batch_size = 32;
int epochs = 10;
Shape img_dim = (64, 64);
IDatasetV2 train_ds, val_ds;
Model model;
public ExampleConfig InitConfig()
=> Config = new ExampleConfig
{
Name = "Image Classification (Keras)",
Enabled = true
};
public bool Run()
{
tf.enable_eager_execution();
PrepareData();
BuildModel();
Train();
return true;
}
public override void BuildModel()
{
int num_classes = 5;
// var normalization_layer = tf.keras.layers.Rescaling(1.0f / 255);
var layers = keras.layers;
var myLayers = new List<ILayer>
{
layers.Rescaling(1.0f / 255, input_shape: (img_dim.dims[0], img_dim.dims[1], 3)),
layers.Conv2D(16, 3, padding: "same", activation: keras.activations.Relu),
layers.MaxPooling2D(),
/*layers.Conv2D(32, 3, padding: "same", activation: "relu"),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding: "same", activation: "relu"),
layers.MaxPooling2D(),*/
layers.Flatten(),
layers.Dense(128, activation: keras.activations.Relu),
layers.Dense(num_classes)
};
model = keras.Sequential(myLayers);
model.compile(optimizer: keras.optimizers.Adam(),
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
metrics: new[] { "accuracy" });
model.summary();
}
public override void Train()
{
model.fit(train_ds, validation_data: val_ds, epochs: epochs);
}
public override void PrepareData()
{
string fileName = "flower_photos.tgz";
string url = $"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz";
string data_dir = Path.Combine(Path.GetTempPath(), "flower_photos");
Web.Download(url, data_dir, fileName);
Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir);
data_dir = Path.Combine(data_dir, "flower_photos");
// convert to tensor
train_ds = keras.preprocessing.image_dataset_from_directory(data_dir,
validation_split: 0.2f,
subset: "training",
seed: 123,
image_size: img_dim,
batch_size: batch_size);
val_ds = keras.preprocessing.image_dataset_from_directory(data_dir,
validation_split: 0.2f,
subset: "validation",
seed: 123,
image_size: img_dim,
batch_size: batch_size);
train_ds = train_ds.shuffle(1000).prefetch(buffer_size: -1);
val_ds = val_ds.prefetch(buffer_size: -1);
}
}