Skip to content

Commit

Permalink
[Improve][Connector-V2] Flink support embedding transform (#7592)
Browse files Browse the repository at this point in the history
* [Improve][Connector-V2] Flink support embedding transform

* [Improve][Connector-V2] Optimized code
  • Loading branch information
corgy-w authored Sep 6, 2024
1 parent b4dbccf commit f7286b7
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/en/transform-v2/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ transform {
inputx = ["${input}"]
}
}
result_table_name = "embedding_output_3"
result_table_name = "embedding_output_1"
}
}
Expand Down
2 changes: 1 addition & 1 deletion docs/zh/transform-v2/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ transform {
inputx = ["${input}"]
}
}
result_table_name = "embedding_output_3"
result_table_name = "embedding_output_1"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK, EngineType.FLINK},
disabledReason = "Currently SPARK and FLINK not support adapt")
type = {EngineType.SPARK},
disabledReason = "Currently SPARK not support adapt")
public class TestEmbeddingIT extends TestSuiteBase implements TestResource {
private static final String TMP_DIR = "/tmp";
private GenericContainer<?> mockserverContainer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ transform {
}
custom_request_body ={
modelx = "${model}"
inputx = ["${input}","${input}"]
inputx = ["${input}"]
}
}
result_table_name = "embedding_output_3"
result_table_name = "embedding_output_1"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ public EmbeddingTransform(
config.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS));
}

private void tryOpen() {
if (model == null) {
open();
}
}

@Override
public void open() {
// Initialize model
Expand Down Expand Up @@ -161,6 +167,7 @@ private void initOutputFields(SeaTunnelRowType inputRowType, Map<String, String>

@Override
protected Object[] getOutputFieldValues(SeaTunnelRowAccessor inputRow) {
tryOpen();
try {
Object[] fieldArray = new Object[fieldOriginalIndexes.size()];
for (int i = 0; i < fieldOriginalIndexes.size(); i++) {
Expand Down

0 comments on commit f7286b7

Please sign in to comment.