Skip to content

Commit

Permalink
[ci] Updates google code formatter to 1.22.0 (#3149)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored May 2, 2024
1 parent 5ca2d02 commit a56d916
Show file tree
Hide file tree
Showing 28 changed files with 94 additions and 22 deletions.
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/convolutional/Convolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ static NDList convolution(
return input.getNDArrayInternal()
.convolution(input, weight, bias, stride, padding, dilation, groups);
}

/**
* A builder that can build any {@code Convolution} block.
*
Expand Down
3 changes: 3 additions & 0 deletions api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ static NDList deconvolution(
return input.getNDArrayInternal()
.deconvolution(input, weight, bias, stride, padding, outPadding, dilation, groups);
}

/**
* A builder that can build any {@code Deconvolution} block.
*
Expand Down Expand Up @@ -246,6 +247,7 @@ public T optPadding(Shape padding) {
this.padding = padding;
return self();
}

/**
* Sets the out_padding along each dimension. Defaults to 0 along each dimension.
*
Expand All @@ -256,6 +258,7 @@ public T optOutPadding(Shape outPadding) {
this.outPadding = outPadding;
return self();
}

/**
* Sets the dilation along each dimension. Defaults to 1 along each dimension.
*
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/transformer/BertBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ public static final class Builder {
float hiddenDropoutProbability = 0.1f;
// float attentionDropoutProbability = 0.1f;
int maxSequenceLength = 512;

// float initializerRange = 0.02f;

private Builder() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,22 @@ public final class ScaledDotProductAttentionBlock extends AbstractBlock {

/** Size of the Word-/Token-embeddings we use the attention on. */
private int embeddingSize;

/** Number of attention heads. */
private int headCount;

/** Pointwise Linear projection of the keys. */
private Linear keyProjection;

/** Pointwise Linear projection of the queries. */
private Linear queryProjection;

/** Pointwise Linear projection of the values. */
private Linear valueProjection;

/** Pointwise Linear projection of the results. */
private Linear resultProjection;

/** Dropout operation to be applied after probability calculation. */
private Dropout attentionProbsDropout;

Expand Down Expand Up @@ -119,6 +125,7 @@ private Linear buildProjection() {
public Linear getKeyProjection() {
return keyProjection;
}

/**
* Pointwise Linear projection of the queries.
*
Expand All @@ -127,6 +134,7 @@ public Linear getKeyProjection() {
public Linear getQueryProjection() {
return queryProjection;
}

/**
* Pointwise Linear projection of the values.
*
Expand All @@ -135,6 +143,7 @@ public Linear getQueryProjection() {
public Linear getValueProjection() {
return valueProjection;
}

/**
* Pointwise Linear projection of the results.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ public class TransformerEncoderBlock extends AbstractBlock {

/** The attention mechanism. */
private ScaledDotProductAttentionBlock selfAttentionBlock;

/** Dropout before residual & layer normalization. */
private Dropout selfAttentionDropout;

/** Normalization of attention output and residual. */
private BatchNorm attentionNorm;

/** Fully connected pointwise block for output projection. */
private PointwiseFeedForwardBlock pointWisefullyConnected;

/** Dropout after fully connected and before last residual & layer normalization. */
private Dropout fullyConnectedDropout;

/** Another normalization for the output and residual. */
private BatchNorm outputNorm;

Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/training/ParameterServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ default void update(String parameterId, NDArray[] params) {
update(parameterId, grads, params);
Arrays.stream(grads).forEach(NDArray::close);
}

/**
* Updates the parameter of a key from Parameter Server.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public void onEpoch(Trainer trainer) {
epochTime = System.nanoTime();
numEpochs++;
}

/** {@inheritDoc} */
@Override
public void onTrainingBegin(Trainer trainer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class MxOpParams extends PairList<String, Object> {

// mxnet cpu take index
private static final String MXNET_CPU = "cpu(0)";

/**
* Sets the Shape parameter.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,7 @@ public static Pointer detachGradient(Pointer handle) {
REFS.recycle(ref);
return pointer;
}

/*
int MXNDArraySetGradState(Pointer handle, int state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ public NDList layerNorm(
manager.from(beta),
eps));
}

/** {@inheritDoc} */
@Override
public NDList batchNorm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public class ProfilerTest {

@Test
public void testProfiler()
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {
try (NDManager manager = NDManager.newBaseManager()) {
ImageClassificationTranslator translator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ public NDList layerNorm(
NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDList batchNorm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static Classifications[] predict(List<String> inputs)
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {
// refer to
// https://medium.com/delvify/bert-rest-inference-from-the-fine-tuned-model-499997b32851 and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ public static void main(String[] args) throws IOException, TranslateException, M
}

public static Classifications predict()
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {
String input = "I like DJL. DJL is the best DL framework!";
logger.info("input Sentence: {}", input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static float[][] predict(List<String> inputs)
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {
String modelUrl =
"https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/4.tar.gz";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ public static List<Joints> predict() throws IOException, ModelException, Transla
}

private static List<Image> predictPeopleInImage(Image img)
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {

Criteria<Image, DetectedObjects> criteria =
Expand Down Expand Up @@ -106,7 +108,9 @@ private static List<Image> predictPeopleInImage(Image img)
}

private static List<Joints> predictJointsForPeople(List<Image> people)
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {

// Use DJL MXNet model zoo model, model can be found:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static Image transfer(Image image, Artist artist)
throws IOException, ModelNotFoundException, MalformedModelException,
throws IOException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
// Use DJL PyTorch model zoo model, model can be found:
// https://mlrepo.djl.ai/model/cv/image_generation/ai/djl/pytorch/cyclegan/0.0.1/style_xxxx.zip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ public final class RollingBatch {
private RollingBatch() {}

public static void main(String[] args)
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
String[] ret = seqBatchSchedulerWithPyTorchContrastive();
logger.info("{}", ret[0]);
}

public static String[] seqBatchSchedulerWithPyTorchContrastive()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ public final class TextGeneration {
private TextGeneration() {}

public static void main(String[] args)
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
String ret1 = generateTextWithPyTorchGreedy();
logger.info("{}", ret1);
Expand All @@ -53,7 +55,9 @@ public static void main(String[] args)
}

public static String generateTextWithPyTorchGreedy()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
Expand Down Expand Up @@ -88,7 +92,9 @@ public static String generateTextWithPyTorchGreedy()
}

public static String[] generateTextWithPyTorchContrastive()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
Expand Down Expand Up @@ -120,7 +126,9 @@ public static String[] generateTextWithPyTorchContrastive()
}

public static String[] generateTextWithPyTorchBeam()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
Expand Down Expand Up @@ -153,7 +161,9 @@ public static String[] generateTextWithPyTorchBeam()
}

public static String[] generateTextWithOnnxRuntimeBeam()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,17 @@ public final class TrainSentimentAnalysis {
private TrainSentimentAnalysis() {}

public static void main(String[] args)
throws IOException, ModelNotFoundException, MalformedModelException,
throws IOException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
TrainSentimentAnalysis.runExample(args);
}

public static TrainingResult runExample(String[] args)
throws IOException, ModelNotFoundException, MalformedModelException,
throws IOException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,19 @@ public final class TrainWithOptimizers {
private TrainWithOptimizers() {}

public static void main(String[] args)
throws IOException, ParseException, ModelNotFoundException, MalformedModelException,
throws IOException,
ParseException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
TrainWithOptimizers.runExample(args);
}

public static TrainingResult runExample(String[] args)
throws IOException, ParseException, ModelNotFoundException, MalformedModelException,
throws IOException,
ParseException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
OptimizerArguments arguments =
(OptimizerArguments) new OptimizerArguments().parseArgs(args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ public class StyleTransferTest {

@Test
public void testStyleTransfer()
throws IOException, ModelNotFoundException, MalformedModelException,
throws IOException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
TestRequirements.engine("PyTorch");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ public void testTrainTextClassification() throws IOException {

@Test
public void testTextClassification()
throws IOException, MalformedModelException, ModelNotFoundException,
throws IOException,
MalformedModelException,
ModelNotFoundException,
TranslateException {
Criteria<String, Classifications> criteria =
Criteria.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public RsNDArray(RsNDManager manager, long handle) {
RsNDArray(RsNDManager manager, long handle, DataType dataType) {
this(manager, handle, dataType, null);
}

/**
* Constructs a Rust {@code NDArray} from a native handle (internal. Use {@link NDManager}
* instead) with the data that is hold on Java side.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ public NDList layerNorm(
NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDList batchNorm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ public void testWithIntermediate() throws TranslateException {

@Test
public void testLoadPredict()
throws IOException, ModelNotFoundException, TranslateException,
throws IOException,
ModelNotFoundException,
TranslateException,
MalformedModelException {
try (ZooModel<Image, Classifications> model = getModel()) {
NoopTranslator translator = new NoopTranslator(Batchifier.STACK);
Expand Down
Loading

0 comments on commit a56d916

Please sign in to comment.