diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 6de979996728d..ea073e9d5d199 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -49,13 +49,13 @@ def __call__(self, x): x = x.float() x = self.norm1(x) x = x.cast(dtypes.default_float) - x = x.gelu() + x = x.quick_gelu() residual = x x = self.conv2(x) x = x.float() x = self.norm2(x) x = x.cast(dtypes.default_float) - x = x.gelu() + x = x.quick_gelu() return x + residual @@ -64,7 +64,7 @@ def __init__(self, W): self.whitening = W self.net = [ nn.Conv2d(12, 32, kernel_size=1, bias=False), - lambda x: x.gelu(), + lambda x: x.quick_gelu(), ConvGroup(32, 64), ConvGroup(64, 256), ConvGroup(256, 512),