From 0c43e297b8c6930b16f04c87ef6eb75932df1ab5 Mon Sep 17 00:00:00 2001 From: Pierre-Yves Ricau Date: Thu, 25 Mar 2021 11:58:54 -0700 Subject: [PATCH] Fix thread safety crash This allows position to be thread local. Fixes #2084 --- .../java/shark/internal/ClassFieldsReader.kt | 229 ++++++++++-------- 1 file changed, 123 insertions(+), 106 deletions(-) diff --git a/shark-graph/src/main/java/shark/internal/ClassFieldsReader.kt b/shark-graph/src/main/java/shark/internal/ClassFieldsReader.kt index a53b2e6762..b6de72a2a4 100644 --- a/shark-graph/src/main/java/shark/internal/ClassFieldsReader.kt +++ b/shark-graph/src/main/java/shark/internal/ClassFieldsReader.kt @@ -28,143 +28,160 @@ internal class ClassFieldsReader( private val classFieldBytes: ByteArray ) { - private var position = 0 - fun classDumpStaticFields(indexedClass: IndexedClass): List { - position = indexedClass.fieldsIndex - val staticFieldCount = readUnsignedShort() - val staticFields = ArrayList(staticFieldCount) - for (i in 0 until staticFieldCount) { - val nameStringId = readId() - val type = readUnsignedByte() - val value = readValue(type) - staticFields.add( - StaticFieldRecord( - nameStringId = nameStringId, - type = type, - value = value + return read(initialPosition = indexedClass.fieldsIndex) { + val staticFieldCount = readUnsignedShort() + val staticFields = ArrayList(staticFieldCount) + for (i in 0 until staticFieldCount) { + val nameStringId = readId() + val type = readUnsignedByte() + val value = readValue(type) + staticFields.add( + StaticFieldRecord( + nameStringId = nameStringId, + type = type, + value = value + ) ) - ) + } + staticFields } - return staticFields } fun classDumpFields(indexedClass: IndexedClass): List { - position = indexedClass.fieldsIndex - - skipStaticFields() + return read(initialPosition = indexedClass.fieldsIndex) { + skipStaticFields() - val fieldCount = readUnsignedShort() - val fields = ArrayList(fieldCount) - for (i in 0 until fieldCount) { - fields.add(FieldRecord(nameStringId = readId(), type = readUnsignedByte())) + val fieldCount = readUnsignedShort() + val fields = ArrayList(fieldCount) + for (i in 0 until fieldCount) { + fields.add(FieldRecord(nameStringId = readId(), type = readUnsignedByte())) + } + fields } - return fields } fun classDumpHasReferenceFields(indexedClass: IndexedClass): Boolean { - position = indexedClass.fieldsIndex - skipStaticFields() - val fieldCount = readUnsignedShort() - for (i in 0 until fieldCount) { - position += identifierByteSize - val type = readUnsignedByte() - if (type == PrimitiveType.REFERENCE_HPROF_TYPE) { - return true + return read(initialPosition = indexedClass.fieldsIndex) { + skipStaticFields() + val fieldCount = readUnsignedShort() + for (i in 0 until fieldCount) { + position += identifierByteSize + val type = readUnsignedByte() + if (type == PrimitiveType.REFERENCE_HPROF_TYPE) { + return@read true + } } + return@read false } - return false } - private fun skipStaticFields() { - val staticFieldCount = readUnsignedShort() - for (i in 0 until staticFieldCount) { - position += identifierByteSize - val type = readUnsignedByte() - position += if (type == PrimitiveType.REFERENCE_HPROF_TYPE) { - identifierByteSize - } else { - PrimitiveType.byteSizeByHprofType.getValue(type) + private val readInFlightThreadLocal = object : ThreadLocal() { + override fun initialValue() = ReadInFlight() + } + + private fun read( + initialPosition: Int, + block: ReadInFlight.() -> R + ): R { + val readInFlight = readInFlightThreadLocal.get() + readInFlight.position = initialPosition + return readInFlight.run(block) + } + + private inner class ReadInFlight { + var position = 0 + + fun skipStaticFields() { + val staticFieldCount = readUnsignedShort() + for (i in 0 until staticFieldCount) { + position += identifierByteSize + val type = readUnsignedByte() + position += if (type == PrimitiveType.REFERENCE_HPROF_TYPE) { + identifierByteSize + } else { + PrimitiveType.byteSizeByHprofType.getValue(type) + } } } - } - private fun readValue(type: Int): ValueHolder { - return when (type) { - PrimitiveType.REFERENCE_HPROF_TYPE -> ReferenceHolder(readId()) - BOOLEAN_TYPE -> BooleanHolder(readBoolean()) - CHAR_TYPE -> CharHolder(readChar()) - FLOAT_TYPE -> FloatHolder(readFloat()) - DOUBLE_TYPE -> DoubleHolder(readDouble()) - BYTE_TYPE -> ByteHolder(readByte()) - SHORT_TYPE -> ShortHolder(readShort()) - INT_TYPE -> IntHolder(readInt()) - LONG_TYPE -> LongHolder(readLong()) - else -> throw IllegalStateException("Unknown type $type") + fun readValue(type: Int): ValueHolder { + return when (type) { + PrimitiveType.REFERENCE_HPROF_TYPE -> ReferenceHolder(readId()) + BOOLEAN_TYPE -> BooleanHolder(readBoolean()) + CHAR_TYPE -> CharHolder(readChar()) + FLOAT_TYPE -> FloatHolder(readFloat()) + DOUBLE_TYPE -> DoubleHolder(readDouble()) + BYTE_TYPE -> ByteHolder(readByte()) + SHORT_TYPE -> ShortHolder(readShort()) + INT_TYPE -> IntHolder(readInt()) + LONG_TYPE -> LongHolder(readLong()) + else -> throw IllegalStateException("Unknown type $type") + } } - } - private fun readByte(): Byte { - return classFieldBytes[position++] - } + fun readByte(): Byte { + return classFieldBytes[position++] + } - private fun readInt(): Int { - return (classFieldBytes[position++].toInt() and 0xff shl 24) or - (classFieldBytes[position++].toInt() and 0xff shl 16) or - (classFieldBytes[position++].toInt() and 0xff shl 8) or - (classFieldBytes[position++].toInt() and 0xff) - } + fun readInt(): Int { + return (classFieldBytes[position++].toInt() and 0xff shl 24) or + (classFieldBytes[position++].toInt() and 0xff shl 16) or + (classFieldBytes[position++].toInt() and 0xff shl 8) or + (classFieldBytes[position++].toInt() and 0xff) + } - private fun readLong(): Long { - return (classFieldBytes[position++].toLong() and 0xff shl 56) or - (classFieldBytes[position++].toLong() and 0xff shl 48) or - (classFieldBytes[position++].toLong() and 0xff shl 40) or - (classFieldBytes[position++].toLong() and 0xff shl 32) or - (classFieldBytes[position++].toLong() and 0xff shl 24) or - (classFieldBytes[position++].toLong() and 0xff shl 16) or - (classFieldBytes[position++].toLong() and 0xff shl 8) or - (classFieldBytes[position++].toLong() and 0xff) - } + fun readLong(): Long { + return (classFieldBytes[position++].toLong() and 0xff shl 56) or + (classFieldBytes[position++].toLong() and 0xff shl 48) or + (classFieldBytes[position++].toLong() and 0xff shl 40) or + (classFieldBytes[position++].toLong() and 0xff shl 32) or + (classFieldBytes[position++].toLong() and 0xff shl 24) or + (classFieldBytes[position++].toLong() and 0xff shl 16) or + (classFieldBytes[position++].toLong() and 0xff shl 8) or + (classFieldBytes[position++].toLong() and 0xff) + } - private fun readShort(): Short { - return ((classFieldBytes[position++].toInt() and 0xff shl 8) or - (classFieldBytes[position++].toInt() and 0xff)).toShort() - } + fun readShort(): Short { + return ((classFieldBytes[position++].toInt() and 0xff shl 8) or + (classFieldBytes[position++].toInt() and 0xff)).toShort() + } - private fun readUnsignedShort(): Int { - return readShort().toInt() and 0xFFFF - } + fun readUnsignedShort(): Int { + return readShort().toInt() and 0xFFFF + } - private fun readUnsignedByte(): Int { - return readByte().toInt() and 0xFF - } + fun readUnsignedByte(): Int { + return readByte().toInt() and 0xFF + } - private fun readId(): Long { - // As long as we don't interpret IDs, reading signed values here is fine. - return when (identifierByteSize) { - 1 -> readByte().toLong() - 2 -> readShort().toLong() - 4 -> readInt().toLong() - 8 -> readLong() - else -> throw IllegalArgumentException("ID Length must be 1, 2, 4, or 8") + fun readId(): Long { + // As long as we don't interpret IDs, reading signed values here is fine. + return when (identifierByteSize) { + 1 -> readByte().toLong() + 2 -> readShort().toLong() + 4 -> readInt().toLong() + 8 -> readLong() + else -> throw IllegalArgumentException("ID Length must be 1, 2, 4, or 8") + } } - } - private fun readBoolean(): Boolean { - return readByte() - .toInt() != 0 - } + fun readBoolean(): Boolean { + return readByte() + .toInt() != 0 + } - private fun readChar(): Char { - return readShort().toChar() - } + fun readChar(): Char { + return readShort().toChar() + } - private fun readFloat(): Float { - return Float.fromBits(readInt()) - } + fun readFloat(): Float { + return Float.fromBits(readInt()) + } - private fun readDouble(): Double { - return Double.fromBits(readLong()) + fun readDouble(): Double { + return Double.fromBits(readLong()) + } } companion object {