-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF port of the Segment Anything Model (SAM) #22970
Conversation
The documentation is not available anymore as the PR was closed or merged. |
def flatten(input, start_dim=0, end_dim=-1): | ||
# Replicates the behavior of torch.flatten in TF | ||
|
||
# If end_dim or start_dim is negative, count them from the end | ||
if end_dim < 0: | ||
end_dim += input.shape.rank | ||
if start_dim < 0: | ||
start_dim += input.shape.rank | ||
|
||
if start_dim == end_dim: | ||
return input | ||
|
||
in_shape = tf.shape(input) | ||
flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1]) | ||
out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0) | ||
return tf.reshape(input, out_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥲
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no idea why I didn't do this before now!
@@ -418,6 +430,45 @@ def post_process_masks( | |||
|
|||
return output_masks | |||
|
|||
def post_process_masks_tf( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we started including separate post-processing ops in native TensorFlow? I thought they were NumPy only. This is indeed nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure about this - there's probably some code duplication in the processor I can remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preprocessing are all in numpy - this hasn't been extended to postprocessing methods yet. Mainly because I haven't dared tackle torch.nn.functional.interpolate
; partly because we haven't needed to yet.
That said - please don't have post_processing_xxx_tf
! We don't use decode_tf
for our tokenizers ;)
Could you rework the methods so there's a single post_process_xxx
method and hidden framework-specifc methods? i.e.
def post_process_masks(self, masks, ...,):
if is_torch_tensor(masks):
return self._post_process_masks_pt(...)
if is_tf_tensor(masks):
return self._post_process_masks_tf(...)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! And sorry - I basically rushed through the processor code so I could get to the bit I was hype about (benchmarking GPT-4's translations)
b9dd5a4
to
b1f61bd
Compare
This is now almost ready to go and the code should be ready for review! Remaining issues:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there two different processing files, one of them not being imported everywhere?
The common tests should not be changed to have a higher tolerance, just override the right tests in proper test file.
Also cc @amyeroberts since you reviewed the PyTorch model extensively.
@@ -469,6 +540,38 @@ def generate_crop_boxes( | |||
image, target_size, crop_n_layers, overlap_ratio, points_per_crop, crop_n_points_downscale_factor, device | |||
) | |||
|
|||
def generate_crop_boxes_tf( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have two functions that do the exact same thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved as part of the general processor refactor!
breakpoint() | ||
print() # Need to check the input shapes here so I know where to pad them |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this is leftover from debugging...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved as part of the general processor refactor! (also oops, sorry)
@@ -0,0 +1,248 @@ | |||
# coding=utf-8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shh, it's gone now. We don't talk about processing_tf_sam
@@ -0,0 +1,122 @@ | |||
# Copyright 2023 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have a separate test file to test the same class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also gone now!
@@ -459,6 +459,9 @@ def test_model_from_pretrained(self): | |||
model = TFData2VecVisionModel.from_pretrained(model_name) | |||
self.assertIsNotNone(model) | |||
|
|||
def test_pt_tf_model_equivalence(self, allow_missing_keys=True, tol=5e-4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change the tolerance for this model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a tolerance argument to the base tests triggered the test to run in other models, which caused this test to fail. I'll investigate and see if it's necessary, though!
def call(self, x: tf.Tensor) -> tf.Tensor: | ||
if self.data_format == "channels_last": | ||
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps) | ||
elif self.data_format == "channels_first": | ||
input_dtype = x.dtype | ||
x = tf.cast(x, tf.float32) | ||
u = tf.reduce_mean(x, axis=1, keepdims=True) | ||
s = tf.math.square(x - u) | ||
s = tf.reduce_mean(s, axis=1, keepdims=True) | ||
x = (x - u) / tf.math.sqrt(s + self.eps) | ||
x = tf.cast(x, input_dtype) | ||
x = self.weight[:, None, None] * x + self.bias[:, None, None] | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use more descriptive variable names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was copied straight from the PyTorch code, but on reflection I could probably refactor the whole thing out, because it was only there to deal with different memory orderings (whereas TensorFlow tensors are always contiguous and always have standard C memory ordering)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! I refactored the functional_layernorm
function to handle alternate axes, and then just called that instead of this manual layernorm. Model output is unchanged and all integration tests still pass.
# Matt: I think this sum is actually checking that the sparse prompt embeddings aren't an empty tensor | ||
# with shape[1] == 0, so I'm going to replace this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace this by?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarified that comment!
self.config = config | ||
|
||
def build(self, input_shape): | ||
# TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To address.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never figured this out, but it's the same in Torch, and both models give equivalent outputs. @ArthurZucker do you know why this weight is non-trainable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't find any reference to this random embedding in the paper (in fact, the paper always mentions learned positional embeddings), but the same pattern is in the SAM codebase
This meme is all I can think of
Thanks for the review - about half of the comments relate to the processor code, which is definitely in need of a refactor, yes. Working on that now! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
Left some general comments - mainly wrt the processing code. I'd like for there to be as little TF/PT specific code if possible. For postprocessing it's OK, as a lot of postprocessing is still pytorch specific but for preprocessing it should be (as much as possible) framework agnostic.
For the processor, can you add pt_tf cross checks to make sure that TF postprocessed outputs are equivalent to the PT ones?
@@ -418,6 +430,45 @@ def post_process_masks( | |||
|
|||
return output_masks | |||
|
|||
def post_process_masks_tf( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preprocessing are all in numpy - this hasn't been extended to postprocessing methods yet. Mainly because I haven't dared tackle torch.nn.functional.interpolate
; partly because we haven't needed to yet.
That said - please don't have post_processing_xxx_tf
! We don't use decode_tf
for our tokenizers ;)
Could you rework the methods so there's a single post_process_xxx
method and hidden framework-specifc methods? i.e.
def post_process_masks(self, masks, ...,):
if is_torch_tensor(masks):
return self._post_process_masks_pt(...)
if is_tf_tensor(masks):
return self._post_process_masks_tf(...)
...
@@ -267,7 +267,7 @@ def prepare_numpy_arrays(inputs_dict): | |||
|
|||
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise | |||
# to generate masks during test | |||
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): | |||
def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict, tol=1e-5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to add the tol
argument here? Unless necessary, I'd avoid resetting the tol default in all the methods so we only need to update in one place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored this and reverted all the changes in the common tests
if output_hidden_states: | ||
vision_hidden_states = vision_outputs[1] | ||
if output_attentions: | ||
vision_attentions = vision_outputs[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we instead pass in return_dict=True
to self.vision_encoder
and then explicitly access the values from the names? I'm not a big fan of accessing from indexes here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! (Also changed in the original PT code)
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why have these arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not clear about this one! Aren't these arguments common across most of our models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so? Only SAM has get_image_embeddings
and all other get_xxx_embeddings
as far as I can tell just take self
|
||
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only | ||
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced | ||
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:)
@amyeroberts @sgugger I refactored all the changes to the common tests, and just overrode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! 💪
Additional general comment: it seems like it is missing the Keras training
argument all around (call
and in the dropout layers)... but on the other hand, SAM is not trainable. Still, in case we add a training script, I'd add this quick future-proof change :D
self.config = config | ||
|
||
def build(self, input_shape): | ||
# TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't find any reference to this random embedding in the paper (in fact, the paper always mentions learned positional embeddings), but the same pattern is in the SAM codebase
This meme is all I can think of
76cebb9
to
17536e4
Compare
@gante I think all comments are now addressed, and I added All comments from @amyeroberts and @sgugger should be addressed too - are you okay with going ahead and merging now once tests pass? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work on this. @amyeroberts could you also have a look before this is merged?
points (`torch.Tensor`, **optional**): | ||
point coordinates and labels to embed. | ||
boxes (`torch.Tensor`, **optionnal**): | ||
boxes (`torch.Tensor`, **optional**): | ||
boxes to embed | ||
masks (`torch.Tensor`, **optionnal**): | ||
masks (`torch.Tensor`, **optional**): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are touching this, can you put the optionals in italics and not bold ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
return_dict=return_dict, | ||
return_dict=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cannot be forced as return_dict
breaks jit compilation. This change needs reverting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad - this was my suggestion, sorry @Rocketknight1!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
values. | ||
""" | ||
|
||
def __init__(self, config, downsample_rate=None, **kwargs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The -> None
make zero sense to me as a type annotation (I know it's what PEP says, but the init returns an instance of the class). Since there are no type annotations elsewhere, maybe just remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! (for all classes across both the PT and TF files)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! 🔥
Thanks for iterating, and in particular for spending the time to add equivalence tests for the processor and keep the image processing code tidy with the two frameworks 🤗
|
||
self.assertTrue(np.all(tf_masks[0].numpy() == pt_masks[0].numpy())) | ||
|
||
def test_image_processor_equivalence(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗
return_dict=return_dict, | ||
return_dict=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad - this was my suggestion, sorry @Rocketknight1!
masks = outputs.pred_masks[0, 0, 0, 0, :3] | ||
|
||
self.assertTrue(np.allclose(scores[-1].numpy(), np.array(0.5798), atol=2e-4)) | ||
self.assertTrue(np.allclose(masks.numpy(), np.array([-6.6381, -6.0734, -7.5308]), atol=1e-2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tolerance seems pretty high 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually okay - the values for the scores are very large (usually in the range 5-30). A tolerance of 2e-4 for numbers that big is quite tight!
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") | ||
|
||
raw_image = prepare_image() | ||
input_boxes = [[650, 900, 1000, 1250]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note in #23376 - input_boxes
should be a list of list of ints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
The bounding boxes corresponding to segmentation masks | ||
amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): | ||
NMS threshold. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is just copying from the PT implementation - but it would be great to add to the docstring info about what's returned as there's many objects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be honest that I don't understand it too well, lol. I'll leave that for a follow-up on the Torch end and copy the strings whenever they do it 😅
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so? Only SAM has get_image_embeddings
and all other get_xxx_embeddings
as far as I can tell just take self
self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first" | ||
) | ||
self.upscale_layer_norm = TFSamLayerNorm( | ||
self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
latyer norm layer here should take eps
from config
self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1") | ||
self.conv2 = tf.keras.layers.Conv2D( | ||
config.output_channels, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="conv2", | ||
) | ||
self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layer norm layers here should take eps
from config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PyTorch version doesn't, and just uses the 1e-6 default kwarg value!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK 👍
@@ -44,6 +51,12 @@ | |||
if is_torchvision_available(): | |||
from torchvision.ops.boxes import batched_nms | |||
|
|||
if is_tf_available(): | |||
import tensorflow as tf | |||
from tensorflow.experimental import numpy as tnp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's hope it's not too experimental 😬
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tnp has been around since 2.4, I think we're safe!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ha! for TF I doubt it ;)
Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
875cc35
to
3902969
Compare
I think comments are addressed now - are we okay to merge? |
I'm treating silence as agreement, merging! |
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <[email protected]> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <[email protected]> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <[email protected]> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <[email protected]> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
* First commit * Add auto-translation with GPT-4 * make fixup * Add a functional layernorm for TF * Add all the auxiliary imports etc. * Add the extra processor and tests * rebase to main * Add all the needed fixes to the GPT code * make fixup * Make convolutions channels-last so they run on CPU * make fixup * Fix final issues * Fix other models affected by test change * Clarify comment on the sparse_prompt_embeddings check * Refactor functional_layernorm, use shape_list in place of .shape in some places * Remove deprecated torch-alike code * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Update tests/models/sam/test_modeling_tf_sam.py Co-authored-by: amyeroberts <[email protected]> * Refactor processor with common methods and separated private methods * make fixup * Quietly delete the file that didn't do anything (sorry Sylvain) * Refactor the processor tests into one file * make fixup * Clean up some unnecessary indirection * Fix TF mask postprocessing * Add more processor equivalence tests * Refactor generate_crop_boxes to use framework-neutral np code * Make the serving output correctly conditional * Fix error message line length * Use dict keys rather than indices internally in both TF and PT SAM call/forward * Return dicts internally in the call/forward methods * Revert changes to common tests and just override check_pt_tf_outputs * Revert changes to other model tests * Clarify comments for functional layernorm * Add missing transpose from PT code * Removed unused copied from in PT code * Remove overrides for tests that don't exist in TF * Fix transpose and update tests for PT and TF to check pred_masks * Add training flag * Update tests to use TF checkpoints * Update index.mdx * Add missing cross-test decorator * Remove optional extra asterisks * Revert return_dict changes in PT code * Update src/transformers/models/sam/modeling_tf_sam.py Co-authored-by: Sylvain Gugger <[email protected]> * Remove None return annotations on init methods * Update tests/models/sam/test_processor_sam.py Co-authored-by: amyeroberts <[email protected]> * Fix input_boxes shapes * make fixup --------- Co-authored-by: amyeroberts <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
This is a first draft of the SAM port - will update this PR as I port tests and make sure everything is working okay. It's also a first proof-of-concept for full GPT-4 auto-translation from PyTorch: The entire
modeling_tf_sam.py
file was converted from PyTorch by GPT-4 with the exception of the imports at the top, because I haven't written a prompt for those yet.Update: I checked over all of the code and fixed the issues in the GPT port. Equivalence tests all look good! This is almost ready to merge, but there are a few small issues left:
channels_first
doesn't actually work on CPU in TF