Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci] Updates google code formatter to 1.22.0 #3149

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ static NDList convolution(
return input.getNDArrayInternal()
.convolution(input, weight, bias, stride, padding, dilation, groups);
}

/**
* A builder that can build any {@code Convolution} block.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ static NDList deconvolution(
return input.getNDArrayInternal()
.deconvolution(input, weight, bias, stride, padding, outPadding, dilation, groups);
}

/**
* A builder that can build any {@code Deconvolution} block.
*
Expand Down Expand Up @@ -246,6 +247,7 @@ public T optPadding(Shape padding) {
this.padding = padding;
return self();
}

/**
* Sets the out_padding along each dimension. Defaults to 0 along each dimension.
*
Expand All @@ -256,6 +258,7 @@ public T optOutPadding(Shape outPadding) {
this.outPadding = outPadding;
return self();
}

/**
* Sets the dilation along each dimension. Defaults to 1 along each dimension.
*
Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/nn/transformer/BertBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ public static final class Builder {
float hiddenDropoutProbability = 0.1f;
// float attentionDropoutProbability = 0.1f;
int maxSequenceLength = 512;

// float initializerRange = 0.02f;

private Builder() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,22 @@ public final class ScaledDotProductAttentionBlock extends AbstractBlock {

/** Size of the Word-/Token-embeddings we use the attention on. */
private int embeddingSize;

/** Number of attention heads. */
private int headCount;

/** Pointwise Linear projection of the keys. */
private Linear keyProjection;

/** Pointwise Linear projection of the queries. */
private Linear queryProjection;

/** Pointwise Linear projection of the values. */
private Linear valueProjection;

/** Pointwise Linear projection of the results. */
private Linear resultProjection;

/** Dropout operation to be applied after probability calculation. */
private Dropout attentionProbsDropout;

