diff --git a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java index a049c20e2b7..edb1df359ec 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Convolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Convolution.java @@ -285,6 +285,7 @@ static NDList convolution( return input.getNDArrayInternal() .convolution(input, weight, bias, stride, padding, dilation, groups); } + /** * A builder that can build any {@code Convolution} block. * diff --git a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java index 419780a98d1..dc14af70777 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Deconvolution.java @@ -197,6 +197,7 @@ static NDList deconvolution( return input.getNDArrayInternal() .deconvolution(input, weight, bias, stride, padding, outPadding, dilation, groups); } + /** * A builder that can build any {@code Deconvolution} block. * @@ -246,6 +247,7 @@ public T optPadding(Shape padding) { this.padding = padding; return self(); } + /** * Sets the out_padding along each dimension. Defaults to 0 along each dimension. * @@ -256,6 +258,7 @@ public T optOutPadding(Shape outPadding) { this.outPadding = outPadding; return self(); } + /** * Sets the dilation along each dimension. Defaults to 1 along each dimension. * diff --git a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java index af2c5c175b5..627a74fd2f6 100644 --- a/api/src/main/java/ai/djl/nn/transformer/BertBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/BertBlock.java @@ -283,6 +283,7 @@ public static final class Builder { float hiddenDropoutProbability = 0.1f; // float attentionDropoutProbability = 0.1f; int maxSequenceLength = 512; + // float initializerRange = 0.02f; private Builder() {} diff --git a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java index fa293f4b37a..a6ecbb0ec52 100644 --- a/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/ScaledDotProductAttentionBlock.java @@ -72,16 +72,22 @@ public final class ScaledDotProductAttentionBlock extends AbstractBlock { /** Size of the Word-/Token-embeddings we use the attention on. */ private int embeddingSize; + /** Number of attention heads. */ private int headCount; + /** Pointwise Linear projection of the keys. */ private Linear keyProjection; + /** Pointwise Linear projection of the queries. */ private Linear queryProjection; + /** Pointwise Linear projection of the values. */ private Linear valueProjection; + /** Pointwise Linear projection of the results. */ private Linear resultProjection; + /** Dropout operation to be applied after probability calculation. */ private Dropout attentionProbsDropout; @@ -119,6 +125,7 @@ private Linear buildProjection() { public Linear getKeyProjection() { return keyProjection; } + /** * Pointwise Linear projection of the queries. * @@ -127,6 +134,7 @@ public Linear getKeyProjection() { public Linear getQueryProjection() { return queryProjection; } + /** * Pointwise Linear projection of the values. * @@ -135,6 +143,7 @@ public Linear getQueryProjection() { public Linear getValueProjection() { return valueProjection; } + /** * Pointwise Linear projection of the results. * diff --git a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java index f01cb1adc33..8d2d8d69b46 100644 --- a/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java +++ b/api/src/main/java/ai/djl/nn/transformer/TransformerEncoderBlock.java @@ -31,14 +31,19 @@ public class TransformerEncoderBlock extends AbstractBlock { /** The attention mechanism. */ private ScaledDotProductAttentionBlock selfAttentionBlock; + /** Dropout before residual & layer normalization. */ private Dropout selfAttentionDropout; + /** Normalization of attention output and residual. */ private BatchNorm attentionNorm; + /** Fully connected pointwise block for output projection. */ private PointwiseFeedForwardBlock pointWisefullyConnected; + /** Dropout after fully connected and before last residual & layer normalization. */ private Dropout fullyConnectedDropout; + /** Another normalization for the output and residual. */ private BatchNorm outputNorm; diff --git a/api/src/main/java/ai/djl/training/ParameterServer.java b/api/src/main/java/ai/djl/training/ParameterServer.java index 54261931efb..0dda061aa8f 100644 --- a/api/src/main/java/ai/djl/training/ParameterServer.java +++ b/api/src/main/java/ai/djl/training/ParameterServer.java @@ -39,6 +39,7 @@ default void update(String parameterId, NDArray[] params) { update(parameterId, grads, params); Arrays.stream(grads).forEach(NDArray::close); } + /** * Updates the parameter of a key from Parameter Server. * diff --git a/api/src/main/java/ai/djl/training/listener/EpochTrainingListener.java b/api/src/main/java/ai/djl/training/listener/EpochTrainingListener.java index e92110e9686..c2a8c7fdc58 100644 --- a/api/src/main/java/ai/djl/training/listener/EpochTrainingListener.java +++ b/api/src/main/java/ai/djl/training/listener/EpochTrainingListener.java @@ -35,6 +35,7 @@ public void onEpoch(Trainer trainer) { epochTime = System.nanoTime(); numEpochs++; } + /** {@inheritDoc} */ @Override public void onTrainingBegin(Trainer trainer) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxOpParams.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxOpParams.java index cea8fd2e6bc..61e48266c63 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxOpParams.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxOpParams.java @@ -23,6 +23,7 @@ public class MxOpParams extends PairList { // mxnet cpu take index private static final String MXNET_CPU = "cpu(0)"; + /** * Sets the Shape parameter. * diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java index d5930cdbaa9..3ec9aab027b 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/jna/JnaUtils.java @@ -908,6 +908,7 @@ public static Pointer detachGradient(Pointer handle) { REFS.recycle(ref); return pointer; } + /* int MXNDArraySetGradState(Pointer handle, int state); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index b7f92cbd1c3..e9eb8b2b771 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -437,6 +437,7 @@ public NDList layerNorm( manager.from(beta), eps)); } + /** {@inheritDoc} */ @Override public NDList batchNorm( diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java index 88f1e9cd32f..6c60bf15e8c 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java @@ -43,7 +43,9 @@ public class ProfilerTest { @Test public void testProfiler() - throws MalformedModelException, ModelNotFoundException, IOException, + throws MalformedModelException, + ModelNotFoundException, + IOException, TranslateException { try (NDManager manager = NDManager.newBaseManager()) { ImageClassificationTranslator translator = diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java index b7c9ec3e049..5160a6b1c79 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java @@ -398,6 +398,7 @@ public NDList layerNorm( NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { throw new UnsupportedOperationException(); } + /** {@inheritDoc} */ @Override public NDList batchNorm( diff --git a/examples/src/main/java/ai/djl/examples/inference/BertClassification.java b/examples/src/main/java/ai/djl/examples/inference/BertClassification.java index 96404ab276d..dcea78b2561 100644 --- a/examples/src/main/java/ai/djl/examples/inference/BertClassification.java +++ b/examples/src/main/java/ai/djl/examples/inference/BertClassification.java @@ -65,7 +65,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans } public static Classifications[] predict(List inputs) - throws MalformedModelException, ModelNotFoundException, IOException, + throws MalformedModelException, + ModelNotFoundException, + IOException, TranslateException { // refer to // https://medium.com/delvify/bert-rest-inference-from-the-fine-tuned-model-499997b32851 and diff --git a/examples/src/main/java/ai/djl/examples/inference/SentimentAnalysis.java b/examples/src/main/java/ai/djl/examples/inference/SentimentAnalysis.java index e83e27fcc87..c2ee112d52e 100644 --- a/examples/src/main/java/ai/djl/examples/inference/SentimentAnalysis.java +++ b/examples/src/main/java/ai/djl/examples/inference/SentimentAnalysis.java @@ -49,7 +49,9 @@ public static void main(String[] args) throws IOException, TranslateException, M } public static Classifications predict() - throws MalformedModelException, ModelNotFoundException, IOException, + throws MalformedModelException, + ModelNotFoundException, + IOException, TranslateException { String input = "I like DJL. DJL is the best DL framework!"; logger.info("input Sentence: {}", input); diff --git a/examples/src/main/java/ai/djl/examples/inference/UniversalSentenceEncoder.java b/examples/src/main/java/ai/djl/examples/inference/UniversalSentenceEncoder.java index 01d64e98b5e..4cc299f42bb 100644 --- a/examples/src/main/java/ai/djl/examples/inference/UniversalSentenceEncoder.java +++ b/examples/src/main/java/ai/djl/examples/inference/UniversalSentenceEncoder.java @@ -61,7 +61,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans } public static float[][] predict(List inputs) - throws MalformedModelException, ModelNotFoundException, IOException, + throws MalformedModelException, + ModelNotFoundException, + IOException, TranslateException { String modelUrl = "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/4.tar.gz"; diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java b/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java index 1efc594773e..3a5866c7b55 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java @@ -70,7 +70,9 @@ public static List predict() throws IOException, ModelException, Transla } private static List predictPeopleInImage(Image img) - throws MalformedModelException, ModelNotFoundException, IOException, + throws MalformedModelException, + ModelNotFoundException, + IOException, TranslateException { Criteria criteria = @@ -106,7 +108,9 @@ private static List predictPeopleInImage(Image img) } private static List predictJointsForPeople(List people) - throws MalformedModelException, ModelNotFoundException, IOException, + throws MalformedModelException, + ModelNotFoundException, + IOException, TranslateException { // Use DJL MXNet model zoo model, model can be found: diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/StyleTransfer.java b/examples/src/main/java/ai/djl/examples/inference/cv/StyleTransfer.java index 34fef8a89e6..cc493a9f61e 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/StyleTransfer.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/StyleTransfer.java @@ -68,7 +68,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans } public static Image transfer(Image image, Artist artist) - throws IOException, ModelNotFoundException, MalformedModelException, + throws IOException, + ModelNotFoundException, + MalformedModelException, TranslateException { // Use DJL PyTorch model zoo model, model can be found: // https://mlrepo.djl.ai/model/cv/image_generation/ai/djl/pytorch/cyclegan/0.0.1/style_xxxx.zip diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java b/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java index a0aada16379..6751f5448d5 100644 --- a/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/RollingBatch.java @@ -40,14 +40,18 @@ public final class RollingBatch { private RollingBatch() {} public static void main(String[] args) - throws ModelNotFoundException, MalformedModelException, IOException, + throws ModelNotFoundException, + MalformedModelException, + IOException, TranslateException { String[] ret = seqBatchSchedulerWithPyTorchContrastive(); logger.info("{}", ret[0]); } public static String[] seqBatchSchedulerWithPyTorchContrastive() - throws ModelNotFoundException, MalformedModelException, IOException, + throws ModelNotFoundException, + MalformedModelException, + IOException, TranslateException { String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip"; diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java index 676188412c7..1de43457610 100644 --- a/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/TextGeneration.java @@ -42,7 +42,9 @@ public final class TextGeneration { private TextGeneration() {} public static void main(String[] args) - throws ModelNotFoundException, MalformedModelException, IOException, + throws ModelNotFoundException, + MalformedModelException, + IOException, TranslateException { String ret1 = generateTextWithPyTorchGreedy(); logger.info("{}", ret1); @@ -53,7 +55,9 @@ public static void main(String[] args) } public static String generateTextWithPyTorchGreedy() - throws ModelNotFoundException, MalformedModelException, IOException, + throws ModelNotFoundException, + MalformedModelException, + IOException, TranslateException { SearchConfig config = new SearchConfig(); config.setMaxSeqLength(60); @@ -88,7 +92,9 @@ public static String generateTextWithPyTorchGreedy() } public static String[] generateTextWithPyTorchContrastive() - throws ModelNotFoundException, MalformedModelException, IOException, + throws ModelNotFoundException, + MalformedModelException, + IOException, TranslateException { SearchConfig config = new SearchConfig(); config.setMaxSeqLength(60); @@ -120,7 +126,9 @@ public static String[] generateTextWithPyTorchContrastive() } public static String[] generateTextWithPyTorchBeam() - throws ModelNotFoundException, MalformedModelException, IOException, + throws ModelNotFoundException, + MalformedModelException, + IOException, TranslateException { SearchConfig config = new SearchConfig(); config.setMaxSeqLength(60); @@ -153,7 +161,9 @@ public static String[] generateTextWithPyTorchBeam() } public static String[] generateTextWithOnnxRuntimeBeam() - throws ModelNotFoundException, MalformedModelException, IOException, + throws ModelNotFoundException, + MalformedModelException, + IOException, TranslateException { SearchConfig config = new SearchConfig(); config.setMaxSeqLength(60); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java b/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java index 4ea222aeb20..d89129e7bf5 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java @@ -76,13 +76,17 @@ public final class TrainSentimentAnalysis { private TrainSentimentAnalysis() {} public static void main(String[] args) - throws IOException, ModelNotFoundException, MalformedModelException, + throws IOException, + ModelNotFoundException, + MalformedModelException, TranslateException { TrainSentimentAnalysis.runExample(args); } public static TrainingResult runExample(String[] args) - throws IOException, ModelNotFoundException, MalformedModelException, + throws IOException, + ModelNotFoundException, + MalformedModelException, TranslateException { Arguments arguments = new Arguments().parseArgs(args); if (arguments == null) { diff --git a/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java b/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java index ba1ea934fb5..f2db8c8a7c2 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java @@ -63,13 +63,19 @@ public final class TrainWithOptimizers { private TrainWithOptimizers() {} public static void main(String[] args) - throws IOException, ParseException, ModelNotFoundException, MalformedModelException, + throws IOException, + ParseException, + ModelNotFoundException, + MalformedModelException, TranslateException { TrainWithOptimizers.runExample(args); } public static TrainingResult runExample(String[] args) - throws IOException, ParseException, ModelNotFoundException, MalformedModelException, + throws IOException, + ParseException, + ModelNotFoundException, + MalformedModelException, TranslateException { OptimizerArguments arguments = (OptimizerArguments) new OptimizerArguments().parseArgs(args); diff --git a/examples/src/test/java/ai/djl/examples/inference/StyleTransferTest.java b/examples/src/test/java/ai/djl/examples/inference/StyleTransferTest.java index 8b4d4510a29..1981cda7643 100644 --- a/examples/src/test/java/ai/djl/examples/inference/StyleTransferTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/StyleTransferTest.java @@ -30,7 +30,9 @@ public class StyleTransferTest { @Test public void testStyleTransfer() - throws IOException, ModelNotFoundException, MalformedModelException, + throws IOException, + ModelNotFoundException, + MalformedModelException, TranslateException { TestRequirements.engine("PyTorch"); diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java index 4395ddf1a6c..e88efefc0a7 100644 --- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java +++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java @@ -81,7 +81,9 @@ public void testTrainTextClassification() throws IOException { @Test public void testTextClassification() - throws IOException, MalformedModelException, ModelNotFoundException, + throws IOException, + MalformedModelException, + ModelNotFoundException, TranslateException { Criteria criteria = Criteria.builder() diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArray.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArray.java index 715d25033e1..443dd1392cc 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArray.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArray.java @@ -62,6 +62,7 @@ public RsNDArray(RsNDManager manager, long handle) { RsNDArray(RsNDManager manager, long handle, DataType dataType) { this(manager, handle, dataType, null); } + /** * Constructs a Rust {@code NDArray} from a native handle (internal. Use {@link NDManager} * instead) with the data that is hold on Java side. diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArrayEx.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArrayEx.java index 453658a860f..f30037bf55a 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArrayEx.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsNDArrayEx.java @@ -402,6 +402,7 @@ public NDList layerNorm( NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { throw new UnsupportedOperationException("Not implemented"); } + /** {@inheritDoc} */ @Override public NDList batchNorm( diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java index 25be89ae738..6133c6ed553 100644 --- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/classification/ResnetTest.java @@ -138,7 +138,9 @@ public void testWithIntermediate() throws TranslateException { @Test public void testLoadPredict() - throws IOException, ModelNotFoundException, TranslateException, + throws IOException, + ModelNotFoundException, + TranslateException, MalformedModelException { try (ZooModel model = getModel()) { NoopTranslator translator = new NoopTranslator(Batchifier.STACK); diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/MobileNetV2.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/MobileNetV2.java index fe5e2d09f6c..4a817cd0dc6 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/MobileNetV2.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/MobileNetV2.java @@ -44,6 +44,7 @@ public final class MobileNetV2 { public static final int MULTILENGTH = 7; private MobileNetV2() {} + /** * Builds a {@link Block} that represent an inverted residual Unit used in the implementation of * the MobileNetV2 Model. diff --git a/tools/gradle/java-formatter.gradle b/tools/gradle/java-formatter.gradle index 5bf1b92f350..426939803f5 100644 --- a/tools/gradle/java-formatter.gradle +++ b/tools/gradle/java-formatter.gradle @@ -6,7 +6,7 @@ buildscript { } } dependencies { - classpath 'com.google.googlejavaformat:google-java-format:1.15.0' + classpath 'com.google.googlejavaformat:google-java-format:1.22.0' } }