Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve dataset structure #175

Merged
merged 10 commits into from
Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/storages.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
The storage defines how the ASTs should be saved on disk.
For now, `astminer` support tree-based and path-based storage formats.

`Astminer` also knows how to find the structure of the dataset and can
save trees or path contexts in the appropriate holdout folders. (`train`, `val` and `test`). If the data is not structured,
all trees will be saved in the `data` folder. Description files for trees or paths will be
saved along with holdouts in the same `outputPath` directory.

Storage config classes are defined in [StorageConfigs.kt](../src/main/kotlin/astminer/config/StorageConfigs.kt).

## Tree formats
Expand Down Expand Up @@ -47,6 +52,7 @@ Extract paths from each AST. Output is 4 files:
2. `tokens.csv` contains numeric ids and corresponding tokens;
3. `paths.csv` contains numeric ids and AST paths in form of space-separated sequences of node type ids;
4. `path_contexts.c2s` contains the labels and sequences of path-contexts (each representing two tokens and a path between them).
This file will be generated for every holdout.

Each line in `path_contexts.c2s` starts with a label, followed by a sequence of space-separated triples. Each triple contains start token id, path id, end token id, separated with commas.

Expand All @@ -63,7 +69,7 @@ Each line in `path_contexts.c2s` starts with a label, followed by a sequence of
### Code2seq

Extract paths from each AST and save in the code2seq format.
The output is `path_context.c2s` file.
The output is `path_context.c2s` file, which will be generated for every holdout.
Each line starts with a label, followed by a sequence of space-separated triples.
Each triple contains the start token, path node types, and end token id, separated with commas.