Expand Down Expand Up @@ -119,6 +125,7 @@ private Linear buildProjection() {
public Linear getKeyProjection() {
return keyProjection;
}

/**
* Pointwise Linear projection of the queries.
*
Expand All @@ -127,6 +134,7 @@ public Linear getKeyProjection() {
public Linear getQueryProjection() {
return queryProjection;
}

/**
* Pointwise Linear projection of the values.
*
Expand All @@ -135,6 +143,7 @@ public Linear getQueryProjection() {
public Linear getValueProjection() {
return valueProjection;
}

/**
* Pointwise Linear projection of the results.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ public class TransformerEncoderBlock extends AbstractBlock {

/** The attention mechanism. */
private ScaledDotProductAttentionBlock selfAttentionBlock;

/** Dropout before residual & layer normalization. */
private Dropout selfAttentionDropout;

/** Normalization of attention output and residual. */
private BatchNorm attentionNorm;

/** Fully connected pointwise block for output projection. */
private PointwiseFeedForwardBlock pointWisefullyConnected;

/** Dropout after fully connected and before last residual & layer normalization. */
private Dropout fullyConnectedDropout;

/** Another normalization for the output and residual. */
private BatchNorm outputNorm;

Expand Down
1 change: 1 addition & 0 deletions api/src/main/java/ai/djl/training/ParameterServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ default void update(String parameterId, NDArray[] params) {
update(parameterId, grads, params);
Arrays.stream(grads).forEach(NDArray::close);
}

/**
* Updates the parameter of a key from Parameter Server.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public void onEpoch(Trainer trainer) {
epochTime = System.nanoTime();
numEpochs++;
}

/** {@inheritDoc} */
@Override
public void onTrainingBegin(Trainer trainer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class MxOpParams extends PairList<String, Object> {

// mxnet cpu take index
private static final String MXNET_CPU = "cpu(0)";

/**
* Sets the Shape parameter.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,7 @@ public static Pointer detachGradient(Pointer handle) {
REFS.recycle(ref);
return pointer;
}

/*
int MXNDArraySetGradState(Pointer handle, int state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ public NDList layerNorm(
manager.from(beta),
eps));
}

/** {@inheritDoc} */
@Override
public NDList batchNorm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public class ProfilerTest {

@Test
public void testProfiler()
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {
try (NDManager manager = NDManager.newBaseManager()) {
ImageClassificationTranslator translator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ public NDList layerNorm(
NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
throw new UnsupportedOperationException();
}

/** {@inheritDoc} */
@Override
public NDList batchNorm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static Classifications[] predict(List<String> inputs)
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {
// refer to
// https://medium.com/delvify/bert-rest-inference-from-the-fine-tuned-model-499997b32851 and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ public static void main(String[] args) throws IOException, TranslateException, M
}

public static Classifications predict()
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {
String input = "I like DJL. DJL is the best DL framework!";
logger.info("input Sentence: {}", input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static float[][] predict(List<String> inputs)
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {
String modelUrl =
"https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/4.tar.gz";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ public static List<Joints> predict() throws IOException, ModelException, Transla
}

private static List<Image> predictPeopleInImage(Image img)
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {

Criteria<Image, DetectedObjects> criteria =
Expand Down Expand Up @@ -106,7 +108,9 @@ private static List<Image> predictPeopleInImage(Image img)
}

private static List<Joints> predictJointsForPeople(List<Image> people)
throws MalformedModelException, ModelNotFoundException, IOException,
throws MalformedModelException,
ModelNotFoundException,
IOException,
TranslateException {

// Use DJL MXNet model zoo model, model can be found:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ public static void main(String[] args) throws IOException, ModelException, Trans
}

public static Image transfer(Image image, Artist artist)
throws IOException, ModelNotFoundException, MalformedModelException,
throws IOException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
// Use DJL PyTorch model zoo model, model can be found:
// https://mlrepo.djl.ai/model/cv/image_generation/ai/djl/pytorch/cyclegan/0.0.1/style_xxxx.zip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ public final class RollingBatch {
private RollingBatch() {}

public static void main(String[] args)
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
String[] ret = seqBatchSchedulerWithPyTorchContrastive();
logger.info("{}", ret[0]);
}

public static String[] seqBatchSchedulerWithPyTorchContrastive()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
String url = "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_pt.zip";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ public final class TextGeneration {
private TextGeneration() {}

public static void main(String[] args)
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
String ret1 = generateTextWithPyTorchGreedy();
logger.info("{}", ret1);
Expand All @@ -53,7 +55,9 @@ public static void main(String[] args)
}

public static String generateTextWithPyTorchGreedy()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
Expand Down Expand Up @@ -88,7 +92,9 @@ public static String generateTextWithPyTorchGreedy()
}

public static String[] generateTextWithPyTorchContrastive()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
Expand Down Expand Up @@ -120,7 +126,9 @@ public static String[] generateTextWithPyTorchContrastive()
}

public static String[] generateTextWithPyTorchBeam()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
Expand Down Expand Up @@ -153,7 +161,9 @@ public static String[] generateTextWithPyTorchBeam()
}

public static String[] generateTextWithOnnxRuntimeBeam()
throws ModelNotFoundException, MalformedModelException, IOException,
throws ModelNotFoundException,
MalformedModelException,
IOException,
TranslateException {
SearchConfig config = new SearchConfig();
config.setMaxSeqLength(60);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,17 @@ public final class TrainSentimentAnalysis {
private TrainSentimentAnalysis() {}

public static void main(String[] args)
throws IOException, ModelNotFoundException, MalformedModelException,
throws IOException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
TrainSentimentAnalysis.runExample(args);
}

public static TrainingResult runExample(String[] args)
throws IOException, ModelNotFoundException, MalformedModelException,
throws IOException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
Arguments arguments = new Arguments().parseArgs(args);
if (arguments == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,19 @@ public final class TrainWithOptimizers {
private TrainWithOptimizers() {}

public static void main(String[] args)
throws IOException, ParseException, ModelNotFoundException, MalformedModelException,
throws IOException,
ParseException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
TrainWithOptimizers.runExample(args);
}

public static TrainingResult runExample(String[] args)
throws IOException, ParseException, ModelNotFoundException, MalformedModelException,
throws IOException,
ParseException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
OptimizerArguments arguments =
(OptimizerArguments) new OptimizerArguments().parseArgs(args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ public class StyleTransferTest {

@Test
public void testStyleTransfer()
throws IOException, ModelNotFoundException, MalformedModelException,
throws IOException,
ModelNotFoundException,
MalformedModelException,
TranslateException {
TestRequirements.engine("PyTorch");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ public void testTrainTextClassification() throws IOException {

@Test
public void testTextClassification()
throws IOException, MalformedModelException, ModelNotFoundException,
throws IOException,
MalformedModelException,
ModelNotFoundException,
TranslateException {
Criteria<String, Classifications> criteria =
Criteria.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public RsNDArray(RsNDManager manager, long handle) {
RsNDArray(RsNDManager manager, long handle, DataType dataType) {
this(manager, handle, dataType, null);
}

/**
* Constructs a Rust {@code NDArray} from a native handle (internal. Use {@link NDManager}
* instead) with the data that is hold on Java side.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ public NDList layerNorm(
NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
throw new UnsupportedOperationException("Not implemented");
}

/** {@inheritDoc} */
@Override
public NDList batchNorm(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ public void testWithIntermediate() throws TranslateException {

@Test
public void testLoadPredict()
throws IOException, ModelNotFoundException, TranslateException,
throws IOException,
ModelNotFoundException,
TranslateException,
MalformedModelException {
try (ZooModel<Image, Classifications> model = getModel()) {
NoopTranslator translator = new NoopTranslator(Batchifier.STACK);
Expand Down
Loading
Loading