Skip to content

Commit

Permalink
Merge pull request #201 from JetBrains-Research/code2seq_meta
Browse files Browse the repository at this point in the history
Code2seq metadata
  • Loading branch information
SpirinEgor authored Feb 11, 2022
2 parents 6b733bd + bf83283 commit 7afcc92
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 49 deletions.
8 changes: 8 additions & 0 deletions src/main/kotlin/astminer/common/model/PipelineModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package astminer.common.model
import astminer.storage.structurallyNormalized
import java.io.Closeable
import java.io.File
import java.nio.file.Path
import kotlin.io.path.Path

interface Filter

Expand Down Expand Up @@ -73,6 +75,12 @@ enum class DatasetHoldout(val dirName: String) {
Validation("val"),
Test("test"),
None("data");

fun createDir(parent: Path = Path("")): File {
val holdoutDir = parent.toFile().resolve(this.dirName)
holdoutDir.mkdirs()
return holdoutDir
}
}

/** Returns map with three entries (keys: train data pool, validation data pool and test data pool;
Expand Down
1 change: 1 addition & 0 deletions src/main/kotlin/astminer/config/PipelineConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ data class PipelineConfig(
val filters: List<FilterConfig> = emptyList(),
@SerialName("label") val labelExtractor: LabelExtractorConfig,
val storage: StorageConfig,
val collectMetadata: Boolean = false,
val numOfThreads: Int = 1,
val compressBeforeSaving: Boolean = false
) {
Expand Down
9 changes: 2 additions & 7 deletions src/main/kotlin/astminer/config/StorageConfigs.kt
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,10 @@ class DotAstStorageConfig : StorageConfig() {
*/
@Serializable
@SerialName("json AST")
class JsonAstStorageConfig(
private val withPaths: Boolean = false,
private val withRanges: Boolean = false
) : StorageConfig() {
class JsonAstStorageConfig : StorageConfig() {
override fun createStorage(outputDirectoryPath: String) =
JsonAstStorage(
outputDirectoryPath,
withPaths,
withRanges
outputDirectoryPath
)
}

Expand Down
25 changes: 22 additions & 3 deletions src/main/kotlin/astminer/pipeline/Pipeline.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import astminer.parse.getParsingResultFactory
import astminer.pipeline.branch.FilePipelineBranch
import astminer.pipeline.branch.FunctionPipelineBranch
import astminer.pipeline.branch.IllegalLabelExtractorException
import astminer.storage.MetaDataStorage
import me.tongfei.progressbar.ProgressBar
import java.io.File

Expand Down Expand Up @@ -46,20 +47,33 @@ class Pipeline(private val config: PipelineConfig) {

private fun parseLanguage(language: FileExtension) {
val parsingResultFactory = getParsingResultFactory(language, config.parser.name)
createStorage(language).use { storage ->
val storage = createStorage(language)
val metaStorage = if (config.collectMetadata) createMetaStorage(language) else null
try {
for ((holdoutType, holdoutDir) in holdoutMap) {
val holdoutFiles = getProjectFilesWithExtension(holdoutDir, language.fileExtension)
printHoldoutStat(holdoutFiles, holdoutType)
val progressBar = ProgressBar("", holdoutFiles.size.toLong())
parsingResultFactory.parseFilesInThreads(holdoutFiles, config.numOfThreads, inputDirectory.path) {

parsingResultFactory.parseFilesInThreads(
files = holdoutFiles,
numOfThreads = config.numOfThreads,
inputDirectoryPath = inputDirectory.path,
) {
val labeledResults = branch.process(it).let { results ->
if (config.compressBeforeSaving) { results.toStructurallyNormalized() } else { results }
}
storage.storeSynchronously(labeledResults, holdoutType)
synchronized(this) {
storage.store(labeledResults)
metaStorage?.store(labeledResults)
}
progressBar.step()
}
progressBar.close()
}
} finally {
storage.close()
metaStorage?.close()
}
}

Expand All @@ -68,6 +82,11 @@ class Pipeline(private val config: PipelineConfig) {
return config.storage.createStorage(storagePath)
}

private fun createMetaStorage(extension: FileExtension): MetaDataStorage {
val metaStoragePath = createStorageDirectory(extension).path
return MetaDataStorage(metaStoragePath)
}

private fun createStorageDirectory(extension: FileExtension): File {
val outputDirectoryForExtension = outputDirectory.resolve(extension.fileExtension)
outputDirectoryForExtension.mkdir()
Expand Down
44 changes: 44 additions & 0 deletions src/main/kotlin/astminer/storage/MetaDataStorage.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package astminer.storage

import astminer.common.model.*
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import java.io.PrintWriter
import kotlin.io.path.Path

@Serializable
data class TreeMetaData(val label: String, val path: String, val range: NodeRange?) {
constructor(labeledResult: LabeledResult<out Node>) : this(
labeledResult.label,
labeledResult.filePath,
labeledResult.root.range
)
}

class MetaDataStorage(override val outputDirectoryPath: String) : Storage {
private val metadataWriters = mutableMapOf<DatasetHoldout, PrintWriter>()

private fun DatasetHoldout.resolveHoldout(): PrintWriter {
val newOutputFile = this.createDir(Path(outputDirectoryPath)).resolve(METADATA_FILENAME)
newOutputFile.createNewFile()
return PrintWriter(newOutputFile.outputStream(), true)
}

private fun PrintWriter.writeMetadata(labeledResult: LabeledResult<out Node>) {
this.println(Json.encodeToString(TreeMetaData(labeledResult)))
}

override fun store(labeledResult: LabeledResult<out Node>, holdout: DatasetHoldout) {
val writer = metadataWriters.getOrPut(holdout) { holdout.resolveHoldout() }
writer.writeMetadata(labeledResult)
}

override fun close() {
metadataWriters.values.forEach { it.close() }
}

companion object {
const val METADATA_FILENAME = "metadata.jsonl"
}
}
2 changes: 1 addition & 1 deletion src/main/kotlin/astminer/storage/ast/CsvAstStorage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class CsvAstStorage(override val outputDirectoryPath: String) : Storage {
holdoutDir.mkdirs()
val astFile = holdoutDir.resolve("asts.csv")
astFile.createNewFile()
val newWriter = PrintWriter(astFile)
val newWriter = PrintWriter(astFile.outputStream(), true)
newWriter.println("id,ast")
return newWriter
}
Expand Down
23 changes: 10 additions & 13 deletions src/main/kotlin/astminer/storage/ast/JsonAstStorage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import java.io.File
import java.io.PrintWriter
import kotlin.io.path.Path

private typealias Id = Int

Expand All @@ -15,9 +16,7 @@ private typealias Id = Int
* Each tree is flattened and represented as a list of nodes.
*/
class JsonAstStorage(
override val outputDirectoryPath: String,
private val withPaths: Boolean,
private val withRanges: Boolean
override val outputDirectoryPath: String
) : Storage {
private val treeFlattener = TreeFlattener()

Expand All @@ -31,30 +30,26 @@ class JsonAstStorage(
@Serializable
private data class LabeledAst(
val label: String,
val path: String? = null,
val ast: List<OutputNode>
)

@Serializable
private data class OutputNode(
val token: String,
val typeLabel: String,
val range: NodeRange? = null,
val children: List<Id>
)

private fun TreeFlattener.EnumeratedNode.toOutputNode() =
OutputNode(
node.token.final(),
node.typeLabel,
if (withRanges) node.range else null,
children.map { it.id }
)

override fun store(labeledResult: LabeledResult<out Node>, holdout: DatasetHoldout) {
val outputNodes = treeFlattener.flatten(labeledResult.root).map { it.toOutputNode() }
val path = if (withPaths) labeledResult.filePath else null
val labeledAst = LabeledAst(labeledResult.label, path, outputNodes)
val labeledAst = LabeledAst(labeledResult.label, outputNodes)
val writer = datasetWriters.getOrPut(holdout) { holdout.resolveHoldout() }
writer.println(Json.encodeToString(labeledAst))
}
Expand All @@ -64,11 +59,13 @@ class JsonAstStorage(
}

private fun DatasetHoldout.resolveHoldout(): PrintWriter {
val holdoutDir = File(outputDirectoryPath).resolve(this.dirName)
holdoutDir.mkdirs()
val astFile = holdoutDir.resolve("asts.jsonl")
astFile.createNewFile()
return PrintWriter(astFile)
val newOutputFile = this.createDir(Path(outputDirectoryPath)).resolve(AST_FILENAME)
newOutputFile.createNewFile()
return PrintWriter(newOutputFile.outputStream(), true)
}

companion object {
const val AST_FILENAME = "asts.jsonl"
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/main/kotlin/astminer/storage/path/Code2VecPathStorage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import astminer.common.model.*
import astminer.common.storage.*
import java.io.File

class Code2VecPathStorage(outputDirectoryPath: String, private val config: PathBasedStorageConfig) :
class Code2VecPathStorage(
outputDirectoryPath: String,
private val config: PathBasedStorageConfig,
) :
PathBasedStorage(outputDirectoryPath, config) {

private val tokensMap: RankedIncrementalIdStorage<String> = RankedIncrementalIdStorage()
Expand Down
39 changes: 24 additions & 15 deletions src/main/kotlin/astminer/storage/path/PathBasedStorage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import astminer.paths.PathRetrievalSettings
import astminer.paths.toPathContext
import java.io.File
import java.io.PrintWriter
import kotlin.io.path.Path

/**
* Config for CountingPathStorage which contains all hyperparameters for path extraction.
Expand All @@ -29,22 +30,26 @@ data class PathBasedStorageConfig(
/**
* Base class for all path storages. Extracts paths from given LabellingResult and stores it in a specified format.
* @property outputDirectoryPath The path to the output directory.
* @property config The config that contains hyperparameters for path extraction.
* @property pathExtractionConfig The config that contains hyperparameters for path extraction.
* (for example enabling filepath storage)
*/
abstract class PathBasedStorage(
final override val outputDirectoryPath: String,
private val config: PathBasedStorageConfig,
private val pathExtractionConfig: PathBasedStorageConfig,
) : Storage {

private val pathMiner = PathMiner(PathRetrievalSettings(config.maxPathLength, config.maxPathWidth))
private val pathMiner = PathMiner(
PathRetrievalSettings(
pathExtractionConfig.maxPathLength,
pathExtractionConfig.maxPathWidth
)
)
private val datasetFileWriters = mutableMapOf<DatasetHoldout, PrintWriter>()

init {
File(outputDirectoryPath).mkdirs()
}
init { File(outputDirectoryPath).mkdirs() }

private fun retrievePaths(node: Node) = if (config.maxPathContextsPerEntity != null) {
pathMiner.retrievePaths(node).shuffled().take(config.maxPathContextsPerEntity)
private fun retrievePaths(node: Node) = if (pathExtractionConfig.maxPathContextsPerEntity != null) {
pathMiner.retrievePaths(node).shuffled().take(pathExtractionConfig.maxPathContextsPerEntity)
} else {
pathMiner.retrievePaths(node)
}
Expand All @@ -67,19 +72,23 @@ abstract class PathBasedStorage(
override fun store(labeledResult: LabeledResult<out Node>, holdout: DatasetHoldout) {
val labeledPathContexts = retrieveLabeledPathContexts(labeledResult)
val output = labeledPathContextsToString(labeledPathContexts)
val writer = datasetFileWriters.getOrPut(holdout) { holdout.resolveWriter() }
val writer = datasetFileWriters.getOrPut(holdout) { holdout.resolveDataWriter() }
writer.println(output)
}

override fun close() {
datasetFileWriters.values.map { it.close() }
}

private fun DatasetHoldout.resolveWriter(): PrintWriter {
val holdoutDir = File(outputDirectoryPath).resolve(this.dirName)
holdoutDir.mkdirs()
val pathContextFile = holdoutDir.resolve("path_contexts.c2s")
pathContextFile.createNewFile()
return PrintWriter(pathContextFile)
private fun DatasetHoldout.resolveWriter(outputFile: String): PrintWriter {
val newOutputFile = this.createDir(Path(outputDirectoryPath)).resolve(outputFile)
newOutputFile.createNewFile()
return PrintWriter(newOutputFile.outputStream(), true)
}

private fun DatasetHoldout.resolveDataWriter() = resolveWriter(PATH_CONTEXT_FILENAME)

companion object {
const val PATH_CONTEXT_FILENAME = "path_contexts.c2s"
}
}
Loading

0 comments on commit 7afcc92

Please sign in to comment.