-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
34 lines (29 loc) · 1.21 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
from data_loader import load_dataset
# Load the dataset
data, labels = load_dataset("data")
# Preprocess data: pad sequences to the same length
max_points = 2048
data_padded = np.array([np.pad(d, ((0, max_points - len(d)), (0, 0)), mode='constant') for d in data])
# Define the PointNet++ model
def create_pointnet_model(num_classes):
inputs = layers.Input(shape=(max_points, 3))
x = layers.Conv1D(64, kernel_size=1, activation='relu')(inputs)
x = layers.Conv1D(128, kernel_size=1, activation='relu')(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dense(128, activation='relu')(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs=inputs, outputs=outputs)
return model
# Create and compile the model
num_classes = 3
model = create_pointnet_model(num_classes)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Train the model
model.fit(data_padded, labels, epochs=20, batch_size=4)
model.save("chess_pointnet_model.h5")
model.export("saved_model")
print("Model trained and saved!")