Skip to content

Commit

Permalink
Use name column matching in CassandraRecordCursor
Browse files Browse the repository at this point in the history
This is needed for upcoming query pass-through function.
  • Loading branch information
ebyhr committed Feb 15, 2023
1 parent 2f13230 commit 6c10805
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.cql.Row;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.plugin.cassandra.CassandraType.Kind;
import io.trino.spi.connector.RecordCursor;
Expand All @@ -24,21 +25,27 @@

import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.plugin.cassandra.util.CassandraCqlUtils.validColumnName;
import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone;
import static java.lang.Float.floatToRawIntBits;
import static java.util.Objects.requireNonNull;

public class CassandraRecordCursor
implements RecordCursor
{
private final List<String> columnNames;
private final List<CassandraType> cassandraTypes;
private final CassandraTypeManager cassandraTypeManager;
private final ResultSet rs;
private Row currentRow;

public CassandraRecordCursor(CassandraSession cassandraSession, CassandraTypeManager cassandraTypeManager, List<CassandraType> cassandraTypes, String cql)
public CassandraRecordCursor(CassandraSession cassandraSession, CassandraTypeManager cassandraTypeManager, List<String> columnNames, List<CassandraType> cassandraTypes, String cql)
{
this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null"));
this.cassandraTypes = cassandraTypes;
checkArgument(columnNames.size() == cassandraTypes.size(), "columnNames and cassandraTypes sizes don't match");
this.cassandraTypeManager = cassandraTypeManager;
rs = cassandraSession.execute(cql);
currentRow = null;
Expand All @@ -63,7 +70,7 @@ public void close()
@Override
public boolean getBoolean(int i)
{
return currentRow.getBool(i);
return currentRow.getBool(validColumnName(columnNames.get(i)));
}

@Override
Expand All @@ -81,13 +88,14 @@ public long getReadTimeNanos()
@Override
public double getDouble(int i)
{
String columnName = validColumnName(columnNames.get(i));
switch (getCassandraType(i).getKind()) {
case DOUBLE:
return currentRow.getDouble(i);
return currentRow.getDouble(columnName);
case FLOAT:
return currentRow.getFloat(i);
return currentRow.getFloat(columnName);
case DECIMAL:
return currentRow.getBigDecimal(i).doubleValue();
return currentRow.getBigDecimal(columnName).doubleValue();
default:
throw new IllegalStateException("Cannot retrieve double for " + getCassandraType(i));
}
Expand All @@ -96,22 +104,23 @@ public double getDouble(int i)
@Override
public long getLong(int i)
{
String columnName = validColumnName(columnNames.get(i));
switch (getCassandraType(i).getKind()) {
case INT:
return currentRow.getInt(i);
return currentRow.getInt(columnName);
case SMALLINT:
return currentRow.getShort(i);
return currentRow.getShort(columnName);
case TINYINT:
return currentRow.getByte(i);
return currentRow.getByte(columnName);
case BIGINT:
case COUNTER:
return currentRow.getLong(i);
return currentRow.getLong(columnName);
case TIMESTAMP:
return packDateTimeWithZone(currentRow.getInstant(i).toEpochMilli(), TimeZoneKey.UTC_KEY);
return packDateTimeWithZone(currentRow.getInstant(columnName).toEpochMilli(), TimeZoneKey.UTC_KEY);
case DATE:
return currentRow.getLocalDate(i).toEpochDay();
return currentRow.getLocalDate(columnName).toEpochDay();
case FLOAT:
return floatToRawIntBits(currentRow.getFloat(i));
return floatToRawIntBits(currentRow.getFloat(columnName));
default:
throw new IllegalStateException("Cannot retrieve long for " + getCassandraType(i));
}
Expand All @@ -128,7 +137,7 @@ public Slice getSlice(int i)
if (getCassandraType(i).getKind() == Kind.TIMESTAMP) {
throw new IllegalArgumentException("Timestamp column can not be accessed with getSlice");
}
NullableValue value = cassandraTypeManager.getColumnValue(cassandraTypes.get(i), currentRow, i);
NullableValue value = cassandraTypeManager.getColumnValue(cassandraTypes.get(i), currentRow, currentRow.firstIndexOf(validColumnName(columnNames.get(i))));
if (value.getValue() instanceof Slice) {
return (Slice) value.getValue();
}
Expand All @@ -142,7 +151,7 @@ public Object getObject(int i)
switch (cassandraType.getKind()) {
case TUPLE:
case UDT:
return cassandraTypeManager.getColumnValue(cassandraType, currentRow, i).getValue();
return cassandraTypeManager.getColumnValue(cassandraType, currentRow, currentRow.firstIndexOf(validColumnName(columnNames.get(i)))).getValue();
default:
throw new IllegalArgumentException("getObject cannot be called for " + cassandraType);
}
Expand All @@ -157,9 +166,10 @@ public Type getType(int i)
@Override
public boolean isNull(int i)
{
String columnName = validColumnName(columnNames.get(i));
if (getCassandraType(i).getKind() == Kind.TIMESTAMP) {
return currentRow.getInstant(i) == null;
return currentRow.getInstant(columnName) == null;
}
return currentRow.isNull(i);
return currentRow.isNull(columnName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class CassandraRecordSet
private final CassandraSession cassandraSession;
private final CassandraTypeManager cassandraTypeManager;
private final String cql;
private final List<String> cassandraNames;
private final List<CassandraType> cassandraTypes;
private final List<Type> columnTypes;

Expand All @@ -39,6 +40,7 @@ public CassandraRecordSet(CassandraSession cassandraSession, CassandraTypeManage
this.cql = requireNonNull(cql, "cql is null");

requireNonNull(cassandraColumns, "cassandraColumns is null");
this.cassandraNames = transformList(cassandraColumns, CassandraColumnHandle::getName);
this.cassandraTypes = transformList(cassandraColumns, CassandraColumnHandle::getCassandraType);
this.columnTypes = transformList(cassandraColumns, CassandraColumnHandle::getType);
}
Expand All @@ -52,7 +54,7 @@ public List<Type> getColumnTypes()
@Override
public RecordCursor cursor()
{
return new CassandraRecordCursor(cassandraSession, cassandraTypeManager, cassandraTypes, cql);
return new CassandraRecordCursor(cassandraSession, cassandraTypeManager, cassandraNames, cassandraTypes, cql);
}

private static <T, R> List<R> transformList(List<T> list, Function<T, R> function)
Expand Down

0 comments on commit 6c10805

Please sign in to comment.