Skip to content

Commit

Permalink
New default dictionary (more compact). Also improve unittest coverage…
Browse files Browse the repository at this point in the history
… for the default dictionary
  • Loading branch information
wanasit committed Aug 30, 2020
1 parent 55882ba commit 6b75ad8
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ const val TARGET_DICTIONARY_FILENAME = "../kotori/src/main/resources/default_dic

fun main() {

val sourceDictionary = runAndPrintTimeMillis("Loading source dictionary") {
val sourceDictionary = runAndPrintTimeMillis(
"Loading source dictionary (MeCab IPADict)") {
Dictionaries.Mecab.loadIpadic()
}

Expand All @@ -36,6 +37,12 @@ fun main() {
.map { adjustTermCosts(it) }
println("Filtered term entries from ${deduplicatedTermEntries.size} to ${filteredTermEntries.size}")

filteredTermEntries.forEach {
if (it.leftId != it.rightId) {
throw AssertionError("Unexpected case where leftId != rightId in the source: $it")
}
}

val targetDictionary = runAndPrintTimeMillis("Building target dictionary") {

val terms = PlainTermDictionary.copyOf(filteredTermEntries) {
Expand Down Expand Up @@ -65,9 +72,13 @@ fun main() {

println("Dictionary file size: ${ (File(TARGET_DICTIONARY_FILENAME).length().toDouble() / 1024).format() } KB")

val writtenDictionary = File(TARGET_DICTIONARY_FILENAME).inputStream().use {
GZIPInputStream(it).use {
DefaultDictionary.readFromInputStream(it)
val writtenDictionary = runAndPrintTimeMillis(
"Reading back the written dictionary") {

File(TARGET_DICTIONARY_FILENAME).inputStream().use {
GZIPInputStream(it).use {
DefaultDictionary.readFromInputStream(it)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.github.wanasit.kotori.dictionaries

import com.github.wanasit.kotori.utils.termEntries
import org.junit.Test
import kotlin.test.assertTrue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import com.github.wanasit.kotori.utils.termEntries
import java.io.InputStream
import java.io.OutputStream



class DefaultDictionary(
override val terms: TermDictionary<DefaultTermFeatures>,
override val connection: PlainConnectionCostTable,
Expand All @@ -32,7 +30,7 @@ class DefaultDictionary(
fun readFromInputStream(inputStream: InputStream) : DefaultDictionary {
val unknownExtraction = DefaultUnknownTermExtraction.readFromInputStream(inputStream)
val connection = PlainConnectionCostTable.readFromInputStream(inputStream)
val termEntries = DefaultTermFeatures.readTermEntriesFromInputStream(inputStream)
val termEntries = DefaultTermEntry.readFromInputStream(inputStream)

return DefaultDictionary(PlainTermDictionary(termEntries), connection, unknownExtraction)
}
Expand All @@ -41,7 +39,7 @@ class DefaultDictionary(
val termEntries = value.termEntries.toTypedArray()
DefaultUnknownTermExtraction.writeToOutputStream(outputStream, value.unknownExtraction)
PlainConnectionCostTable.writeToOutputStream(outputStream, value.connection)
DefaultTermFeatures.writeTermEntriesToOutput(outputStream, termEntries)
DefaultTermEntry.writeToOutputAsDefaultTermEntries(outputStream, termEntries)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.github.wanasit.kotori.optimized

import com.github.wanasit.kotori.TermEntry
import com.github.wanasit.kotori.mecab.MeCabLikeTermFeatures
import com.github.wanasit.kotori.mecab.MeCabTermFeatures
import com.github.wanasit.kotori.utils.IOUtils
import java.io.InputStream
import java.io.OutputStream
import java.lang.IllegalArgumentException

/**
* A term entry for default dictionary
* To make default dictionary compact, we make assumption that
* the left and right context id of each term entry are the same
*/
data class DefaultTermEntry(
override val surfaceForm: String,
val contextId: Int,
override val cost: Int,
override val features: DefaultTermFeatures,
override val leftId: Int = contextId,
override val rightId: Int = contextId
) : TermEntry<DefaultTermFeatures> {

companion object {
fun copy(other: TermEntry<DefaultTermFeatures>) : DefaultTermEntry {
if (other.leftId != other.rightId) {
throw IllegalArgumentException(
"A default term entry must have the same left and right context ID")
}

return DefaultTermEntry(
surfaceForm = other.surfaceForm,
contextId = other.leftId,
cost = other.cost,
features = other.features)
}

fun readFromInputStream(inputStream: InputStream) : Array<DefaultTermEntry> {
val size = IOUtils.readInt(inputStream)
val sizePerEntry = 3
val flattenTermEntry = IOUtils.readIntArray(inputStream, size * sizePerEntry)
val surfaceForms = IOUtils.readStringArray(inputStream, size)
return Array(size) {
DefaultTermEntry(
surfaceForm = surfaceForms[it],
contextId = flattenTermEntry[it*sizePerEntry],
cost = flattenTermEntry[it*sizePerEntry + 1],
features = DefaultTermFeatures(
partOfSpeech = DefaultTermFeatures.PartOfSpeech.values()[flattenTermEntry[it*sizePerEntry + 2]]
))
}
}

fun writeToOutputAsDefaultTermEntries(outputStream: OutputStream, termEntries: Array<TermEntry<DefaultTermFeatures>>) {
writeToOutput(outputStream, termEntries.map { copy(it) }.toTypedArray() )
}

fun writeToOutput(outputStream: OutputStream, termEntries: Array<DefaultTermEntry>) {
val size = termEntries.size
val surfaceForms = termEntries.map { it.surfaceForm }.toTypedArray()
val flattenTermEntry = termEntries.flatMap { listOf(
it.contextId,
it.cost,
it.features.partOfSpeech.ordinal
)}.toIntArray()

IOUtils.writeInt(outputStream, size)
IOUtils.writeIntArray(outputStream, flattenTermEntry, includeSize = false)
IOUtils.writeStringArray(outputStream, surfaceForms, includeSize = false)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
package com.github.wanasit.kotori.optimized

import com.github.wanasit.kotori.TermEntry
import com.github.wanasit.kotori.mecab.MeCabLikeTermFeatures
import com.github.wanasit.kotori.mecab.MeCabTermFeatures
import com.github.wanasit.kotori.utils.IOUtils
import java.io.InputStream
import java.io.OutputStream

data class DefaultTermFeatures(
val partOfSpeech: PartOfSpeech = PartOfSpeech.UNKNOWN
) {
Expand All @@ -28,36 +21,4 @@ data class DefaultTermFeatures(
OTHER(),
UNKNOWN()
}

companion object {
fun readTermEntriesFromInputStream(inputStream: InputStream) : Array<TermEntry<DefaultTermFeatures>> {
val size = IOUtils.readInt(inputStream)
val sizePerEntry = 4
val flattenTermEntry = IOUtils.readIntArray(inputStream, size * sizePerEntry)
val surfaceForms = IOUtils.readStringArray(inputStream, size)
return Array(size) {
val leftId = flattenTermEntry[it*sizePerEntry]
val rightId = flattenTermEntry[it*sizePerEntry + 1]
val cost = flattenTermEntry[it*sizePerEntry + 2]

PlainTermEntry(surfaceForms[it],
leftId, rightId, cost,
DefaultTermFeatures(
partOfSpeech = PartOfSpeech.values()[flattenTermEntry[it*sizePerEntry + 3]]
))
}
}

fun writeTermEntriesToOutput(outputStream: OutputStream, termEntries: Array<TermEntry<DefaultTermFeatures>>) {
val size = termEntries.size
val surfaceForms = termEntries.map { it.surfaceForm }.toTypedArray()
val flattenTermEntry = termEntries.flatMap { listOf(it.leftId, it.rightId, it.cost,
it.features.partOfSpeech.ordinal
)}.toIntArray()

IOUtils.writeInt(outputStream, size)
IOUtils.writeIntArray(outputStream, flattenTermEntry, includeSize = false)
IOUtils.writeStringArray(outputStream, surfaceForms, includeSize = false)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ data class PlainToken<TermFeatures>(
override val index: Int,
override val features: TermFeatures) : Token<TermFeatures> {

class EmptyFeatures

companion object {

val EMPTY_FEATURES = EmptyFeatures();

fun createWithEmptyFeatures(text: String, index: Int) : PlainToken<EmptyFeatures> {
return PlainToken(text, index, EmptyFeatures())
return PlainToken(text, index, EMPTY_FEATURES)
}
}

Expand All @@ -21,4 +22,15 @@ data class PlainToken<TermFeatures>(
fun List<Token<*>>.withoutFeatures() : List<Token<EmptyFeatures>> {
return this.map { createWithEmptyFeatures(it.text, it.index) }
}

class EmptyFeatures {

override fun equals(other: Any?): Boolean {
return other is EmptyFeatures
}

override fun hashCode(): Int {
return javaClass.hashCode()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.github.wanasit.kotori.optimized.unknown

import com.github.wanasit.kotori.TermEntry
import com.github.wanasit.kotori.UnknownTermExtractionStrategy
import com.github.wanasit.kotori.optimized.DefaultTermEntry
import com.github.wanasit.kotori.optimized.DefaultTermFeatures
import com.github.wanasit.kotori.utils.IOUtils
import java.io.InputStream
Expand All @@ -21,7 +21,7 @@ object DefaultUnknownTermExtraction {
val categoryDefinitionFlattenArray = IOUtils.readShortArray(inputStream)
val arraySizes = IOUtils.readIntArray(inputStream)
val flattenCharToCategories = IOUtils.readIntArray(inputStream)
val flattenCategoryToTermEntries = DefaultTermFeatures.readTermEntriesFromInputStream(inputStream)
val flattenCategoryToTermEntries = DefaultTermEntry.readFromInputStream(inputStream)

var index = 0
val charToCategories: Array<IntArray> = Array(charcodeSize) {
Expand All @@ -35,7 +35,7 @@ object DefaultUnknownTermExtraction {
val entries = flattenCategoryToTermEntries.copyOfRange(
index, index + arraySizes[charcodeSize + it]).toList()
index += arraySizes[charcodeSize + it]
entries
entries as List<TermEntry<DefaultTermFeatures>>
}

val categoryToDefinition = Array(charCategorySize) {
Expand Down Expand Up @@ -78,7 +78,7 @@ object DefaultUnknownTermExtraction {
IOUtils.writeShortArray(outputStream, categoryDefinitionFlattenArray)
IOUtils.writeIntArray(outputStream, arraySizes)
IOUtils.writeIntArray(outputStream, flattenCharToCategories)
DefaultTermFeatures.writeTermEntriesToOutput(outputStream, flattenCategoryToTermEntries.toTypedArray())
DefaultTermEntry.writeToOutputAsDefaultTermEntries(outputStream, flattenCategoryToTermEntries.toTypedArray())
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
package com.github.wanasit.kotori.utils

import com.github.wanasit.kotori.Dictionary
import com.github.wanasit.kotori.TermDictionary
import com.github.wanasit.kotori.TermEntry
import com.github.wanasit.kotori.optimized.PlainTermEntry
import com.github.wanasit.kotori.optimized.PlainToken

val <F> Dictionary<F>.termEntries: List<TermEntry<F>>
get() = this.terms.map { it.second }

val <F> Dictionary<F>.size: Int
get() = this.terms.size()

val <F> TermDictionary<F>.asEntries : List<TermEntry<F>>
get() = this.map { it.second }

fun TermEntry<*>.withoutFeatures(): PlainTermEntry<PlainToken.EmptyFeatures> {
return PlainTermEntry(this, PlainToken.EMPTY_FEATURES)
}
Binary file modified kotori/src/main/resources/default_dictionary.bin.gz
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fun connectionTable(init: FakeConnectionTable.() -> Unit) : ConnectionCost {

open class FakingTermDictionaryWithEmptyFeatures : FakingTermDictionary<PlainToken.EmptyFeatures>() {
fun term(surfaceForm: String, wordType: WordType, cost: Int) {
term(surfaceForm, wordType, cost, PlainToken.EmptyFeatures())
term(surfaceForm, wordType, cost, PlainToken.EMPTY_FEATURES)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.github.wanasit.kotori.mecab

import com.github.wanasit.kotori.Dictionary
import com.github.wanasit.kotori.TermDictionary
import com.github.wanasit.kotori.TermEntry
import com.github.wanasit.kotori.utils.ResourceUtil
import java.nio.charset.Charset

Expand All @@ -9,18 +11,18 @@ const val FILE_NAME_TERM_DICTIONARY = "Adverb.csv"

fun MeCabDictionary.readFromResource(
namespace: String = DEFAULT_RESOURCE_NAMESPACE,
charset: Charset = MeCabDictionary.DEFAULT_CHARSET
charset: Charset = DEFAULT_CHARSET
) : Dictionary<MeCabTermFeatures> {

val termDictionary = MeCabTermDictionary.readFromInputStream(
ResourceUtil.readResourceAsStream(namespace, FILE_NAME_TERM_DICTIONARY), charset)

val termConnection = MeCabConnectionCost.readFromInputStream(
ResourceUtil.readResourceAsStream(namespace, MeCabDictionary.FILE_NAME_CONNECTION_COST), charset)
ResourceUtil.readResourceAsStream(namespace, FILE_NAME_CONNECTION_COST), charset)

val unknownTermStrategy = MeCabUnknownTermExtractionStrategy.readFromFileInputStreams(
ResourceUtil.readResourceAsStream(namespace, MeCabDictionary.FILE_NAME_UNKNOWN_ENTRIES),
ResourceUtil.readResourceAsStream(namespace, MeCabDictionary.FILE_NAME_CHARACTER_DEFINITION),
ResourceUtil.readResourceAsStream(namespace, FILE_NAME_UNKNOWN_ENTRIES),
ResourceUtil.readResourceAsStream(namespace, FILE_NAME_CHARACTER_DEFINITION),
charset)

return Dictionary(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.github.wanasit.kotori.optimized

import com.github.wanasit.kotori.connectionTable
import com.github.wanasit.kotori.fakeTermDictionaryWithoutFeature
import com.github.wanasit.kotori.optimized.unknown.UnknownTermExtractionByCharacterCategory
import com.github.wanasit.kotori.utils.asEntries
import com.github.wanasit.kotori.utils.termEntries
import com.github.wanasit.kotori.utils.withoutFeatures
import org.junit.Assert
import org.junit.Assert.*
import org.junit.Test

class TestDefaultDictionary {

@Test
fun testBasicCreationAndSerialization() {
val terms = fakeTermDictionaryWithoutFeature {
term("そこで", CONJ, 10)
term("そこ", NOUN, 40)
term("", VERB, 40)
term("", ADJ, 10)
term("はなし", NOUN, 40)
term("", VERB, 10)
term("なし", NOUN, 40)
term("終わり", NOUN, 40)
term("になった", VERB, 40)
term("", ADJ, 10)
term("なった", VERB, 40)
}.asEntries

val connectionCost = connectionTable {
header( END, NOUN, VERB, ADJ, CONJ)
row(BEGIN, 0, 10, 10, 0, 10)
row(NOUN, 10, 10, 40, 10, 0)
row(VERB, 10, 10, 10, 0, 10)
row(ADJ, 10, 10, 10, 10, 10)
row(CONJ, 0, 10, 10, 0, 10)
}

val unknownExtraction: UnknownTermExtractionByCharacterCategory<DefaultTermFeatures> =
UnknownTermExtractionByCharacterCategory.fromUnoptimizedMapping(emptyMap(), emptyMap(), emptyMap())

val dictionary = DefaultDictionary(
terms = PlainTermDictionary.copyOf(terms) { PlainTermEntry(it, DefaultTermFeatures()) },
unknownExtraction = unknownExtraction,
connection = PlainConnectionCostTable.copyOf(terms, connectionCost)
)

val file = createTempFile()
file.deleteOnExit()
file.outputStream().use {
DefaultDictionary.writeToOutputStream(it, dictionary);
}

val readDictionary = file.inputStream().use {
DefaultDictionary.readFromInputStream(it)
}

assertEquals(
terms.map { it.withoutFeatures() },
readDictionary.termEntries.map { it.withoutFeatures() })
assertEquals(connectionCost.lookup(1, 1), readDictionary.connection.lookup(1 , 1))
assertEquals(connectionCost.lookup(3, 1), readDictionary.connection.lookup(3 , 1))
}
}
Loading

0 comments on commit 6b75ad8

Please sign in to comment.