Skip to content

Commit

Permalink
add TrainClassifierCNN class to perform multiclass classification for…
Browse files Browse the repository at this point in the history
… Danielle/Kristin project
  • Loading branch information
EhsanGharibNezhad committed Oct 23, 2024
1 parent 0016804 commit b5b5f41
Showing 1 changed file with 173 additions and 0 deletions.
173 changes: 173 additions & 0 deletions TelescopeML/DeepTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,176 @@ def fit_cnn_model(self,
# test_score = model.evaluate(x = [self.X1_test, self.X2_test],
# y = [self.y1_test, self.y2_test, self.y3_test, self.y4_test],
# verbose=0)


class TrainClassifierCNN:
"""
Train Convolutional Neural Networks model using classification approach.
Parameters
-----------
X1_train : array
Row-StandardScaled input spectra for training.
X1_val : array
Row-StandardScaled input spectra for validation.
X1_test : array
Row-StandardScaled input spectra for testing.
X2_train : array
Col-StandardScaled Mix Max of all rows of input spectra for training.
X2_val : array
Col-StandardScaled Mix Max of all rows of input spectra for validation.
X2_test : array
Col-StandardScaled Mix Max of all rows of input spectra for testing.
y_train : array
One-hot encoded or integer labels for multiclass classification for training.
y_val : array
One-hot encoded or integer labels for multiclass classification for validation.
y_test : array
One-hot encoded or integer labels for multiclass classification for testing.
"""
def __init__(self,
X1_train, X1_val, X1_test, # Row-StandardScaled input spectra
X2_train, X2_val, X2_test, # Col-StandardScaled Mix Max of all rows of input spectra
y_train, y_val, y_test # Labels for classification
):

self.X1_train, self.X1_val, self.X1_test = X1_train, X1_val, X1_test
self.X2_train, self.X2_val, self.X2_test = X2_train, X2_val, X2_test
self.y_train, self.y_val, self.y_test = y_train, y_val, y_test

def build_model(self, config):
"""
Build a CNN model for multiclass classification with the given hyperparameters.
Parameters
----------
config : dict
Dictionary containing hyperparameter settings for the CNN and FC layers.
Returns
-------
object
Pre-built tf.keras.Model CNN model for multiclass classification.
"""

Conv__filters = config['Conv__filters']
Conv__kernel_size = config['Conv__kernel_size']
Conv__MaxPooling1D = config['Conv__MaxPooling1D']
Conv__NumberLayers = config['Conv__NumberLayers']
Conv__NumberBlocks = config['Conv__NumberBlocks']

FC1__units = config['FC1__units']
FC1__dropout = config['FC1__dropout']
FC1__NumberLayers = config['FC1__NumberLayers']

FC2__units = config['FC2__units']
FC2__NumberLayers = config['FC2__NumberLayers']
FC2__dropout = config['FC2__dropout']
FC2__NumberBlocks = config['FC2__NumberBlocks']

lr = config['lr']
self.lr = lr

# Shape of the inputs
input_1 = tf.keras.layers.Input(shape=(104, 1))
input_2 = tf.keras.layers.Input(shape=(2,))

# Convolutional Blocks
model = input_1
for b in range(Conv__NumberBlocks):
for l in range(Conv__NumberLayers):
model = Conv1D(filters=Conv__filters * (b + l + 1) ** 2,
kernel_size=Conv__kernel_size,
padding='same',
activation='relu',
kernel_initializer='he_normal')(model)
model = MaxPooling1D(pool_size=(Conv__MaxPooling1D))(model)

# Flatten Layer
model = Flatten()(model)

# Fully Connected Layers (before concatenation)
for l in range(FC1__NumberLayers):
model = Dense(FC1__units * (l + 1) ** 2, activation='relu',
kernel_initializer='he_normal')(model)
model = Dropout(FC1__dropout)(model)

# Concatenation Layer
model = tf.keras.layers.concatenate([model, input_2])

# Fully Connected Layers after concatenation
for b in range(FC2__NumberBlocks):
for l in range(FC2__NumberLayers):
model = Dense(FC2__units * (b + l + 1) ** 2, activation='relu',
kernel_initializer='he_normal')(model)
model = Dropout(FC2__dropout)(model)

# Output Layer for Multiclass Classification
output = Dense(units=config['output_units'], activation='softmax', name='output')(model)

# Create the model
model = tf.keras.Model(inputs=[input_1, input_2], outputs=[output])

self.model = model

# Print model summary
print(model.summary())

def fit_cnn_model(self, batch_size=32, epochs=3):
"""
Fit the CNN model for multiclass classification.
Parameters
----------
batch_size : int, default 32
The number of samples per gradient update.
epochs : int, default 3
Number of epochs to train the model.
Returns
-------
history : object
Training history (Loss values for train and validation).
"""
# Compile the model with categorical crossentropy loss and Adam optimizer
self.model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=self.lr),
metrics=['accuracy'])

# Early stopping to prevent overfitting
early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Train the model
history = self.model.fit(x=[self.X1_train, self.X2_train],
y=self.y_train,
batch_size=batch_size,
validation_data=([self.X1_val, self.X2_val], self.y_val),
epochs=epochs,
verbose=1,
callbacks=[early_stop])

self.history = history
return history

# Example usage with hyperparameter configuration:
# config = {
# 'Conv__filters': 32,
# 'Conv__kernel_size': 3,
# 'Conv__MaxPooling1D': 2,
# 'Conv__NumberLayers': 2,
# 'Conv__NumberBlocks': 2,
# 'FC1__units': 64,
# 'FC1__dropout': 0.5,
# 'FC1__NumberLayers': 2,
# 'FC2__units': 32,
# 'FC2__dropout': 0.5,
# 'FC2__NumberLayers': 2,
# 'FC2__NumberBlocks': 2,
# 'lr': 0.001,
# 'output_units': 10 # For 10-class classification
# }
# model = TrainClassifierCNN(X1_train, X1_val, X1_test, X2_train, X2_val, X2_test, y_train, y_val, y_test)
# model.build_model(config)
# history = model.fit_cnn_model(batch_size=64, epochs=50)

0 comments on commit b5b5f41

Please sign in to comment.