Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy is lost after save_weights/load_weights #20524

Open
Pandaaaa906 opened this issue Nov 20, 2024 · 3 comments
Open

Accuracy is lost after save_weights/load_weights #20524

Pandaaaa906 opened this issue Nov 20, 2024 · 3 comments
Assignees

Comments

@Pandaaaa906
Copy link

Keras version: 3

TensorFlow version

2.16.1

Current behavior?

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 354ms/step - accuracy: 0.5000 - loss: 1.1560
[1.1560312509536743, 0.5]
Epoch 1/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 596ms/step - accuracy: 0.5000 - loss: 1.1560
Epoch 2/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 28ms/step - accuracy: 0.5000 - loss: 14.5018
Epoch 3/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 30ms/step - accuracy: 0.5000 - loss: 9.9714
Epoch 4/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step - accuracy: 0.7500 - loss: 1.3363
Epoch 5/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - accuracy: 1.0000 - loss: 8.9407e-08
Epoch 6/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 1.0000 - loss: 4.7684e-07
Epoch 7/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step - accuracy: 0.7500 - loss: 0.2545
Epoch 8/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step - accuracy: 0.7500 - loss: 0.8729
Epoch 9/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 28ms/step - accuracy: 1.0000 - loss: 9.1682e-04
Epoch 10/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - accuracy: 1.0000 - loss: 2.6822e-07
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step - accuracy: 1.0000 - loss: 0.0000e+00
[0.0, 1.0]
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 335ms/step - accuracy: 0.2500 - loss: 0.8475 # this should be acc 1.0 loss 0
[0.847506046295166, 0.25]

Standalone code to reproduce the issue

import tensorflow as tf

class CusModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(units=2, activation='softmax', name='output')

    def call(self, x):
        return self.dense(x)

dummy_data_x = tf.convert_to_tensor([[0, 0],
                [1, 0],
                [0, 1],
                [1, 1]])
dummy_data_y = tf.convert_to_tensor([0, 1, 0, 1])

model = CusModel()
model.compile(optimizer=tf.keras.optimizers.Adam(10.0),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
print(model.evaluate(x=dummy_data_x, y=dummy_data_y))
model.fit(x=dummy_data_x, y=dummy_data_y, epochs=10)
print(model.evaluate(x=dummy_data_x, y=dummy_data_y))
model.save_weights('test_model.weights.h5')

model = CusModel()
model.load_weights('test_model.weights.h5')
model.compile(optimizer=tf.keras.optimizers.Adam(10.0),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
print(model.evaluate(x=dummy_data_x, y=dummy_data_y))
@ghsanti
Copy link
Contributor

ghsanti commented Nov 20, 2024

There are a few ways to fix it.

method 1

build the layer in the Model subclass (add the method build):

    def build(self, x): # method within your Model.
      self.dense.build(x)

And call model.build((2,)) after compile, and before evaluate.

method 2

To force weights you can also evaluate twice, as done below:

print(model.evaluate(x=dummy_data_x, ...)) # creates weights
model.load_weights('test_model.weights.h5') # loads new
print(model.evaluate(x=dummy_data_x, ...)) # ok result

this works because the layer will be build for you.

@sonali-kumari1
Copy link

Hi @Pandaaaa906,

Thanks for reporting the issue.

Instead of saving and loading the weights, you can try model.save() and keras.models.load_model().
By doing this, you can ensure that model architecture,optimizer state and weights everything is saved.

Attaching gist for your reference.

@james77777778
Copy link
Contributor

@Pandaaaa906
I have submitted a PR (which has been merged) to raise an error when calling save_weights and load_weights on an unbuilt model. You should consider adopting the methods provided by @ghsanti to resolve this issue.

Thanks @ghsanti for pointing this out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants