From 71ac147f7f6b27ce9e4709782746ff8a9c5d4f23 Mon Sep 17 00:00:00 2001 From: Raymond Lam Date: Tue, 22 Sep 2020 11:10:48 -0700 Subject: [PATCH] Create index-based iterator for non-mutable map keySet and values access Create this change to fix UnsafeArrayData bug --- .../transport/spark/data/SparkMap.scala | 64 +++++++++++++++---- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala index 4859d92e..da16aa04 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala +++ b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkMap.scala @@ -31,17 +31,37 @@ case class SparkMap(private var _mapData: MapData, } override def keySet(): util.Set[StdData] = { - new util.AbstractSet[StdData] { + if (_mutableMap == null) { + new util.AbstractSet[StdData] { + + override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { + var offset : Int = 0 - override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { - private val keysIterator = if (_mutableMap == null) _mapData.keyArray().array.iterator else _mutableMap.keysIterator + override def next(): StdData = { + offset += 1 + SparkWrapper.createStdData(_mapData.keyArray().get(offset - 1, _keyType), _keyType) + } - override def next(): StdData = SparkWrapper.createStdData(keysIterator.next(), _keyType) + override def hasNext: Boolean = { + offset < SparkMap.this.size() + } + } - override def hasNext: Boolean = keysIterator.hasNext + override def size(): Int = SparkMap.this.size() } + } else { + new util.AbstractSet[StdData] { + + override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { + private val keysIterator = _mutableMap.keysIterator - override def size(): Int = SparkMap.this.size() + override def next(): StdData = SparkWrapper.createStdData(keysIterator.next(), _keyType) + + override def hasNext: Boolean = keysIterator.hasNext + } + + override def size(): Int = SparkMap.this.size() + } } } @@ -54,17 +74,37 @@ case class SparkMap(private var _mapData: MapData, } override def values(): util.Collection[StdData] = { - new util.AbstractCollection[StdData] { + if (_mutableMap == null) { + new util.AbstractCollection[StdData] { + + override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { + var offset : Int = 0 - override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { - private val valueIterator = if (_mutableMap == null) _mapData.valueArray().array.iterator else _mutableMap.valuesIterator + override def next(): StdData = { + offset += 1 + SparkWrapper.createStdData(_mapData.valueArray().get(offset - 1, _valueType), _valueType) + } - override def next(): StdData = SparkWrapper.createStdData(valueIterator.next(), _valueType) + override def hasNext: Boolean = { + offset < SparkMap.this.size() + } + } - override def hasNext: Boolean = valueIterator.hasNext + override def size(): Int = SparkMap.this.size() } + } else { + new util.AbstractCollection[StdData] { + + override def iterator(): util.Iterator[StdData] = new util.Iterator[StdData] { + private val valueIterator = _mutableMap.valuesIterator - override def size(): Int = SparkMap.this.size() + override def next(): StdData = SparkWrapper.createStdData(valueIterator.next(), _valueType) + + override def hasNext: Boolean = valueIterator.hasNext + } + + override def size(): Int = SparkMap.this.size() + } } }