Skip to content

Commit

Permalink
Add tokenizer to trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 27, 2024
1 parent 9f6648b commit 1f29ecd
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 92 deletions.
8 changes: 5 additions & 3 deletions src/SIL.Machine.Tool/AlignCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public AlignCommand()
);
_symHeuristicOption = Option(
"-sh|--sym-heuristic <SYM_HEURISTIC>",
$"The symmetrization heuristic.\nHeuristics: \"{ToolHelpers.Och}\" (default), \"{ToolHelpers.Union}\", \"{ToolHelpers.Intersection}\", \"{ToolHelpers.Grow}\", \"{ToolHelpers.GrowDiag}\", \"{ToolHelpers.GrowDiagFinal}\", \"{ToolHelpers.GrowDiagFinalAnd}\", \"{ToolHelpers.None}\".",
$"The symmetrization heuristic.\nHeuristics: \"{SymmetrizationHelpers.Och}\" (default), \"{SymmetrizationHelpers.Union}\", \"{SymmetrizationHelpers.Intersection}\", \"{SymmetrizationHelpers.Grow}\", \"{SymmetrizationHelpers.GrowDiag}\", \"{SymmetrizationHelpers.GrowDiagFinal}\", \"{SymmetrizationHelpers.GrowDiagFinalAnd}\", \"{SymmetrizationHelpers.None}\".",
CommandOptionType.SingleValue
);
_scoresOption = Option("-s|--scores", "Include scores in the output.", CommandOptionType.NoValue);
Expand All @@ -53,7 +53,7 @@ protected override async Task<int> ExecuteCommandAsync(CancellationToken cancell
if (code != 0)
return code;

if (!ToolHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption?.Value()))
if (!SymmetrizationHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption?.Value()))
{
Out.WriteLine("The specified symmetrization heuristic is invalid.");
return 1;
Expand All @@ -75,7 +75,9 @@ protected override async Task<int> ExecuteCommandAsync(CancellationToken cancell

int processorCount = Environment.ProcessorCount;

SymmetrizationHeuristic symHeuristic = ToolHelpers.GetSymmetrizationHeuristic(_symHeuristicOption?.Value());
SymmetrizationHeuristic symHeuristic = SymmetrizationHelpers.GetSymmetrizationHeuristic(
_symHeuristicOption?.Value()
);

if (!_quietOption.HasValue())
Out.Write("Loading model... ");
Expand Down
22 changes: 12 additions & 10 deletions src/SIL.Machine.Tool/AlignmentModelCommandSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public void AddParameters(CommandBase command)
_modelArgument = command.Argument("MODEL_PATH", "The word alignment model.").IsRequired();
_modelTypeOption = command.Option(
"-mt|--model-type <MODEL_TYPE>",
$"The word alignment model type.\nTypes: \"{ToolHelpers.Hmm}\" (default), \"{ToolHelpers.Ibm1}\", \"{ToolHelpers.Ibm2}\", \"{ToolHelpers.Ibm3}\", \"{ToolHelpers.Ibm4}\", \"{ToolHelpers.FastAlign}\".",
$"The word alignment model type.\nTypes: \"{ThotWordAlignmentHelpers.Hmm}\" (default), \"{ThotWordAlignmentHelpers.Ibm1}\", \"{ThotWordAlignmentHelpers.Ibm2}\", \"{ThotWordAlignmentHelpers.Ibm3}\", \"{ThotWordAlignmentHelpers.Ibm4}\", \"{ThotWordAlignmentHelpers.FastAlign}\".",
CommandOptionType.SingleValue
);
_pluginOption = command.Option(
Expand Down Expand Up @@ -90,7 +90,7 @@ public IWordAlignmentModel CreateAlignmentModel(
ThotWordAlignmentModelType modelType = ThotWordAlignmentModelType.Hmm;
if (_modelTypeOption.HasValue())
{
modelType = ToolHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value());
modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value());
}
else
{
Expand All @@ -103,7 +103,7 @@ public IWordAlignmentModel CreateAlignmentModel(
yaml.Load(reader);
var root = (YamlMappingNode)yaml.Documents.First().RootNode;
var modelTypeStr = (string)root[new YamlScalarNode("model")];
modelType = ToolHelpers.GetThotWordAlignmentModelType(modelTypeStr);
modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(modelTypeStr);
}
}
}
Expand Down Expand Up @@ -144,7 +144,9 @@ public ITrainer CreateAlignmentModelTrainer(
return _modelFactory.CreateTrainer(_modelArgument.Value, corpus, maxSize, parameters, direct);
}

ThotWordAlignmentModelType modelType = ToolHelpers.GetThotWordAlignmentModelType(_modelTypeOption.Value());
ThotWordAlignmentModelType modelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(
_modelTypeOption.Value()
);

string modelPath = _modelArgument.Value;
if (ToolHelpers.IsDirectoryPath(modelPath))
Expand Down Expand Up @@ -227,12 +229,12 @@ private static bool ValidateAlignmentModelTypeOption(string value, IEnumerable<s
{
var validTypes = new HashSet<string>
{
ToolHelpers.Hmm,
ToolHelpers.Ibm1,
ToolHelpers.Ibm2,
ToolHelpers.FastAlign,
ToolHelpers.Ibm3,
ToolHelpers.Ibm4
ThotWordAlignmentHelpers.Hmm,
ThotWordAlignmentHelpers.Ibm1,
ThotWordAlignmentHelpers.Ibm2,
ThotWordAlignmentHelpers.FastAlign,
ThotWordAlignmentHelpers.Ibm3,
ThotWordAlignmentHelpers.Ibm4
};
validTypes.UnionWith(pluginTypes);
return string.IsNullOrEmpty(value) || validTypes.Contains(value);
Expand Down
13 changes: 10 additions & 3 deletions src/SIL.Machine.Tool/SymmetrizeCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public SymmetrizeCommand()
);
_symHeuristicOption = Option(
"-sh|--sym-heuristic <SYM_HEURISTIC>",
$"The symmetrization heuristic.\nHeuristics: \"{ToolHelpers.Och}\" (default), \"{ToolHelpers.Union}\", \"{ToolHelpers.Intersection}\", \"{ToolHelpers.Grow}\", \"{ToolHelpers.GrowDiag}\", \"{ToolHelpers.GrowDiagFinal}\", \"{ToolHelpers.GrowDiagFinalAnd}\".",
$"The symmetrization heuristic.\nHeuristics: \"{SymmetrizationHelpers.Och}\" (default), \"{SymmetrizationHelpers.Union}\", \"{SymmetrizationHelpers.Intersection}\", \"{SymmetrizationHelpers.Grow}\", \"{SymmetrizationHelpers.GrowDiag}\", \"{SymmetrizationHelpers.GrowDiagFinal}\", \"{SymmetrizationHelpers.GrowDiagFinalAnd}\".",
CommandOptionType.SingleValue
);
_quietOption = Option("-q|--quiet", "Only display results.", CommandOptionType.NoValue);
Expand Down Expand Up @@ -67,14 +67,21 @@ protected override async Task<int> ExecuteCommandAsync(CancellationToken cancell
return 1;
}

if (!ToolHelpers.ValidateSymmetrizationHeuristicOption(_symHeuristicOption.Value(), noneAllowed: false))
if (
!SymmetrizationHelpers.ValidateSymmetrizationHeuristicOption(
_symHeuristicOption.Value(),
noneAllowed: false
)
)
{
Out.WriteLine("The specified symmetrization heuristic is invalid.");
return 1;
}

string outputFormat = _outputFormatOption.Value() ?? Pharaoh;
SymmetrizationHeuristic heuristic = ToolHelpers.GetSymmetrizationHeuristic(_symHeuristicOption.Value());
SymmetrizationHeuristic heuristic = SymmetrizationHelpers.GetSymmetrizationHeuristic(
_symHeuristicOption.Value()
);

using var directReader = new StreamReader(_directArgument.Value);
using var inverseReader = new StreamReader(_inverseArgument.Value);
Expand Down
81 changes: 10 additions & 71 deletions src/SIL.Machine.Tool/ToolHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,6 @@ namespace SIL.Machine;

internal static class ToolHelpers
{
public const string FastAlign = "fast_align";
public const string Ibm1 = "ibm1";
public const string Ibm2 = "ibm2";
public const string Hmm = "hmm";
public const string Ibm3 = "ibm3";
public const string Ibm4 = "ibm4";

public const string Och = "och";
public const string Union = "union";
public const string Intersection = "intersection";
public const string Grow = "grow";
public const string GrowDiag = "grow-diag";
public const string GrowDiagFinal = "grow-diag-final";
public const string GrowDiagFinalAnd = "grow-diag-final-and";
public const string None = "none";

public static bool ValidateCorpusFormatOption(string value)
{
return string.IsNullOrEmpty(value) || value.ToLowerInvariant().IsOneOf("dbl", "usx", "text", "pt", "pt_m");
Expand Down Expand Up @@ -154,29 +138,14 @@ public static string GetTranslationModelConfigFileName(string path)

public static bool ValidateTranslationModelTypeOption(string value)
{
var validTypes = new HashSet<string> { Hmm, Ibm1, Ibm2, FastAlign };
return string.IsNullOrEmpty(value) || validTypes.Contains(value);
}

public static ThotWordAlignmentModelType GetThotWordAlignmentModelType(string modelType)
{
switch (modelType)
var validTypes = new HashSet<string>
{
case "fastAlign":
case FastAlign:
return ThotWordAlignmentModelType.FastAlign;
case Ibm1:
return ThotWordAlignmentModelType.Ibm1;
case Ibm2:
return ThotWordAlignmentModelType.Ibm2;
default:
case Hmm:
return ThotWordAlignmentModelType.Hmm;
case Ibm3:
return ThotWordAlignmentModelType.Ibm3;
case Ibm4:
return ThotWordAlignmentModelType.Ibm4;
}
ThotWordAlignmentHelpers.Hmm,
ThotWordAlignmentHelpers.Ibm1,
ThotWordAlignmentHelpers.Ibm2,
ThotWordAlignmentHelpers.FastAlign
};
return string.IsNullOrEmpty(value) || validTypes.Contains(value);
}

public static ITrainer CreateTranslationModelTrainer(
Expand All @@ -186,7 +155,9 @@ public static ITrainer CreateTranslationModelTrainer(
int maxSize
)
{
ThotWordAlignmentModelType wordAlignmentModelType = GetThotWordAlignmentModelType(modelType);
ThotWordAlignmentModelType wordAlignmentModelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(
modelType
);

string modelDir = Path.GetDirectoryName(modelConfigFileName);
if (!Directory.Exists(modelDir))
Expand All @@ -201,38 +172,6 @@ int maxSize
};
}

public static bool ValidateSymmetrizationHeuristicOption(string value, bool noneAllowed = true)
{
var validHeuristics = new HashSet<string>
{
Och,
Union,
Intersection,
Grow,
GrowDiag,
GrowDiagFinal,
GrowDiagFinalAnd
};
if (noneAllowed)
validHeuristics.Add(None);
return string.IsNullOrEmpty(value) || validHeuristics.Contains(value.ToLowerInvariant());
}

public static SymmetrizationHeuristic GetSymmetrizationHeuristic(string value)
{
return value switch
{
None => SymmetrizationHeuristic.None,
Union => SymmetrizationHeuristic.Union,
Intersection => SymmetrizationHeuristic.Intersection,
Grow => SymmetrizationHeuristic.Grow,
GrowDiag => SymmetrizationHeuristic.GrowDiag,
GrowDiagFinal => SymmetrizationHeuristic.GrowDiagFinal,
GrowDiagFinalAnd => SymmetrizationHeuristic.GrowDiagFinalAnd,
_ => SymmetrizationHeuristic.Och,
};
}

public static StreamWriter CreateStreamWriter(string fileName)
{
var utf8Encoding = new UTF8Encoding(false);
Expand Down
4 changes: 2 additions & 2 deletions src/SIL.Machine.Tool/TranslationModelCommandSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public void AddParameters(CommandBase command)
_modelArgument = command.Argument("MODEL_PATH", "The translation model.").IsRequired();
_modelTypeOption = command.Option(
"-mt|--model-type <MODEL_TYPE>",
$"The word alignment model type.\nTypes: \"{ToolHelpers.Hmm}\" (default), \"{ToolHelpers.Ibm1}\", \"{ToolHelpers.Ibm2}\", \"{ToolHelpers.FastAlign}\".",
$"The word alignment model type.\nTypes: \"{ThotWordAlignmentHelpers.Hmm}\" (default), \"{ThotWordAlignmentHelpers.Ibm1}\", \"{ThotWordAlignmentHelpers.Ibm2}\", \"{ThotWordAlignmentHelpers.FastAlign}\".",
CommandOptionType.SingleValue
);
_pluginOption = command.Option(
Expand Down Expand Up @@ -67,7 +67,7 @@ public ITranslationModel CreateModel()
if (_modelFactory != null)
return _modelFactory.CreateModel(_modelArgument.Value);

ThotWordAlignmentModelType wordAlignmentModelType = ToolHelpers.GetThotWordAlignmentModelType(
ThotWordAlignmentModelType wordAlignmentModelType = ThotWordAlignmentHelpers.GetThotWordAlignmentModelType(
_modelTypeOption.Value()
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

<ItemGroup>
<PackageReference Include="Thot" Version="3.4.4" />
<PackageReference Include="CaseExtensions" Version="1.1.0" />
</ItemGroup>

<ItemGroup>
Expand Down
14 changes: 13 additions & 1 deletion src/SIL.Machine.Translation.Thot/ThotWordAlignmentModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Threading.Tasks;
using SIL.Extensions;
using SIL.Machine.Corpora;
using SIL.Machine.Tokenization;
using SIL.ObjectModel;

namespace SIL.Machine.Translation.Thot
Expand Down Expand Up @@ -116,6 +117,11 @@ public void CreateNew(string prefFileName)
}

public ITrainer CreateTrainer(IParallelTextCorpus corpus)
{
return CreateTrainer(corpus, null);
}

public ITrainer CreateTrainer(IParallelTextCorpus corpus, ITokenizer<string, int, string> tokenizer = null)
{
CheckDisposed();

Expand All @@ -126,7 +132,13 @@ public ITrainer CreateTrainer(IParallelTextCorpus corpus)
);
}

return new Trainer(this, corpus);
var trainer = new Trainer(this, corpus);
if (tokenizer != null)
{
trainer.SourceTokenizer = tokenizer;
trainer.TargetTokenizer = tokenizer;
}
return trainer;
}

public Task SaveAsync()
Expand Down
40 changes: 39 additions & 1 deletion src/SIL.Machine.Translation.Thot/ThotWordAlignmentModelType.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
namespace SIL.Machine.Translation.Thot
using System;
using CaseExtensions;

namespace SIL.Machine.Translation.Thot
{
public enum ThotWordAlignmentModelType
{
Expand All @@ -9,4 +12,39 @@ public enum ThotWordAlignmentModelType
Ibm3,
Ibm4
}

public static class ThotWordAlignmentHelpers
{
public const string FastAlign = "fast_align";
public const string Ibm1 = "ibm1";
public const string Ibm2 = "ibm2";
public const string Hmm = "hmm";
public const string Ibm3 = "ibm3";
public const string Ibm4 = "ibm4";

public static ThotWordAlignmentModelType GetThotWordAlignmentModelType(
string modelType,
ThotWordAlignmentModelType? defaultType = null
)
{
switch (modelType.ToSnakeCase())
{
case FastAlign:
return ThotWordAlignmentModelType.FastAlign;
case Ibm1:
return ThotWordAlignmentModelType.Ibm1;
case Ibm2:
return ThotWordAlignmentModelType.Ibm2;
case Hmm:
return ThotWordAlignmentModelType.Hmm;
case Ibm3:
return ThotWordAlignmentModelType.Ibm3;
case Ibm4:
return ThotWordAlignmentModelType.Ibm4;
default:
return defaultType
?? throw new ArgumentException($"Invalid word alignment model type: {modelType}");
}
}
}
}
1 change: 1 addition & 0 deletions src/SIL.Machine/SIL.Machine.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
<PackageReference Include="SIL.Scripture" Version="12.0.1" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="6.0.0" />
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="6.0.0" />
<PackageReference Include="CaseExtensions" Version="1.1.0" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net461'">
Expand Down
Loading

0 comments on commit 1f29ecd

Please sign in to comment.