Expand Down
2 changes: 2 additions & 0 deletions src/main/kotlin/astminer/common/model/ParsingResultModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ interface ParsingResultFactory {
val results = mutableListOf<T?>()
val threads = mutableListOf<Thread>()

if (files.isEmpty()) { return emptyList() }

synchronized(results) {
files.chunked(ceil(files.size.toDouble() / numOfThreads).toInt()).filter { it.isNotEmpty() }
.map { chunk ->
Expand Down
35 changes: 32 additions & 3 deletions src/main/kotlin/astminer/common/model/PipelineModel.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package astminer.common.model

import java.io.Closeable
import java.io.File

interface Filter

Expand Down Expand Up @@ -40,11 +41,39 @@ fun <T : Node> ParsingResult<T>.labeledWith(label: String): LabeledResult<T> = L
interface Storage : Closeable {
val outputDirectoryPath: String

fun store(labeledResult: LabeledResult<out Node>)
fun store(labeledResult: LabeledResult<out Node>, holdout: DatasetHoldout = DatasetHoldout.None)

fun store(labeledResults: Iterable<LabeledResult<out Node>>) {
fun store(labeledResults: Iterable<LabeledResult<out Node>>, holdout: DatasetHoldout = DatasetHoldout.None) {
for (labeledResult in labeledResults) {
store(labeledResult)
store(labeledResult, holdout)
}
}
}

enum class DatasetHoldout(val dirName: String) {
Train("train"),
Validation("val"),
Test("test"),
None("data");
}

/** Returns map with three entries (keys: train data pool, validation data pool and test data pool;
* values: holdout directories) if dataset structure is present.
* One pool (None) otherwise.**/
fun findDatasetHoldouts(inputDir: File): Map<DatasetHoldout, File> {
val trainDir = inputDir.resolve(DatasetHoldout.Train.dirName)
val valDir = inputDir.resolve(DatasetHoldout.Validation.dirName)
val testDir = inputDir.resolve(DatasetHoldout.Test.dirName)

return if (trainDir.exists() && valDir.exists() && testDir.exists()) {
mapOf(
DatasetHoldout.Train to trainDir,
DatasetHoldout.Validation to valDir,
DatasetHoldout.Test to testDir
)
} else {
mapOf(
DatasetHoldout.None to inputDir
)
}
}
59 changes: 37 additions & 22 deletions src/main/kotlin/astminer/pipeline/Pipeline.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
package astminer.pipeline

import astminer.common.getProjectFilesWithExtension
import astminer.common.model.FileLabelExtractor
import astminer.common.model.FunctionLabelExtractor
import astminer.common.model.Storage
import astminer.common.model.*
import astminer.config.FileExtension
import astminer.config.PipelineConfig
import astminer.parse.getParsingResultFactory
import astminer.pipeline.branch.FilePipelineBranch
import astminer.pipeline.branch.FunctionPipelineBranch
import astminer.pipeline.branch.IllegalLabelExtractorException
import me.tongfei.progressbar.ProgressBar
import java.io.Closeable
import java.io.File

/**
Expand All @@ -24,6 +23,9 @@ class Pipeline(private val config: PipelineConfig) {
private val filters = config.filters.map { it.filterImpl }
private val labelExtractor = config.labelExtractor.labelExtractorImpl

private val holdoutMap = findDatasetHoldouts(inputDirectory)
private val isDataset = holdoutMap.size > 1

private val branch = when (labelExtractor) {
is FileLabelExtractor -> FilePipelineBranch(filters, labelExtractor)
is FunctionLabelExtractor -> FunctionPipelineBranch(filters, labelExtractor)
Expand All @@ -41,32 +43,45 @@ class Pipeline(private val config: PipelineConfig) {
return config.storage.createStorage(storagePath)
}

private fun <T : Closeable, R> T.useSynchronously(callback: (T) -> R) =
this.use {
synchronized(this) {
callback(this)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks complicated ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add in storage interface syncronizedStore:

fun syncronizedStore(labeledResult: LabeledResult<out Node>) = syncronized {
    store(labeledResult)
}


private fun parseLanguage(language: FileExtension) {
val parsingResultFactory = getParsingResultFactory(language, config.parser.name)
createStorage(language).useSynchronously { storage ->
for ((holdoutType, holdoutDir) in holdoutMap) {
val holdoutFiles = getProjectFilesWithExtension(holdoutDir, language.fileExtension)
printHoldoutStat(holdoutFiles, holdoutType)
val progressBar = ProgressBar("", holdoutFiles.size.toLong())
parsingResultFactory.parseFilesInThreads(holdoutFiles, config.numOfThreads) { parseResult ->
val labeledResults = branch.process(parseResult)
storage.store(labeledResults, holdoutType)
progressBar.step()
}
progressBar.close()
}
}
}

private fun printHoldoutStat(files: List<File>, holdoutType: DatasetHoldout) {
var output = "${files.size} file(s) found"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use StringBuilder for manipulation with strings to avoid using mutable strings

if (isDataset) { output += " in ${holdoutType.name}" }
println(output)
}

/**
* Runs the pipeline that is defined in the [config].
*/
fun run() {
println("Working in ${config.numOfThreads} thread(s)")
if (isDataset) { println("Dataset structure found") }
for (language in config.parser.languages) {
println("Parsing $language")
val parsingResultFactory = getParsingResultFactory(language, config.parser.name)

println("Collecting files...")
val files = getProjectFilesWithExtension(inputDirectory, language.fileExtension)
println("${files.size} files retrieved")

val progressBar = ProgressBar("", files.size.toLong())

createStorage(language).use { storage ->
synchronized(storage) {
parsingResultFactory.parseFilesInThreads(files, config.numOfThreads) { parseResult ->
for (labeledResult in branch.process(parseResult)) {
storage.store(labeledResult)
}
progressBar.step()
}
}
}
progressBar.close()
parseLanguage(language)
}
println("Done!")
}
Expand Down
28 changes: 18 additions & 10 deletions src/main/kotlin/astminer/storage/ast/CsvAstStorage.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package astminer.storage.ast

import astminer.common.model.DatasetHoldout
import astminer.common.model.LabeledResult
import astminer.common.model.Node
import astminer.common.model.Storage
Expand All @@ -19,29 +20,26 @@ class CsvAstStorage(override val outputDirectoryPath: String) : Storage {
private val tokensMap: RankedIncrementalIdStorage<String> = RankedIncrementalIdStorage()
private val nodeTypesMap: RankedIncrementalIdStorage<String> = RankedIncrementalIdStorage()

private val astsOutputStream: PrintWriter
private val astsPrintWriters = mutableMapOf<DatasetHoldout, PrintWriter>()

init {
File(outputDirectoryPath).mkdirs()
val astsFile = File("$outputDirectoryPath/asts.csv")
astsFile.createNewFile()
astsOutputStream = PrintWriter(astsFile)
astsOutputStream.write("id,ast\n")
}

override fun store(labeledResult: LabeledResult<out Node>) {
override fun store(labeledResult: LabeledResult<out Node>, holdout: DatasetHoldout) {
for (node in labeledResult.root.preOrder()) {
tokensMap.record(node.token)
nodeTypesMap.record(node.typeLabel)
}
dumpAst(labeledResult.root, labeledResult.label)
val writer = astsPrintWriters.getOrPut(holdout) { holdout.resolveHoldout() }
dumpAst(labeledResult.root, labeledResult.label, writer)
}

override fun close() {
dumpTokenStorage(File("$outputDirectoryPath/tokens.csv"))
dumpNodeTypesStorage(File("$outputDirectoryPath/node_types.csv"))

astsOutputStream.close()
astsPrintWriters.values.map { it.close() }
}

private fun dumpTokenStorage(file: File) {
Expand All @@ -52,13 +50,23 @@ class CsvAstStorage(override val outputDirectoryPath: String) : Storage {
dumpIdStorageToCsv(nodeTypesMap, "node_type", nodeTypeToCsvString, file)
}

private fun dumpAst(root: Node, id: String) {
astsOutputStream.write("$id,${astString(root)}\n")
private fun dumpAst(root: Node, id: String, writer: PrintWriter) {
writer.println("$id,${astString(root)}")
}

internal fun astString(node: Node): String {
return "${tokensMap.getId(node.token)} ${nodeTypesMap.getId(node.typeLabel)}{${
node.children.joinToString(separator = "", transform = ::astString)
}}"
}

private fun DatasetHoldout.resolveHoldout(): PrintWriter {
val holdoutDir = File(outputDirectoryPath).resolve(this.dirName)
holdoutDir.mkdirs()
val astFile = holdoutDir.resolve("asts.csv")
astFile.createNewFile()
val newWriter = PrintWriter(astFile)
newWriter.println("id,ast")
return newWriter
}
}
15 changes: 11 additions & 4 deletions src/main/kotlin/astminer/storage/ast/DotAstStorage.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package astminer.storage.ast

import astminer.common.model.DatasetHoldout
import astminer.common.model.LabeledResult
import astminer.common.model.Node
import astminer.common.model.Storage
Expand All @@ -15,26 +16,25 @@ class DotAstStorage(override val outputDirectoryPath: String) : Storage {

internal data class FilePath(val parentPath: String, val fileName: String)

private val astDirectoryPath: File
private val astDirectoryPaths = mutableMapOf<DatasetHoldout, File>()
private val astFilenameFormat = "ast_%d.dot"
private val descriptionFileStream: PrintWriter
private var index: Long = 0

init {
File(outputDirectoryPath).mkdirs()
astDirectoryPath = File(outputDirectoryPath, "asts")
astDirectoryPath.mkdirs()
val descriptionFile = File(outputDirectoryPath, "description.csv")
descriptionFile.createNewFile()
descriptionFileStream = PrintWriter(descriptionFile)
descriptionFileStream.write("dot_file,source_file,label,node_id,token,type\n")
}

override fun store(labeledResult: LabeledResult<out Node>) {
override fun store(labeledResult: LabeledResult<out Node>, holdout: DatasetHoldout) {
// Use filename as a label for ast
// TODO: save full signature for method
val normalizedLabel = normalizeAstLabel(labeledResult.label)
val normalizedFilepath = normalizeFilepath(labeledResult.filePath)
val astDirectoryPath = astDirectoryPaths.getOrPut(holdout) { holdout.resolveHoldout() }
val nodesMap =
dumpAst(labeledResult.root, File(astDirectoryPath, astFilenameFormat.format(index)), normalizedLabel)
val nodeDescriptionFormat = "${astFilenameFormat.format(index)},$normalizedFilepath,$normalizedLabel,%d,%s,%s"
Expand Down Expand Up @@ -68,6 +68,13 @@ class DotAstStorage(override val outputDirectoryPath: String) : Storage {
return nodesMap
}

private fun DatasetHoldout.resolveHoldout(): File {
val outputDir = File(outputDirectoryPath)
val asts = outputDir.resolve(this.dirName).resolve("asts")
asts.mkdirs()
return asts
}

// Label should contain only latin letters, numbers and underscores, other symbols replace with an underscore
internal fun normalizeAstLabel(label: String): String =
label.replace("[^A-z0-9_]".toRegex(), "_")
Expand Down
19 changes: 13 additions & 6 deletions src/main/kotlin/astminer/storage/ast/JsonAstStorage.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package astminer.storage.ast

import astminer.common.model.DatasetHoldout
import astminer.common.model.LabeledResult
import astminer.common.model.Node
import astminer.common.model.Storage
Expand All @@ -19,14 +20,11 @@ private typealias Id = Int
class JsonAstStorage(override val outputDirectoryPath: String, private val withPaths: Boolean) : Storage {
private val treeFlattener = TreeFlattener()

private val writer: PrintWriter
private val datasetWriters = mutableMapOf<DatasetHoldout, PrintWriter>()

init {
val outputDirectory = File(outputDirectoryPath)
outputDirectory.mkdirs()
val file = outputDirectory.resolve("asts.jsonl")
file.createNewFile()
writer = file.printWriter()
}

@Serializable
Expand All @@ -38,15 +36,24 @@ class JsonAstStorage(override val outputDirectoryPath: String, private val withP
private fun TreeFlattener.EnumeratedNode.toOutputNode() =
OutputNode(node.token, node.typeLabel, children.map { it.id })

override fun store(labeledResult: LabeledResult<out Node>) {
override fun store(labeledResult: LabeledResult<out Node>, holdout: DatasetHoldout) {
val outputNodes = treeFlattener.flatten(labeledResult.root).map { it.toOutputNode() }
val path = if (withPaths) labeledResult.filePath else null
val labeledAst = LabeledAst(labeledResult.label, path, outputNodes)
val writer = datasetWriters.getOrPut(holdout) { holdout.resolveHoldout() }
writer.println(Json.encodeToString(labeledAst))
}

override fun close() {
writer.close()
datasetWriters.values.map { it.close() }
}

private fun DatasetHoldout.resolveHoldout(): PrintWriter {
val holdoutDir = File(outputDirectoryPath).resolve(this.dirName)
holdoutDir.mkdirs()
val astFile = holdoutDir.resolve("asts.jsonl")
astFile.createNewFile()
return PrintWriter(astFile)
}
}

Expand Down
22 changes: 13 additions & 9 deletions src/main/kotlin/astminer/storage/path/PathBasedStorage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,10 @@ abstract class PathBasedStorage(
) : Storage {

private val pathMiner = PathMiner(PathRetrievalSettings(config.maxPathLength, config.maxPathWidth))

private val pathsFile: File
private val pathContextPrintWriter: PrintWriter
private val datasetFileWriters = mutableMapOf<DatasetHoldout, PrintWriter>()

init {
File(outputDirectoryPath).mkdirs()
pathsFile = File(outputDirectoryPath).resolve("path_contexts.c2s")
pathsFile.createNewFile()
pathContextPrintWriter = PrintWriter(pathsFile)
}

private fun retrievePaths(node: Node) = if (config.maxPathContextsPerEntity != null) {
Expand All @@ -69,13 +64,22 @@ abstract class PathBasedStorage(
/**
* Extract paths from [labeledResult] and store them in the specified format.
*/
override fun store(labeledResult: LabeledResult<out Node>) {
override fun store(labeledResult: LabeledResult<out Node>, holdout: DatasetHoldout) {
val labeledPathContexts = retrieveLabeledPathContexts(labeledResult)
val output = labeledPathContextsToString(labeledPathContexts)
pathContextPrintWriter.println(output)
val writer = datasetFileWriters.getOrPut(holdout) { holdout.resolveWriter() }
writer.println(output)
}

override fun close() {
pathContextPrintWriter.close()
datasetFileWriters.values.map { it.close() }
}

private fun DatasetHoldout.resolveWriter(): PrintWriter {
val holdoutDir = File(outputDirectoryPath).resolve(this.dirName)
holdoutDir.mkdirs()
val pathContextFile = holdoutDir.resolve("path_contexts.c2s")
pathContextFile.createNewFile()
return PrintWriter(pathContextFile)
}
}
Loading