diff --git a/samza-api/src/main/java/org/apache/samza/application/StreamApplication.java b/samza-api/src/main/java/org/apache/samza/application/StreamApplication.java index eeece1007f89a..a26c5af7f38d0 100644 --- a/samza-api/src/main/java/org/apache/samza/application/StreamApplication.java +++ b/samza-api/src/main/java/org/apache/samza/application/StreamApplication.java @@ -24,11 +24,43 @@ /** - * This interface defines a template for stream application that user will implement to create operator DAG in {@link StreamGraph}. + * This interface defines a template for stream application that user will implement to initialize operator DAG in {@link StreamGraph}. + * + *

+ * User program implements {@link StreamApplication#init(StreamGraph, Config)} method to initialize the transformation logic + * from all input streams to output streams. A simple user code example is shown below: + *

+ * + *
{@code
+ * public class PageViewCounterExample implements StreamApplication {
+ *   // max timeout is 60 seconds
+ *   private static final MAX_TIMEOUT = 60000;
+ *
+ *   public void init(StreamGraph graph, Config config) {
+ *     MessageStream pageViewEvents = graph.getInputStream("pageViewEventStream", (k, m) -> (PageViewEvent) m);
+ *     OutputStream pageViewEventFilteredStream = graph
+ *       .getOutputStream("pageViewEventFiltered", m -> m.memberId, m -> m);
+ *
+ *     pageViewEvents
+ *       .filter(m -> !(m.getMessage().getEventTime() < System.currentTimeMillis() - MAX_TIMEOUT))
+ *       .sendTo(pageViewEventFilteredStream);
+ *   }
+ *
+ *   // local execution mode
+ *   public static void main(String[] args) {
+ *     CommandLine cmdLine = new CommandLine();
+ *     Config config = cmdLine.loadConfig(cmdLine.parser().parse(args));
+ *     PageViewCounterExample userApp = new PageViewCounterExample();
+ *     ApplicationRunner localRunner = ApplicationRunner.getLocalRunner(config);
+ *     localRunner.run(userApp);
+ *   }
+ *
+ * }
+ * }
+ * */ @InterfaceStability.Unstable public interface StreamApplication { - static final String APP_CLASS_CONFIG = "app.class"; /** * Users are required to implement this abstract method to initialize the processing logic of the application, in terms @@ -38,4 +70,5 @@ public interface StreamApplication { * @param config the {@link Config} of the application */ void init(StreamGraph graph, Config config); + } diff --git a/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java b/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java index 345bff02a205a..c406a933e354c 100644 --- a/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java +++ b/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java @@ -50,7 +50,7 @@ public interface MessageStream { * @param the type of messages in the transformed {@link MessageStream} * @return the transformed {@link MessageStream} */ - MessageStream map(MapFunction mapFn); + MessageStream map(MapFunction mapFn); /** * Applies the provided 1:n function to transform a message in this {@link MessageStream} @@ -60,7 +60,7 @@ public interface MessageStream { * @param the type of messages in the transformed {@link MessageStream} * @return the transformed {@link MessageStream} */ - MessageStream flatMap(FlatMapFunction flatMapFn); + MessageStream flatMap(FlatMapFunction flatMapFn); /** * Applies the provided function to messages in this {@link MessageStream} and returns the @@ -72,7 +72,7 @@ public interface MessageStream { * @param filterFn the predicate to filter messages from this {@link MessageStream} * @return the transformed {@link MessageStream} */ - MessageStream filter(FilterFunction filterFn); + MessageStream filter(FilterFunction filterFn); /** * Allows sending messages in this {@link MessageStream} to an output system using the provided {@link SinkFunction}. @@ -83,7 +83,7 @@ public interface MessageStream { * * @param sinkFn the function to send messages in this stream to an external system */ - void sink(SinkFunction sinkFn); + void sink(SinkFunction sinkFn); /** * Allows sending messages in this {@link MessageStream} to an output {@link MessageStream}. @@ -120,10 +120,10 @@ public interface MessageStream { * @param ttl the ttl for messages in each stream * @param the type of join key * @param the type of messages in the other stream - * @param the type of messages resulting from the {@code joinFn} + * @param the type of messages resulting from the {@code joinFn} * @return the joined {@link MessageStream} */ - MessageStream join(MessageStream otherStream, JoinFunction joinFn, Duration ttl); + MessageStream join(MessageStream otherStream, JoinFunction joinFn, Duration ttl); /** * Merge all {@code otherStreams} with this {@link MessageStream}. @@ -133,7 +133,7 @@ public interface MessageStream { * @param otherStreams other {@link MessageStream}s to be merged with this {@link MessageStream} * @return the merged {@link MessageStream} */ - MessageStream merge(Collection> otherStreams); + MessageStream merge(Collection> otherStreams); /** * Sends the messages of type {@code M}in this {@link MessageStream} to a repartitioned output stream and consumes @@ -144,6 +144,6 @@ public interface MessageStream { * @param the type of output message key and partition key * @return the repartitioned {@link MessageStream} */ - MessageStream partitionBy(Function keyExtractor); + MessageStream partitionBy(Function keyExtractor); } diff --git a/samza-api/src/main/java/org/apache/samza/operators/StreamGraph.java b/samza-api/src/main/java/org/apache/samza/operators/StreamGraph.java index ff1c58060a163..a03f7c32079c0 100644 --- a/samza-api/src/main/java/org/apache/samza/operators/StreamGraph.java +++ b/samza-api/src/main/java/org/apache/samza/operators/StreamGraph.java @@ -40,7 +40,7 @@ public interface StreamGraph { * @param the type of message in the input {@link MessageStream} * @return the input {@link MessageStream} */ - MessageStream getInputStream(String streamId, BiFunction msgBuilder); + MessageStream getInputStream(String streamId, BiFunction msgBuilder); /** * Gets the {@link OutputStream} corresponding to the logical {@code streamId}. @@ -54,7 +54,7 @@ public interface StreamGraph { * @return the output {@link MessageStream} */ OutputStream getOutputStream(String streamId, - Function keyExtractor, Function msgExtractor); + Function keyExtractor, Function msgExtractor); /** * Sets the {@link ContextManager} for this {@link StreamGraph}. diff --git a/samza-api/src/main/java/org/apache/samza/operators/windows/Windows.java b/samza-api/src/main/java/org/apache/samza/operators/windows/Windows.java index 9192fc1890b18..721b4c07389f5 100644 --- a/samza-api/src/main/java/org/apache/samza/operators/windows/Windows.java +++ b/samza-api/src/main/java/org/apache/samza/operators/windows/Windows.java @@ -119,11 +119,12 @@ private Windows() { } * @param the type of the key in the {@link Window} * @return the created {@link Window} function. */ - public static Window keyedTumblingWindow(Function keyFn, Duration interval, - Supplier initialValue, FoldLeftFunction foldFn) { + public static Window keyedTumblingWindow(Function keyFn, Duration interval, + Supplier initialValue, FoldLeftFunction foldFn) { Trigger defaultTrigger = new TimeTrigger<>(interval); - return new WindowInternal(defaultTrigger, initialValue, foldFn, keyFn, null, WindowType.TUMBLING); + return new WindowInternal<>(defaultTrigger, (Supplier) initialValue, (FoldLeftFunction) foldFn, + (Function) keyFn, null, WindowType.TUMBLING); } @@ -147,10 +148,10 @@ public static Window keyedTumblingWindow(Function key * @param the type of the key in the {@link Window} * @return the created {@link Window} function */ - public static Window> keyedTumblingWindow(Function keyFn, Duration interval) { + public static Window> keyedTumblingWindow(Function keyFn, Duration interval) { FoldLeftFunction> aggregator = createAggregator(); - Supplier> initialValue = () -> new ArrayList<>(); + Supplier> initialValue = ArrayList::new; return keyedTumblingWindow(keyFn, interval, initialValue, aggregator); } @@ -175,10 +176,11 @@ public static Window> keyedTumblingWindow(Function the type of the {@link WindowPane} output value * @return the created {@link Window} function */ - public static Window tumblingWindow(Duration duration, Supplier initialValue, - FoldLeftFunction foldFn) { + public static Window tumblingWindow(Duration duration, Supplier initialValue, + FoldLeftFunction foldFn) { Trigger defaultTrigger = Triggers.repeat(new TimeTrigger<>(duration)); - return new WindowInternal<>(defaultTrigger, initialValue, foldFn, null, null, WindowType.TUMBLING); + return new WindowInternal<>(defaultTrigger, (Supplier) initialValue, (FoldLeftFunction) foldFn, + null, null, WindowType.TUMBLING); } /** @@ -203,7 +205,7 @@ public static Window tumblingWindow(Duration duration, Supp public static Window> tumblingWindow(Duration duration) { FoldLeftFunction> aggregator = createAggregator(); - Supplier> initialValue = () -> new ArrayList<>(); + Supplier> initialValue = ArrayList::new; return tumblingWindow(duration, initialValue, aggregator); } @@ -235,10 +237,11 @@ public static Window> tumblingWindow(Duration duratio * @param the type of the output value in the {@link WindowPane} * @return the created {@link Window} function */ - public static Window keyedSessionWindow(Function keyFn, Duration sessionGap, - Supplier initialValue, FoldLeftFunction foldFn) { + public static Window keyedSessionWindow(Function keyFn, Duration sessionGap, + Supplier initialValue, FoldLeftFunction foldFn) { Trigger defaultTrigger = Triggers.timeSinceLastMessage(sessionGap); - return new WindowInternal<>(defaultTrigger, initialValue, foldFn, keyFn, null, WindowType.SESSION); + return new WindowInternal<>(defaultTrigger, (Supplier) initialValue, (FoldLeftFunction) foldFn, (Function) keyFn, + null, WindowType.SESSION); } /** @@ -265,11 +268,11 @@ public static Window keyedSessionWindow(Function keyF * @param the type of the key in the {@link Window} * @return the created {@link Window} function */ - public static Window> keyedSessionWindow(Function keyFn, Duration sessionGap) { + public static Window> keyedSessionWindow(Function keyFn, Duration sessionGap) { FoldLeftFunction> aggregator = createAggregator(); - Supplier> initialValue = () -> new ArrayList<>(); + Supplier> initialValue = ArrayList::new; return keyedSessionWindow(keyFn, sessionGap, initialValue, aggregator); } diff --git a/samza-api/src/main/java/org/apache/samza/task/TaskContext.java b/samza-api/src/main/java/org/apache/samza/task/TaskContext.java index 128cff1468e7e..dc5742f641679 100644 --- a/samza-api/src/main/java/org/apache/samza/task/TaskContext.java +++ b/samza-api/src/main/java/org/apache/samza/task/TaskContext.java @@ -58,10 +58,9 @@ public interface TaskContext { /** * Method to allow user to return customized context * - * @param the type of user-defined task context * @return user-defined task context object */ - default T getUserDefinedContext() { + default Object getUserDefinedContext() { return null; }; } diff --git a/samza-core/src/main/java/org/apache/samza/config/ApplicationConfig.java b/samza-core/src/main/java/org/apache/samza/config/ApplicationConfig.java index 1c0073535557e..9eb4161617e70 100644 --- a/samza-core/src/main/java/org/apache/samza/config/ApplicationConfig.java +++ b/samza-core/src/main/java/org/apache/samza/config/ApplicationConfig.java @@ -46,6 +46,7 @@ public class ApplicationConfig extends MapConfig { public static final String APP_COORDINATION_SERVICE_FACTORY_CLASS = "app.coordination.service.factory.class"; public static final String APP_NAME = "app.name"; public static final String APP_ID = "app.id"; + public static final String APP_CLASS = "app.class"; public ApplicationConfig(Config config) { super(config); @@ -67,6 +68,10 @@ public String getAppId() { return get(APP_ID, get(JobConfig.JOB_ID(), "1")); } + public String getAppClass() { + return get(APP_CLASS, null); + } + @Deprecated public String getProcessorId() { return get(PROCESSOR_ID, null); diff --git a/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java b/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java index dfe231ead6f91..69a41dbd7db6f 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java @@ -72,7 +72,7 @@ public MessageStreamImpl(StreamGraphImpl graph) { } @Override - public MessageStream map(MapFunction mapFn) { + public MessageStream map(MapFunction mapFn) { OperatorSpec op = OperatorSpecs.createMapOperatorSpec( mapFn, new MessageStreamImpl<>(this.graph), this.graph.getNextOpId()); this.registeredOperatorSpecs.add(op); @@ -80,7 +80,7 @@ public MessageStream map(MapFunction mapFn) { } @Override - public MessageStream filter(FilterFunction filterFn) { + public MessageStream filter(FilterFunction filterFn) { OperatorSpec op = OperatorSpecs.createFilterOperatorSpec( filterFn, new MessageStreamImpl<>(this.graph), this.graph.getNextOpId()); this.registeredOperatorSpecs.add(op); @@ -88,7 +88,7 @@ public MessageStream filter(FilterFunction filterFn) { } @Override - public MessageStream flatMap(FlatMapFunction flatMapFn) { + public MessageStream flatMap(FlatMapFunction flatMapFn) { OperatorSpec op = OperatorSpecs.createStreamOperatorSpec( flatMapFn, new MessageStreamImpl<>(this.graph), this.graph.getNextOpId()); this.registeredOperatorSpecs.add(op); @@ -96,7 +96,7 @@ public MessageStream flatMap(FlatMapFunction flatMapFn) { } @Override - public void sink(SinkFunction sinkFn) { + public void sink(SinkFunction sinkFn) { SinkOperatorSpec op = OperatorSpecs.createSinkOperatorSpec(sinkFn, this.graph.getNextOpId()); this.registeredOperatorSpecs.add(op); } @@ -110,22 +110,22 @@ public void sendTo(OutputStream outputStream) { @Override public MessageStream> window(Window window) { - OperatorSpec> wndOp = OperatorSpecs.createWindowOperatorSpec((WindowInternal) window, - new MessageStreamImpl<>(this.graph), this.graph.getNextOpId()); + OperatorSpec> wndOp = OperatorSpecs.createWindowOperatorSpec( + (WindowInternal) window, new MessageStreamImpl<>(this.graph), this.graph.getNextOpId()); this.registeredOperatorSpecs.add(wndOp); return wndOp.getNextStream(); } @Override - public MessageStream join( - MessageStream otherStream, JoinFunction joinFn, Duration ttl) { - MessageStreamImpl nextStream = new MessageStreamImpl<>(this.graph); + public MessageStream join( + MessageStream otherStream, JoinFunction joinFn, Duration ttl) { + MessageStreamImpl nextStream = new MessageStreamImpl<>(this.graph); - PartialJoinFunction thisPartialJoinFn = new PartialJoinFunction() { - private KeyValueStore> thisStreamState; + PartialJoinFunction thisPartialJoinFn = new PartialJoinFunction() { + private KeyValueStore> thisStreamState; @Override - public RM apply(M m, JM jm) { + public TM apply(M m, OM jm) { return joinFn.apply(m, jm); } @@ -148,21 +148,21 @@ public void init(Config config, TaskContext context) { } }; - PartialJoinFunction otherPartialJoinFn = new PartialJoinFunction() { - private KeyValueStore> otherStreamState; + PartialJoinFunction otherPartialJoinFn = new PartialJoinFunction() { + private KeyValueStore> otherStreamState; @Override - public RM apply(JM om, M m) { + public TM apply(OM om, M m) { return joinFn.apply(m, om); } @Override - public K getKey(JM message) { + public K getKey(OM message) { return joinFn.getSecondKey(message); } @Override - public KeyValueStore> getState() { + public KeyValueStore> getState() { return otherStreamState; } @@ -175,7 +175,7 @@ public void init(Config config, TaskContext taskContext) { this.registeredOperatorSpecs.add(OperatorSpecs.createPartialJoinOperatorSpec( thisPartialJoinFn, otherPartialJoinFn, ttl.toMillis(), nextStream, this.graph.getNextOpId())); - ((MessageStreamImpl) otherStream).registeredOperatorSpecs + ((MessageStreamImpl) otherStream).registeredOperatorSpecs .add(OperatorSpecs.createPartialJoinOperatorSpec( otherPartialJoinFn, thisPartialJoinFn, ttl.toMillis(), nextStream, this.graph.getNextOpId())); @@ -183,7 +183,7 @@ public void init(Config config, TaskContext taskContext) { } @Override - public MessageStream merge(Collection> otherStreams) { + public MessageStream merge(Collection> otherStreams) { MessageStreamImpl nextStream = new MessageStreamImpl<>(this.graph); otherStreams.add(this); @@ -193,7 +193,7 @@ public MessageStream merge(Collection> otherStreams) { } @Override - public MessageStream partitionBy(Function keyExtractor) { + public MessageStream partitionBy(Function keyExtractor) { int opId = this.graph.getNextOpId(); String opName = String.format("%s-%s", OperatorSpec.OpCode.PARTITION_BY.name().toLowerCase(), opId); MessageStreamImpl intermediateStream = diff --git a/samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java b/samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java index a49b68ebd2551..86ce6a4868eba 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/StreamGraphImpl.java @@ -61,16 +61,25 @@ public StreamGraphImpl(ApplicationRunner runner, Config config) { } @Override - public MessageStream getInputStream(String streamId, BiFunction msgBuilder) { + public MessageStream getInputStream(String streamId, BiFunction msgBuilder) { + if (msgBuilder == null) { + throw new IllegalArgumentException("msgBuilder can't be null for an input stream"); + } return inStreams.computeIfAbsent(runner.getStreamSpec(streamId), - streamSpec -> new InputStreamInternalImpl<>(this, streamSpec, msgBuilder)); + streamSpec -> new InputStreamInternalImpl<>(this, streamSpec, (BiFunction) msgBuilder)); } @Override public OutputStream getOutputStream(String streamId, - Function keyExtractor, Function msgExtractor) { + Function keyExtractor, Function msgExtractor) { + if (keyExtractor == null) { + throw new IllegalArgumentException("keyExtractor can't be null for an output stream."); + } + if (msgExtractor == null) { + throw new IllegalArgumentException("msgExtractor can't be null for an output stream."); + } return outStreams.computeIfAbsent(runner.getStreamSpec(streamId), - streamSpec -> new OutputStreamInternalImpl<>(this, streamSpec, keyExtractor, msgExtractor)); + streamSpec -> new OutputStreamInternalImpl<>(this, streamSpec, (Function) keyExtractor, (Function) msgExtractor)); } @Override @@ -95,16 +104,28 @@ public StreamGraph withContextManager(ContextManager contextManager) { * @return the intermediate {@link MessageStreamImpl} */ MessageStreamImpl getIntermediateStream(String streamName, - Function keyExtractor, Function msgExtractor, BiFunction msgBuilder) { + Function keyExtractor, Function msgExtractor, BiFunction msgBuilder) { String streamId = String.format("%s-%s-%s", config.get(JobConfig.JOB_NAME()), config.get(JobConfig.JOB_ID(), "1"), streamName); + if (msgBuilder == null) { + throw new IllegalArgumentException("msgBuilder cannot be null for an intermediate stream"); + } + + if (keyExtractor == null) { + throw new IllegalArgumentException("keyExtractor can't be null for an output stream."); + } + if (msgExtractor == null) { + throw new IllegalArgumentException("msgExtractor can't be null for an output stream."); + } + StreamSpec streamSpec = runner.getStreamSpec(streamId); IntermediateStreamInternalImpl intStream = (IntermediateStreamInternalImpl) inStreams .computeIfAbsent(streamSpec, - k -> new IntermediateStreamInternalImpl<>(this, streamSpec, keyExtractor, msgExtractor, msgBuilder)); + k -> new IntermediateStreamInternalImpl<>(this, streamSpec, (Function) keyExtractor, + (Function) msgExtractor, (BiFunction) msgBuilder)); outStreams.putIfAbsent(streamSpec, intStream); return intStream; } diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java b/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java index e2c4b9aeb94d3..0b93bbe1c1ad0 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java +++ b/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java @@ -53,7 +53,7 @@ private OperatorSpecs() {} * @return the {@link StreamOperatorSpec} */ public static StreamOperatorSpec createMapOperatorSpec( - MapFunction mapFn, MessageStreamImpl nextStream, int opId) { + MapFunction mapFn, MessageStreamImpl nextStream, int opId) { return new StreamOperatorSpec<>(new FlatMapFunction() { @Override public Collection apply(M message) { @@ -84,7 +84,7 @@ public void init(Config config, TaskContext context) { * @return the {@link StreamOperatorSpec} */ public static StreamOperatorSpec createFilterOperatorSpec( - FilterFunction filterFn, MessageStreamImpl nextStream, int opId) { + FilterFunction filterFn, MessageStreamImpl nextStream, int opId) { return new StreamOperatorSpec<>(new FlatMapFunction() { @Override public Collection apply(M message) { @@ -115,8 +115,8 @@ public void init(Config config, TaskContext context) { * @return the {@link StreamOperatorSpec} */ public static StreamOperatorSpec createStreamOperatorSpec( - FlatMapFunction transformFn, MessageStreamImpl nextStream, int opId) { - return new StreamOperatorSpec<>(transformFn, nextStream, OperatorSpec.OpCode.FLAT_MAP, opId); + FlatMapFunction transformFn, MessageStreamImpl nextStream, int opId) { + return new StreamOperatorSpec<>((FlatMapFunction) transformFn, nextStream, OperatorSpec.OpCode.FLAT_MAP, opId); } /** @@ -127,8 +127,8 @@ public static StreamOperatorSpec createStreamOperatorSpec( * @param type of input message * @return the {@link SinkOperatorSpec} for the sink operator */ - public static SinkOperatorSpec createSinkOperatorSpec(SinkFunction sinkFn, int opId) { - return new SinkOperatorSpec<>(sinkFn, OperatorSpec.OpCode.SINK, opId); + public static SinkOperatorSpec createSinkOperatorSpec(SinkFunction sinkFn, int opId) { + return new SinkOperatorSpec<>((SinkFunction) sinkFn, OperatorSpec.OpCode.SINK, opId); } /** @@ -195,7 +195,7 @@ public static WindowOperatorSpec createWindowOperatorSpec public static PartialJoinOperatorSpec createPartialJoinOperatorSpec( PartialJoinFunction thisPartialJoinFn, PartialJoinFunction otherPartialJoinFn, long ttlMs, MessageStreamImpl nextStream, int opId) { - return new PartialJoinOperatorSpec(thisPartialJoinFn, otherPartialJoinFn, ttlMs, nextStream, opId); + return new PartialJoinOperatorSpec<>(thisPartialJoinFn, otherPartialJoinFn, ttlMs, nextStream, opId); } /** @@ -207,7 +207,7 @@ public static PartialJoinOperatorSpec createPartial * @return the {@link StreamOperatorSpec} for the merge */ public static StreamOperatorSpec createMergeOperatorSpec(MessageStreamImpl nextStream, int opId) { - return new StreamOperatorSpec(message -> + return new StreamOperatorSpec<>(message -> new ArrayList() { { this.add(message); diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/StreamOperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/StreamOperatorSpec.java index 3c427c7a15571..f9bbe2d3b1227 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/spec/StreamOperatorSpec.java +++ b/samza-core/src/main/java/org/apache/samza/operators/spec/StreamOperatorSpec.java @@ -45,7 +45,7 @@ public class StreamOperatorSpec implements OperatorSpec { * @param opCode the {@link OpCode} for this {@link StreamOperatorSpec} * @param opId the unique id for this {@link StreamOperatorSpec} in a {@link org.apache.samza.operators.StreamGraph} */ - StreamOperatorSpec(FlatMapFunction transformFn, MessageStreamImpl nextStream, + StreamOperatorSpec(FlatMapFunction transformFn, MessageStreamImpl nextStream, OperatorSpec.OpCode opCode, int opId) { this.transformFn = transformFn; this.nextStream = nextStream; diff --git a/samza-core/src/main/java/org/apache/samza/operators/stream/IntermediateStreamInternalImpl.java b/samza-core/src/main/java/org/apache/samza/operators/stream/IntermediateStreamInternalImpl.java index a1bee6ab3167d..8f45f7a031648 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/stream/IntermediateStreamInternalImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/stream/IntermediateStreamInternalImpl.java @@ -33,8 +33,8 @@ public class IntermediateStreamInternalImpl extends MessageStreamImpl msgExtractor; private final BiFunction msgBuilder; - public IntermediateStreamInternalImpl(StreamGraphImpl graph, StreamSpec streamSpec, - Function keyExtractor, Function msgExtractor, BiFunction msgBuilder) { + public IntermediateStreamInternalImpl(StreamGraphImpl graph, StreamSpec streamSpec, Function keyExtractor, + Function msgExtractor, BiFunction msgBuilder) { super(graph); this.streamSpec = streamSpec; this.keyExtractor = keyExtractor; diff --git a/samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java b/samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java index 445d13e8480f0..6408e6f0418c7 100644 --- a/samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java +++ b/samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java @@ -19,6 +19,7 @@ package org.apache.samza.task; import org.apache.samza.SamzaException; +import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.Config; import org.apache.samza.config.ConfigException; import org.apache.samza.application.StreamApplication; @@ -158,19 +159,20 @@ private static void validateFactory(Object factory) { * @return {@link StreamApplication} instance */ public static StreamApplication createStreamApplication(Config config) { - if (config.get(StreamApplication.APP_CLASS_CONFIG) != null && !config.get(StreamApplication.APP_CLASS_CONFIG).isEmpty()) { + ApplicationConfig appConfig = new ApplicationConfig(config); + if (appConfig.getAppClass() != null && !appConfig.getAppClass().isEmpty()) { TaskConfig taskConfig = new TaskConfig(config); if (taskConfig.getTaskClass() != null && !taskConfig.getTaskClass().isEmpty()) { throw new ConfigException("High level StreamApplication API cannot be used together with low-level API using task.class."); } - String appClassName = config.get(StreamApplication.APP_CLASS_CONFIG); + String appClassName = appConfig.getAppClass(); try { Class builderClass = Class.forName(appClassName); return (StreamApplication) builderClass.newInstance(); } catch (Throwable t) { String errorMsg = String.format("Failed to create StreamApplication class from the config. %s = %s", - StreamApplication.APP_CLASS_CONFIG, config.get(StreamApplication.APP_CLASS_CONFIG)); + ApplicationConfig.APP_CLASS, appConfig.getAppClass()); log.error(errorMsg, t); throw new ConfigException(errorMsg, t); } diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java index c55fcd02b6770..b7f952ac02d2a 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestExecutionPlanner.java @@ -42,6 +42,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -116,8 +118,10 @@ private StreamGraphImpl createSimpleGraph() { * */ StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config); - OutputStream output1 = streamGraph.getOutputStream("output1", null, null); - streamGraph.getInputStream("input1", null) + Function mockFn = mock(Function.class); + OutputStream output1 = streamGraph.getOutputStream("output1", mockFn, mockFn); + BiFunction mockBuilder = mock(BiFunction.class); + streamGraph.getInputStream("input1", mockBuilder) .partitionBy(m -> "yes!!!").map(m -> m) .sendTo(output1); return streamGraph; @@ -137,11 +141,13 @@ private StreamGraphImpl createStreamGraphWithJoin() { */ StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config); - MessageStream m1 = streamGraph.getInputStream("input1", null).map(m -> m); - MessageStream m2 = streamGraph.getInputStream("input2", null).partitionBy(m -> "haha").filter(m -> true); - MessageStream m3 = streamGraph.getInputStream("input3", null).filter(m -> true).partitionBy(m -> "hehe").map(m -> m); - OutputStream output1 = streamGraph.getOutputStream("output1", null, null); - OutputStream output2 = streamGraph.getOutputStream("output2", null, null); + BiFunction msgBuilder = mock(BiFunction.class); + MessageStream m1 = streamGraph.getInputStream("input1", msgBuilder).map(m -> m); + MessageStream m2 = streamGraph.getInputStream("input2", msgBuilder).partitionBy(m -> "haha").filter(m -> true); + MessageStream m3 = streamGraph.getInputStream("input3", msgBuilder).filter(m -> true).partitionBy(m -> "hehe").map(m -> m); + Function mockFn = mock(Function.class); + OutputStream output1 = streamGraph.getOutputStream("output1", mockFn, mockFn); + OutputStream output2 = streamGraph.getOutputStream("output2", mockFn, mockFn); m1.join(m2, mock(JoinFunction.class), Duration.ofHours(2)).sendTo(output1); m3.join(m2, mock(JoinFunction.class), Duration.ofHours(1)).sendTo(output2); diff --git a/samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java b/samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java index 9f9945b91c611..c4ab9224ff9fb 100644 --- a/samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java +++ b/samza-core/src/test/java/org/apache/samza/execution/TestJobGraphJsonGenerator.java @@ -22,6 +22,9 @@ import java.time.Duration; import java.util.HashMap; import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; + import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.MapConfig; @@ -101,11 +104,13 @@ public void test() throws Exception { StreamManager streamManager = new StreamManager(systemAdmins); StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config); - MessageStream m1 = streamGraph.getInputStream("input1", null).map(m -> m); - MessageStream m2 = streamGraph.getInputStream("input2", null).partitionBy(m -> "haha").filter(m -> true); - MessageStream m3 = streamGraph.getInputStream("input3", null).filter(m -> true).partitionBy(m -> "hehe").map(m -> m); - OutputStream outputStream1 = streamGraph.getOutputStream("output1", null, null); - OutputStream outputStream2 = streamGraph.getOutputStream("output2", null, null); + BiFunction mockBuilder = mock(BiFunction.class); + MessageStream m1 = streamGraph.getInputStream("input1", mockBuilder).map(m -> m); + MessageStream m2 = streamGraph.getInputStream("input2", mockBuilder).partitionBy(m -> "haha").filter(m -> true); + MessageStream m3 = streamGraph.getInputStream("input3", mockBuilder).filter(m -> true).partitionBy(m -> "hehe").map(m -> m); + Function mockFn = mock(Function.class); + OutputStream outputStream1 = streamGraph.getOutputStream("output1", mockFn, mockFn); + OutputStream outputStream2 = streamGraph.getOutputStream("output2", mockFn, mockFn); m1.join(m2, mock(JoinFunction.class), Duration.ofHours(2)).sendTo(outputStream1); m3.join(m2, mock(JoinFunction.class), Duration.ofHours(1)).sendTo(outputStream2); diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java b/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java index e815b81260ebb..44870fd8ad632 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java +++ b/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java @@ -21,8 +21,7 @@ import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.MapConfig; -import org.apache.samza.operators.data.TestMessageEnvelope; -import org.apache.samza.operators.data.TestOutputMessageEnvelope; +import org.apache.samza.operators.data.*; import org.apache.samza.operators.functions.FilterFunction; import org.apache.samza.operators.functions.FlatMapFunction; import org.apache.samza.operators.functions.JoinFunction; @@ -36,6 +35,7 @@ import org.apache.samza.system.OutgoingMessageEnvelope; import org.apache.samza.system.SystemStream; import org.apache.samza.task.MessageCollector; +import org.apache.samza.task.TaskContext; import org.apache.samza.task.TaskCoordinator; import org.junit.Test; @@ -43,14 +43,15 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.HashSet; +import java.util.List; import java.util.Map; -import java.util.Set; import java.util.function.Function; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -72,7 +73,7 @@ public void testMap() { assertEquals(mapOp.getNextStream(), outputStream); // assert that the transformation function is what we defined above TestMessageEnvelope xTestMsg = mock(TestMessageEnvelope.class); - TestMessageEnvelope.MessageType mockInnerTestMessage = mock(TestMessageEnvelope.MessageType.class); + MessageType mockInnerTestMessage = mock(MessageType.class); when(xTestMsg.getKey()).thenReturn("test-msg-key"); when(xTestMsg.getMessage()).thenReturn(mockInnerTestMessage); when(mockInnerTestMessage.getValue()).thenReturn("123456789"); @@ -87,20 +88,74 @@ public void testMap() { @Test public void testFlatMap() { MessageStreamImpl inputStream = new MessageStreamImpl<>(mockGraph); - Set flatOuts = new HashSet() { { + List flatOuts = new ArrayList() { { this.add(mock(TestOutputMessageEnvelope.class)); this.add(mock(TestOutputMessageEnvelope.class)); this.add(mock(TestOutputMessageEnvelope.class)); } }; - FlatMapFunction xFlatMap = (TestMessageEnvelope message) -> flatOuts; + final List inputMsgs = new ArrayList<>(); + FlatMapFunction xFlatMap = (TestMessageEnvelope message) -> { + inputMsgs.add(message); + return flatOuts; + }; + MessageStream outputStream = inputStream.flatMap(xFlatMap); + Collection subs = inputStream.getRegisteredOperatorSpecs(); + assertEquals(subs.size(), 1); + OperatorSpec flatMapOp = subs.iterator().next(); + assertTrue(flatMapOp instanceof StreamOperatorSpec); + assertEquals(flatMapOp.getNextStream(), outputStream); + assertEquals(((StreamOperatorSpec) flatMapOp).getTransformFn(), xFlatMap); + + TestMessageEnvelope mockInput = mock(TestMessageEnvelope.class); + // assert that the transformation function is what we defined above + List result = (List) + ((StreamOperatorSpec) flatMapOp).getTransformFn().apply(mockInput); + assertEquals(flatOuts, result); + assertEquals(inputMsgs.size(), 1); + assertEquals(inputMsgs.get(0), mockInput); + } + + @Test + public void testFlatMapWithRelaxedTypes() { + MessageStreamImpl inputStream = new MessageStreamImpl<>(mockGraph); + List flatOuts = new ArrayList() { { + this.add(new TestExtOutputMessageEnvelope("output-key-1", 1, "output-id-001")); + this.add(new TestExtOutputMessageEnvelope("output-key-2", 2, "output-id-002")); + this.add(new TestExtOutputMessageEnvelope("output-key-3", 3, "output-id-003")); + } }; + + class MyFlatMapFunction implements FlatMapFunction { + public final List inputMsgs = new ArrayList<>(); + + @Override + public Collection apply(TestMessageEnvelope message) { + inputMsgs.add(message); + return flatOuts; + } + + @Override + public void init(Config config, TaskContext context) { + inputMsgs.clear(); + } + } + + MyFlatMapFunction xFlatMap = new MyFlatMapFunction(); + MessageStream outputStream = inputStream.flatMap(xFlatMap); Collection subs = inputStream.getRegisteredOperatorSpecs(); assertEquals(subs.size(), 1); OperatorSpec flatMapOp = subs.iterator().next(); assertTrue(flatMapOp instanceof StreamOperatorSpec); assertEquals(flatMapOp.getNextStream(), outputStream); + assertEquals(((StreamOperatorSpec) flatMapOp).getTransformFn(), xFlatMap); + + TestMessageEnvelope mockInput = mock(TestMessageEnvelope.class); // assert that the transformation function is what we defined above - assertEquals(((StreamOperatorSpec) flatMapOp).getTransformFn(), xFlatMap); + List result = (List) + ((StreamOperatorSpec) flatMapOp).getTransformFn().apply(mockInput); + assertEquals(flatOuts, result); + assertEquals(xFlatMap.inputMsgs.size(), 1); + assertEquals(xFlatMap.inputMsgs.get(0), mockInput); } @Test @@ -116,7 +171,7 @@ public void testFilter() { // assert that the transformation function is what we defined above FlatMapFunction txfmFn = ((StreamOperatorSpec) filterOp).getTransformFn(); TestMessageEnvelope mockMsg = mock(TestMessageEnvelope.class); - TestMessageEnvelope.MessageType mockInnerTestMessage = mock(TestMessageEnvelope.MessageType.class); + MessageType mockInnerTestMessage = mock(MessageType.class); when(mockMsg.getMessage()).thenReturn(mockInnerTestMessage); when(mockInnerTestMessage.getEventTime()).thenReturn(11111L); Collection output = txfmFn.apply(mockMsg); @@ -131,8 +186,9 @@ public void testFilter() { @Test public void testSink() { MessageStreamImpl inputStream = new MessageStreamImpl<>(mockGraph); + SystemStream testStream = new SystemStream("test-sys", "test-stream"); SinkFunction xSink = (TestMessageEnvelope m, MessageCollector mc, TaskCoordinator tc) -> { - mc.send(new OutgoingMessageEnvelope(new SystemStream("test-sys", "test-stream"), m.getMessage())); + mc.send(new OutgoingMessageEnvelope(testStream, m.getMessage())); tc.commit(TaskCoordinator.RequestScope.CURRENT_TASK); }; inputStream.sink(xSink); @@ -141,6 +197,21 @@ public void testSink() { OperatorSpec sinkOp = subs.iterator().next(); assertTrue(sinkOp instanceof SinkOperatorSpec); assertEquals(((SinkOperatorSpec) sinkOp).getSinkFn(), xSink); + + TestMessageEnvelope mockTest1 = mock(TestMessageEnvelope.class); + MessageType mockMsgBody = mock(MessageType.class); + when(mockTest1.getMessage()).thenReturn(mockMsgBody); + final List outMsgs = new ArrayList<>(); + MessageCollector mockCollector = mock(MessageCollector.class); + doAnswer(invocation -> { + outMsgs.add((OutgoingMessageEnvelope) invocation.getArguments()[0]); + return null; + }).when(mockCollector).send(any()); + TaskCoordinator mockCoordinator = mock(TaskCoordinator.class); + ((SinkOperatorSpec) sinkOp).getSinkFn().apply(mockTest1, mockCollector, mockCoordinator); + assertEquals(1, outMsgs.size()); + assertEquals(testStream, outMsgs.get(0).getSystemStream()); + assertEquals(mockMsgBody, outMsgs.get(0).getMessage()); } @Test @@ -189,14 +260,14 @@ public String getSecondKey(TestMessageEnvelope message) { @Test public void testMerge() { MessageStream merge1 = new MessageStreamImpl<>(mockGraph); - Collection> others = new ArrayList>() { { + Collection> others = new ArrayList>() { { this.add(new MessageStreamImpl<>(mockGraph)); this.add(new MessageStreamImpl<>(mockGraph)); } }; MessageStream mergeOutput = merge1.merge(others); validateMergeOperator(merge1, mergeOutput); - others.forEach(merge -> validateMergeOperator(merge, mergeOutput)); + others.forEach(merge -> validateMergeOperator((MessageStream) merge, mergeOutput)); } private void validateMergeOperator(MessageStream mergeSource, MessageStream mergeOutput) { diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphImpl.java b/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphImpl.java new file mode 100644 index 0000000000000..3ab1a3c442b58 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphImpl.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators; + +import org.apache.samza.config.Config; +import org.apache.samza.config.JobConfig; +import org.apache.samza.operators.data.MessageType; +import org.apache.samza.operators.data.TestInputMessageEnvelope; +import org.apache.samza.operators.data.TestMessageEnvelope; +import org.apache.samza.operators.stream.InputStreamInternalImpl; +import org.apache.samza.operators.stream.IntermediateStreamInternalImpl; +import org.apache.samza.operators.stream.OutputStreamInternalImpl; +import org.apache.samza.runtime.ApplicationRunner; +import org.apache.samza.system.StreamSpec; +import org.apache.samza.task.TaskContext; +import org.junit.Test; + +import java.util.function.BiFunction; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestStreamGraphImpl { + + @Test + public void testGetInputStream() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamSpec testStreamSpec = new StreamSpec("test-stream-1", "physical-stream-1", "test-system"); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(testStreamSpec); + + StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig); + BiFunction xMsgBuilder = + (k, v) -> new TestInputMessageEnvelope(k, v.getValue(), v.getEventTime(), "input-id-1"); + MessageStream mInputStream = graph.getInputStream("test-stream-1", xMsgBuilder); + assertEquals(graph.getInputStreams().get(testStreamSpec), mInputStream); + assertTrue(mInputStream instanceof InputStreamInternalImpl); + assertEquals(((InputStreamInternalImpl) mInputStream).getMsgBuilder(), xMsgBuilder); + + String key = "test-input-key"; + MessageType msgBody = new MessageType("test-msg-value", 333333L); + TestMessageEnvelope xInputMsg = ((InputStreamInternalImpl) mInputStream). + getMsgBuilder().apply(key, msgBody); + assertEquals(xInputMsg.getKey(), key); + assertEquals(xInputMsg.getMessage().getValue(), msgBody.getValue()); + assertEquals(xInputMsg.getMessage().getEventTime(), msgBody.getEventTime()); + assertEquals(((TestInputMessageEnvelope) xInputMsg).getInputId(), "input-id-1"); + } + + @Test + public void testGetOutputStream() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamSpec testStreamSpec = new StreamSpec("test-stream-1", "physical-stream-1", "test-system"); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(testStreamSpec); + + class MyMessageType extends MessageType { + public final String outputId; + + public MyMessageType(String value, long eventTime, String outputId) { + super(value, eventTime); + this.outputId = outputId; + } + } + + StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig); + Function xKeyExtractor = x -> x.getKey(); + Function xMsgExtractor = + x -> new MyMessageType(x.getMessage().getValue(), x.getMessage().getEventTime(), "test-output-id-1"); + + OutputStream mOutputStream = + graph.getOutputStream("test-stream-1", xKeyExtractor, xMsgExtractor); + assertEquals(graph.getOutputStreams().get(testStreamSpec), mOutputStream); + assertTrue(mOutputStream instanceof OutputStreamInternalImpl); + assertEquals(((OutputStreamInternalImpl) mOutputStream).getKeyExtractor(), xKeyExtractor); + assertEquals(((OutputStreamInternalImpl) mOutputStream).getMsgExtractor(), xMsgExtractor); + + TestInputMessageEnvelope xInputMsg = new TestInputMessageEnvelope("test-key-1", "test-msg-1", 33333L, "input-id-1"); + assertEquals(((OutputStreamInternalImpl) mOutputStream). + getKeyExtractor().apply(xInputMsg), "test-key-1"); + assertEquals(((OutputStreamInternalImpl) mOutputStream). + getMsgExtractor().apply(xInputMsg).getValue(), "test-msg-1"); + assertEquals(((OutputStreamInternalImpl) mOutputStream). + getMsgExtractor().apply(xInputMsg).getEventTime(), 33333L); + assertEquals(((OutputStreamInternalImpl) mOutputStream). + getMsgExtractor().apply(xInputMsg).outputId, "test-output-id-1"); + } + + @Test + public void testWithContextManager() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + + StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig); + + // ensure that default is noop + TaskContext mockContext = mock(TaskContext.class); + assertEquals(graph.getContextManager().initTaskContext(mockConfig, mockContext), mockContext); + + ContextManager testContextManager = new ContextManager() { + @Override + public TaskContext initTaskContext(Config config, TaskContext context) { + return null; + } + + @Override + public void finalizeTaskContext() { + + } + }; + + graph.withContextManager(testContextManager); + assertEquals(graph.getContextManager(), testContextManager); + } + + @Test + public void testGetIntermediateStream() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamSpec testStreamSpec = new StreamSpec("myJob-i001-test-stream-1", "physical-stream-1", "test-system"); + when(mockRunner.getStreamSpec("myJob-i001-test-stream-1")).thenReturn(testStreamSpec); + when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("myJob"); + when(mockConfig.get(JobConfig.JOB_ID(), "1")).thenReturn("i001"); + + class MyMessageType extends MessageType { + public final String outputId; + + public MyMessageType(String value, long eventTime, String outputId) { + super(value, eventTime); + this.outputId = outputId; + } + } + + StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig); + Function xKeyExtractor = x -> x.getKey(); + Function xMsgExtractor = + x -> new MyMessageType(x.getMessage().getValue(), x.getMessage().getEventTime(), "test-output-id-1"); + BiFunction xMsgBuilder = + (k, v) -> new TestInputMessageEnvelope(k, v.getValue(), v.getEventTime(), "input-id-1"); + + MessageStream mIntermediateStream = + graph.getIntermediateStream("test-stream-1", xKeyExtractor, xMsgExtractor, xMsgBuilder); + assertEquals(graph.getOutputStreams().get(testStreamSpec), mIntermediateStream); + assertTrue(mIntermediateStream instanceof IntermediateStreamInternalImpl); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream).getKeyExtractor(), xKeyExtractor); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream).getMsgExtractor(), xMsgExtractor); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream).getMsgBuilder(), xMsgBuilder); + + TestMessageEnvelope xInputMsg = new TestMessageEnvelope("test-key-1", "test-msg-1", 33333L); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream). + getKeyExtractor().apply(xInputMsg), "test-key-1"); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream). + getMsgExtractor().apply(xInputMsg).getValue(), "test-msg-1"); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream). + getMsgExtractor().apply(xInputMsg).getEventTime(), 33333L); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream). + getMsgBuilder().apply("test-key-1", new MyMessageType("test-msg-1", 33333L, "test-output-id-1")).getKey(), "test-key-1"); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream). + getMsgBuilder().apply("test-key-1", new MyMessageType("test-msg-1", 33333L, "test-output-id-1")).getMessage().getValue(), "test-msg-1"); + assertEquals(((IntermediateStreamInternalImpl) mIntermediateStream). + getMsgBuilder().apply("test-key-1", new MyMessageType("test-msg-1", 33333L, "test-output-id-1")).getMessage().getEventTime(), 33333L); + } + + @Test + public void testGetNextOpId() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + + StreamGraphImpl graph = new StreamGraphImpl(mockRunner, mockConfig); + assertEquals(graph.getNextOpId(), 0); + assertEquals(graph.getNextOpId(), 1); + } + +} diff --git a/samza-core/src/test/java/org/apache/samza/operators/data/MessageType.java b/samza-core/src/test/java/org/apache/samza/operators/data/MessageType.java new file mode 100644 index 0000000000000..3fd015b9d41ef --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/data/MessageType.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators.data; + +public class MessageType { + private final String value; + private final long eventTime; + + public MessageType(String value, long eventTime) { + this.value = value; + this.eventTime = eventTime; + } + + public long getEventTime() { + return eventTime; + } + + public String getValue() { + return value; + } +} diff --git a/samza-core/src/test/java/org/apache/samza/operators/data/TestExtOutputMessageEnvelope.java b/samza-core/src/test/java/org/apache/samza/operators/data/TestExtOutputMessageEnvelope.java new file mode 100644 index 0000000000000..22222ed0bba68 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/data/TestExtOutputMessageEnvelope.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators.data; + +public class TestExtOutputMessageEnvelope extends TestOutputMessageEnvelope { + private final String outputId; + + public TestExtOutputMessageEnvelope(String key, Integer value, String outputId) { + super(key, value); + this.outputId = outputId; + } + +} diff --git a/samza-core/src/test/java/org/apache/samza/operators/data/TestInputMessageEnvelope.java b/samza-core/src/test/java/org/apache/samza/operators/data/TestInputMessageEnvelope.java new file mode 100644 index 0000000000000..089f5349738b5 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/data/TestInputMessageEnvelope.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators.data; + +public class TestInputMessageEnvelope extends TestMessageEnvelope { + private final String inputId; + + public TestInputMessageEnvelope(String key, String value, long eventTime, String inputId) { + super(key, value, eventTime); + this.inputId = inputId; + } + + public String getInputId() { + return this.inputId; + } +} diff --git a/samza-core/src/test/java/org/apache/samza/operators/data/TestMessageEnvelope.java b/samza-core/src/test/java/org/apache/samza/operators/data/TestMessageEnvelope.java index 2524c28e623bb..05a63cd2a0faf 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/data/TestMessageEnvelope.java +++ b/samza-core/src/test/java/org/apache/samza/operators/data/TestMessageEnvelope.java @@ -37,21 +37,4 @@ public String getKey() { return this.key; } - public class MessageType { - private final String value; - private final long eventTime; - - public MessageType(String value, long eventTime) { - this.value = value; - this.eventTime = eventTime; - } - - public long getEventTime() { - return eventTime; - } - - public String getValue() { - return value; - } - } } diff --git a/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java b/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java index 37e3d1a70dd2c..d227206cb7842 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java +++ b/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java @@ -21,6 +21,8 @@ import org.apache.samza.operators.MessageStreamImpl; import org.apache.samza.operators.StreamGraphImpl; import org.apache.samza.operators.TestMessageStreamImplUtil; +import org.apache.samza.operators.data.MessageType; +import org.apache.samza.operators.data.TestInputMessageEnvelope; import org.apache.samza.operators.data.TestMessageEnvelope; import org.apache.samza.operators.data.TestOutputMessageEnvelope; import org.apache.samza.operators.functions.FlatMapFunction; @@ -31,39 +33,71 @@ import org.apache.samza.operators.windows.WindowPane; import org.apache.samza.operators.windows.internal.WindowInternal; import org.apache.samza.operators.windows.internal.WindowType; +import org.apache.samza.system.OutgoingMessageEnvelope; +import org.apache.samza.system.SystemStream; import org.apache.samza.task.MessageCollector; import org.apache.samza.task.TaskCoordinator; import org.junit.Test; import java.util.ArrayList; import java.util.Collection; +import java.util.List; import java.util.function.Function; import java.util.function.Supplier; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class TestOperatorSpecs { @Test public void testCreateStreamOperator() { - FlatMapFunction transformFn = m -> new ArrayList() { { + FlatMapFunction transformFn = m -> new ArrayList() { { this.add(new TestMessageEnvelope(m.toString(), m.toString(), 12345L)); } }; MessageStreamImpl mockOutput = mock(MessageStreamImpl.class); - StreamOperatorSpec streamOp = + StreamOperatorSpec streamOp = OperatorSpecs.createStreamOperatorSpec(transformFn, mockOutput, 1); assertEquals(streamOp.getTransformFn(), transformFn); + + Object mockInput = mock(Object.class); + when(mockInput.toString()).thenReturn("test-string-1"); + List outputs = (List) streamOp.getTransformFn().apply(mockInput); + assertEquals(outputs.size(), 1); + assertEquals(outputs.get(0).getKey(), "test-string-1"); + assertEquals(outputs.get(0).getMessage().getValue(), "test-string-1"); + assertEquals(outputs.get(0).getMessage().getEventTime(), 12345L); assertEquals(streamOp.getNextStream(), mockOutput); } @Test public void testCreateSinkOperator() { + SystemStream testStream = new SystemStream("test-sys", "test-stream"); SinkFunction sinkFn = (TestMessageEnvelope message, MessageCollector messageCollector, - TaskCoordinator taskCoordinator) -> { }; + TaskCoordinator taskCoordinator) -> { + messageCollector.send(new OutgoingMessageEnvelope(testStream, message.getKey(), message.getMessage())); + }; SinkOperatorSpec sinkOp = OperatorSpecs.createSinkOperatorSpec(sinkFn, 1); assertEquals(sinkOp.getSinkFn(), sinkFn); + + TestMessageEnvelope mockInput = mock(TestMessageEnvelope.class); + when(mockInput.getKey()).thenReturn("my-test-msg-key"); + MessageType mockMsgBody = mock(MessageType.class); + when(mockInput.getMessage()).thenReturn(mockMsgBody); + final List outputMsgs = new ArrayList<>(); + MessageCollector mockCollector = mock(MessageCollector.class); + doAnswer(invocation -> { + outputMsgs.add((OutgoingMessageEnvelope) invocation.getArguments()[0]); + return null; + }).when(mockCollector).send(any()); + sinkOp.getSinkFn().apply(mockInput, mockCollector, null); + assertEquals(1, outputMsgs.size()); + assertEquals(outputMsgs.get(0).getKey(), "my-test-msg-key"); + assertEquals(outputMsgs.get(0).getMessage(), mockMsgBody); assertEquals(sinkOp.getOpCode(), OperatorSpec.OpCode.SINK); assertEquals(sinkOp.getNextStream(), null); } @@ -103,6 +137,27 @@ public void testCreateWindowOperator() throws Exception { assertEquals(spec.getWindow().getFoldLeftFunction(), aggregator); } + @Test + public void testCreateWindowOperatorWithRelaxedTypes() throws Exception { + Function keyExtractor = m -> m.getKey(); + FoldLeftFunction aggregator = (m, c) -> c + 1; + Supplier initialValue = () -> 0; + //instantiate a window using reflection + WindowInternal window = new WindowInternal(null, initialValue, aggregator, keyExtractor, null, WindowType.TUMBLING); + + MessageStreamImpl> mockWndOut = mock(MessageStreamImpl.class); + WindowOperatorSpec spec = + OperatorSpecs.createWindowOperatorSpec(window, mockWndOut, 1); + assertEquals(spec.getWindow(), window); + assertEquals(spec.getWindow().getKeyExtractor(), keyExtractor); + assertEquals(spec.getWindow().getFoldLeftFunction(), aggregator); + + // make sure that the functions with relaxed types work as expected + TestInputMessageEnvelope inputMsg = new TestInputMessageEnvelope("test-input-key1", "test-value-1", 23456L, "input-id-1"); + assertEquals("test-input-key1", spec.getWindow().getKeyExtractor().apply(inputMsg)); + assertEquals(1, spec.getWindow().getFoldLeftFunction().apply(inputMsg, 0)); + } + @Test public void testCreatePartialJoinOperator() { PartialJoinFunction thisPartialJoinFn diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java index 0b051e8c3a8e3..e3009968276e5 100644 --- a/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java +++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java @@ -19,6 +19,7 @@ package org.apache.samza.task; import org.apache.samza.SamzaException; +import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.Config; import org.apache.samza.config.ConfigException; import org.apache.samza.config.MapConfig; @@ -74,7 +75,7 @@ public void testStreamTaskClass() { public void testCreateStreamApplication() throws Exception { Config config = new MapConfig(new HashMap() { { - this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.TestStreamApplication"); + this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.TestStreamApplication"); } }); StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config); @@ -85,7 +86,7 @@ public void testCreateStreamApplication() throws Exception { config = new MapConfig(new HashMap() { { - this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.InvalidStreamApplication"); + this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication"); } }); try { @@ -97,7 +98,7 @@ public void testCreateStreamApplication() throws Exception { config = new MapConfig(new HashMap() { { - this.put(StreamApplication.APP_CLASS_CONFIG, "no.such.class"); + this.put(ApplicationConfig.APP_CLASS, "no.such.class"); } }); try { @@ -109,7 +110,7 @@ public void testCreateStreamApplication() throws Exception { config = new MapConfig(new HashMap() { { - this.put(StreamApplication.APP_CLASS_CONFIG, ""); + this.put(ApplicationConfig.APP_CLASS, ""); } }); streamApp = TaskFactoryUtil.createStreamApplication(config); @@ -124,7 +125,7 @@ public void testCreateStreamApplication() throws Exception { public void testCreateStreamApplicationWithTaskClass() throws Exception { Config config = new MapConfig(new HashMap() { { - this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.TestStreamApplication"); + this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.TestStreamApplication"); } }); StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config); @@ -133,7 +134,7 @@ public void testCreateStreamApplicationWithTaskClass() throws Exception { config = new MapConfig(new HashMap() { { this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask"); - this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.TestStreamApplication"); + this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.TestStreamApplication"); } }); try { @@ -146,7 +147,7 @@ public void testCreateStreamApplicationWithTaskClass() throws Exception { config = new MapConfig(new HashMap() { { this.put("task.class", "no.such.class"); - this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.TestStreamApplication"); + this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.TestStreamApplication"); } }); try { @@ -162,7 +163,7 @@ public void testStreamTaskClassWithInvalidStreamApplication() throws Exception { Config config = new MapConfig(new HashMap() { { - this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.InvalidStreamApplication"); + this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication"); } }); try { @@ -175,7 +176,7 @@ public void testStreamTaskClassWithInvalidStreamApplication() throws Exception { config = new MapConfig(new HashMap() { { this.put("task.class", "org.apache.samza.testUtils.TestStreamTask"); - this.put(StreamApplication.APP_CLASS_CONFIG, ""); + this.put(ApplicationConfig.APP_CLASS, ""); } }); StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config); @@ -186,7 +187,7 @@ public void testStreamTaskClassWithInvalidStreamApplication() throws Exception { config = new MapConfig(new HashMap() { { this.put("task.class", ""); - this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.InvalidStreamApplication"); + this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication"); } }); try { @@ -226,7 +227,7 @@ public void testAsyncStreamTaskWithInvalidStreamGraphBuilder() throws Exception Config config = new MapConfig(new HashMap() { { - this.put(StreamApplication.APP_CLASS_CONFIG, "org.apache.samza.testUtils.InvalidStreamApplication"); + this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication"); } }); try { @@ -239,7 +240,7 @@ public void testAsyncStreamTaskWithInvalidStreamGraphBuilder() throws Exception config = new MapConfig(new HashMap() { { this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask"); - this.put(StreamApplication.APP_CLASS_CONFIG, ""); + this.put(ApplicationConfig.APP_CLASS, ""); } }); StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config); @@ -250,7 +251,7 @@ public void testAsyncStreamTaskWithInvalidStreamGraphBuilder() throws Exception config = new MapConfig(new HashMap() { { this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask"); - this.put(StreamApplication.APP_CLASS_CONFIG, null); + this.put(ApplicationConfig.APP_CLASS, null); } }); streamApp = TaskFactoryUtil.createStreamApplication(config); diff --git a/samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala b/samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala index cda2690785fca..29fb6d3f6e07f 100644 --- a/samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala +++ b/samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala @@ -119,7 +119,7 @@ object StreamTaskTestUtil { servers = configs.map(TestUtils.createServer(_)).toBuffer val brokerList = TestUtils.getBrokerListStrFromServers(servers, SecurityProtocol.PLAINTEXT) - brokers = brokerList.split(",").map(p => "localhost" + p).mkString(",") + brokers = brokerList.split(",").map(p => "127.0.0.1" + p).mkString(",") // setup the zookeeper and bootstrap servers for local kafka cluster jobConfig ++= Map("systems.kafka.consumer.zookeeper.connect" -> zkConnect,