Skip to content

Commit

Permalink
Merge pull request #464 from ClearTK/feature/463-Ability-to-work-in-m…
Browse files Browse the repository at this point in the history
…ulti-classloader-environments

Issue #463: Ability to work in multi-classloader environments
  • Loading branch information
reckart authored Nov 7, 2022
2 parents 334c84e + 79a775e commit 421ad73
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.cleartk.ml.jar.ClassifierBuilder_ImplBase;
import org.cleartk.ml.jar.JarStreams;
import org.cleartk.ml.mallet.factory.ClassifierTrainerFactory;
import org.cleartk.util.ClassLookup;
import org.cleartk.util.ReflectionUtil;

import cc.mallet.classify.Classifier;
Expand All @@ -53,7 +54,8 @@
*/

public abstract class MalletClassifierBuilder_ImplBase<CLASSIFIER_TYPE extends MalletClassifier_ImplBase<OUTCOME_TYPE>, OUTCOME_TYPE>
extends ClassifierBuilder_ImplBase<CLASSIFIER_TYPE, List<NameNumber>, OUTCOME_TYPE, String> {
extends
ClassifierBuilder_ImplBase<CLASSIFIER_TYPE, List<NameNumber>, OUTCOME_TYPE, String> {

private static final String MODEL_NAME = "model.mallet";

Expand All @@ -62,6 +64,7 @@ public File getTrainingDataFile(File dir) {
return new File(dir, "training-data.mallet");
}

@Override
public void trainClassifier(File dir, String... args) throws Exception {

InstanceListCreator instanceListCreator = new InstanceListCreator();
Expand All @@ -71,20 +74,14 @@ public void trainClassifier(File dir, String... args) throws Exception {
String factoryName = args[0];
Class<ClassifierTrainerFactory<?>> factoryClass = createTrainerFactory(factoryName);
if (factoryClass == null) {
String factoryName2 = "org.cleartk.ml.mallet.factory." + factoryName
+ "TrainerFactory";
String factoryName2 = "org.cleartk.ml.mallet.factory." + factoryName + "TrainerFactory";
factoryClass = createTrainerFactory(factoryName2);
}
if (factoryClass == null) {
throw new IllegalArgumentException(
String
.format(
"name for classifier trainer factory is not valid: name given ='%s'. Valid classifier names include: %s, %s, %s, and %s",
factoryName,
ClassifierTrainerFactory.NAMES[0],
ClassifierTrainerFactory.NAMES[1],
ClassifierTrainerFactory.NAMES[2],
ClassifierTrainerFactory.NAMES[3]));
throw new IllegalArgumentException(String.format(
"name for classifier trainer factory is not valid: name given ='%s'. Valid classifier names include: %s, %s, %s, and %s",
factoryName, ClassifierTrainerFactory.NAMES[0], ClassifierTrainerFactory.NAMES[1],
ClassifierTrainerFactory.NAMES[2], ClassifierTrainerFactory.NAMES[3]));
}

String[] factoryArgs = new String[args.length - 1];
Expand All @@ -96,20 +93,21 @@ public void trainClassifier(File dir, String... args) throws Exception {
trainer = factory.createTrainer(factoryArgs);
} catch (Throwable t) {
throw new IllegalArgumentException("Unable to create trainer. Usage for "
+ factoryClass.getCanonicalName() + ": " + factory.getUsageMessage(), t);
+ factoryClass.getCanonicalName() + ": " + factory.getUsageMessage(), t);
}

this.classifier = trainer.train(instanceList);

ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(dir, MODEL_NAME)));
ObjectOutputStream oos = new ObjectOutputStream(
new FileOutputStream(new File(dir, MODEL_NAME)));
oos.writeObject(classifier);
oos.close();

}

private Class<ClassifierTrainerFactory<?>> createTrainerFactory(String className) {
try {
return ReflectionUtil.uncheckedCast(Class.forName(className));
return ReflectionUtil.uncheckedCast(ClassLookup.lookupClass(className));
} catch (ClassNotFoundException cnfe) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.cleartk.ml.encoder.features.FeaturesEncoder;
import org.cleartk.ml.encoder.features.FeaturesEncoder_ImplBase;
import org.cleartk.ml.encoder.outcome.OutcomeEncoder;
import org.cleartk.util.ClassLookup;
import org.cleartk.util.ReflectionUtil;

/**
Expand All @@ -57,7 +58,7 @@
* @author Steven Bethard
*/
public abstract class EncodingJarClassifierBuilder<CLASSIFIER_TYPE, ENCODED_FEATURES_TYPE, OUTCOME_TYPE, ENCODED_OUTCOME_TYPE>
extends JarClassifierBuilder<CLASSIFIER_TYPE> {
extends JarClassifierBuilder<CLASSIFIER_TYPE> {

private static final String ENCODERS_FILE_NAME = FeaturesEncoder_ImplBase.ENCODERS_FILE_NAME;

Expand Down Expand Up @@ -115,7 +116,7 @@ protected void packageClassifier(File dir, JarOutputStream modelStream) throws I
protected void unpackageClassifier(JarInputStream modelStream) throws IOException {
super.unpackageClassifier(modelStream);
JarStreams.getNextJarEntry(modelStream, ENCODERS_FILE_NAME);
ObjectInputStream is = new ObjectInputStream(modelStream);
ObjectInputStream is = ClassLookup.streamObjects(modelStream);
try {
this.featuresEncoder = this.featuresEncoderCast(is.readObject());
this.outcomeEncoder = this.outcomeEncoderCast(is.readObject());
Expand All @@ -128,14 +129,9 @@ protected void unpackageClassifier(JarInputStream modelStream) throws IOExceptio
private FeaturesEncoder<ENCODED_FEATURES_TYPE> featuresEncoderCast(Object object) {
FeaturesEncoder<ENCODED_FEATURES_TYPE> encoder = (FeaturesEncoder<ENCODED_FEATURES_TYPE>) object;

ReflectionUtil.checkTypeParametersAreEqual(
EncodingJarClassifierBuilder.class,
"ENCODED_FEATURES_TYPE",
this,
FeaturesEncoder.class,
"ENCODED_FEATURES_TYPE",
encoder,
ClassCastException.class);
ReflectionUtil.checkTypeParametersAreEqual(EncodingJarClassifierBuilder.class,
"ENCODED_FEATURES_TYPE", this, FeaturesEncoder.class, "ENCODED_FEATURES_TYPE", encoder,
ClassCastException.class);

return encoder;
}
Expand All @@ -145,23 +141,12 @@ private OutcomeEncoder<OUTCOME_TYPE, ENCODED_OUTCOME_TYPE> outcomeEncoderCast(Ob
OutcomeEncoder<OUTCOME_TYPE, ENCODED_OUTCOME_TYPE> encoder;
encoder = (OutcomeEncoder<OUTCOME_TYPE, ENCODED_OUTCOME_TYPE>) object;

ReflectionUtil.checkTypeParametersAreEqual(
EncodingJarClassifierBuilder.class,
"OUTCOME_TYPE",
this,
OutcomeEncoder.class,
"OUTCOME_TYPE",
encoder,
ClassCastException.class);

ReflectionUtil.checkTypeParametersAreEqual(
EncodingJarClassifierBuilder.class,
"ENCODED_OUTCOME_TYPE",
this,
OutcomeEncoder.class,
"ENCODED_OUTCOME_TYPE",
encoder,
ClassCastException.class);
ReflectionUtil.checkTypeParametersAreEqual(EncodingJarClassifierBuilder.class, "OUTCOME_TYPE",
this, OutcomeEncoder.class, "OUTCOME_TYPE", encoder, ClassCastException.class);

ReflectionUtil.checkTypeParametersAreEqual(EncodingJarClassifierBuilder.class,
"ENCODED_OUTCOME_TYPE", this, OutcomeEncoder.class, "ENCODED_OUTCOME_TYPE", encoder,
ClassCastException.class);

return encoder;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
/*
* Copyright (c) 2012, Regents of the University of Colorado
/*
* Copyright (c) 2012, Regents of the University of Colorado
* All rights reserved.
*
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
*
*
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the University of Colorado at Boulder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
Expand All @@ -19,43 +19,37 @@
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
* POSSIBILITY OF SUCH DAMAGE.
*/
package org.cleartk.ml.jar;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;

import org.cleartk.util.ClassLookup;

/**
* Serves as a common base class for data writer factories such as {@link DefaultDataWriterFactory}
* and {@link DefaultSequenceDataWriterFactory}.
*
*
* <br>
* Copyright (c) 2012, Regents of the University of Colorado <br>
* All rights reserved.
*
*
* @author Steven Bethard
*/
public class GenericDataWriterFactory<OUTCOME_TYPE> extends DirectoryDataWriterFactory {

protected <T> T createDataWriter(String dataWriterClassName, Class<T> superClass)
throws IOException {
throws IOException {
try {
Class<?> untypedCls = Class.forName(dataWriterClassName);
Class<?> untypedCls = ClassLookup.lookupClass(dataWriterClassName);
Class<? extends T> cls = untypedCls.asSubclass(superClass);
return cls.getConstructor(File.class).newInstance(this.outputDirectory);
} catch (ClassNotFoundException e) {
throw new IOException(e);
} catch (InstantiationException e) {
throw new IOException(e);
} catch (IllegalAccessException e) {
throw new IOException(e);
} catch (InvocationTargetException e) {
throw new IOException(e);
} catch (NoSuchMethodException e) {
return cls.getConstructor(File.class).newInstance(outputDirectory);
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException
| InvocationTargetException | NoSuchMethodException e) {
throw new IOException(e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import org.cleartk.ml.DataWriter;
import org.cleartk.ml.SequenceDataWriter;
import org.cleartk.util.ClassLookup;

/**
* Superclass for builders which package classifiers as jar files. Saves a manifest from which new
Expand Down Expand Up @@ -65,7 +66,7 @@ public abstract class JarClassifierBuilder<CLASSIFIER_TYPE> {
* The name of the attribute where the classifier builder class is stored.
*/
private static final Attributes.Name CLASSIFIER_BUILDER_ATTRIBUTE_NAME = new Attributes.Name(
"classifierBuilderClass");
"classifierBuilderClass");

/**
* The manifest associated with this classifier builder. The manifest will be saved to directories
Expand Down Expand Up @@ -102,7 +103,8 @@ public static JarClassifierBuilder<?> fromManifest(Manifest manifest) {
String className = manifest.getMainAttributes().getValue(CLASSIFIER_BUILDER_ATTRIBUTE_NAME);
JarClassifierBuilder<?> builder;
try {
builder = Class.forName(className).asSubclass(JarClassifierBuilder.class).newInstance();
builder = ClassLookup.lookupClass(className).asSubclass(JarClassifierBuilder.class)
.newInstance();
} catch (Exception e) {
throw new RuntimeException("ClassifierBuilder class read from manifest does not exist", e);
}
Expand All @@ -122,7 +124,7 @@ public static JarClassifierBuilder<?> fromManifest(Manifest manifest) {
* {@link #trainClassifier(File, String...)}.
*/
public static void trainAndPackage(File trainingDirectory, String... trainingArguments)
throws Exception {
throws Exception {
JarClassifierBuilder<?> classifierBuilder = fromTrainingDirectory(trainingDirectory);
classifierBuilder.trainClassifier(trainingDirectory, trainingArguments);
classifierBuilder.packageClassifier(trainingDirectory);
Expand All @@ -132,7 +134,6 @@ public static void trainAndPackage(File trainingDirectory, String... trainingArg
* Creates a new classifier builder with a default manifest.
*/
public JarClassifierBuilder() {
super();
this.manifest = new Manifest();
Attributes attributes = this.manifest.getMainAttributes();
attributes.put(Attributes.Name.MANIFEST_VERSION, "1.0");
Expand All @@ -149,7 +150,7 @@ public JarClassifierBuilder() {
public void saveToTrainingDirectory(File dir) throws IOException {
// save the manifest to the directory
try (FileOutputStream manifestStream = new FileOutputStream(getManifestFile(dir))) {
this.manifest.write(manifestStream);
this.manifest.write(manifestStream);
}
}

Expand Down Expand Up @@ -177,9 +178,9 @@ public void saveToTrainingDirectory(File dir) throws IOException {
* The directory where the classifier model was trained.
*/
public void packageClassifier(File dir) throws IOException {
try (JarOutputStream modelStream = new JarOutputStream(new BufferedOutputStream(
new FileOutputStream(getModelJarFile(dir))), this.manifest)) {
this.packageClassifier(dir, modelStream);
try (JarOutputStream modelStream = new JarOutputStream(
new BufferedOutputStream(new FileOutputStream(getModelJarFile(dir))), this.manifest)) {
this.packageClassifier(dir, modelStream);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
import org.apache.commons.io.filefilter.IOFileFilter;
import org.apache.commons.io.filefilter.SuffixFileFilter;
import org.apache.commons.io.filefilter.TrueFileFilter;
import org.junit.Assert;
import org.apache.uima.fit.factory.ConfigurationParameterFactory;
import org.junit.Assert;

/**
* <br>
Expand All @@ -47,18 +47,17 @@
public class ParametersTestUtil {

public static void testParameterDefinitions(String outputDirectory, String... excludeFiles)
throws ClassNotFoundException {
throws ClassNotFoundException {
IOFileFilter includeFilter = new SuffixFileFilter(".java");

if (excludeFiles != null) {
IOFileFilter excludeFilter = FileFilterUtils.notFileFilter(new SuffixFileFilter(excludeFiles));
IOFileFilter excludeFilter = FileFilterUtils
.notFileFilter(new SuffixFileFilter(excludeFiles));
includeFilter = FileFilterUtils.and(excludeFilter, includeFilter);
}

Iterator<File> files = org.apache.commons.io.FileUtils.iterateFiles(
new File(outputDirectory),
includeFilter,
TrueFileFilter.INSTANCE);
Iterator<File> files = org.apache.commons.io.FileUtils.iterateFiles(new File(outputDirectory),
includeFilter, TrueFileFilter.INSTANCE);
testParameterDefinitions(files);
}

Expand All @@ -76,22 +75,24 @@ public static void testParameterDefinitions(Iterator<File> files) throws ClassNo
Field[] fields = cls.getDeclaredFields();
for (Field field : fields) {
if (ConfigurationParameterFactory.isConfigurationParameterField(field)) {
org.apache.uima.fit.descriptor.ConfigurationParameter annotation = field.getAnnotation(org.apache.uima.fit.descriptor.ConfigurationParameter.class);
org.apache.uima.fit.descriptor.ConfigurationParameter annotation = field
.getAnnotation(org.apache.uima.fit.descriptor.ConfigurationParameter.class);
String parameterName = annotation.name();
String expectedName = field.getName();
if (!expectedName.equals(parameterName)) {
badParameters.add("'" + parameterName + "' should be '" + expectedName + "'");
}

expectedName = className+"."+field.getName();
expectedName = className + "." + field.getName();
String fieldName = getParameterNameField(expectedName);
try {
Field fld = cls.getDeclaredField(fieldName);
if ((fld.getModifiers() & Modifier.PUBLIC) == 0
|| (fld.getModifiers() & Modifier.FINAL) == 0
|| (fld.getModifiers() & Modifier.PUBLIC) == 0) {
|| (fld.getModifiers() & Modifier.FINAL) == 0
|| (fld.getModifiers() & Modifier.PUBLIC) == 0) {
missingParameterNameFields.add(expectedName);
} else if (!fld.get(null).equals(expectedName.substring(expectedName.lastIndexOf(".")+1))) {
} else if (!fld.get(null)
.equals(expectedName.substring(expectedName.lastIndexOf(".") + 1))) {
missingParameterNameFields.add(expectedName);
}
} catch (Exception e) {
Expand All @@ -103,19 +104,19 @@ public static void testParameterDefinitions(Iterator<File> files) throws ClassNo

if (badParameters.size() > 0 || missingParameterNameFields.size() > 0) {
String message = String.format(
"%d descriptor parameters with bad names and %d descriptor parameters with no name field. ",
badParameters.size(),
missingParameterNameFields.size());
"%d descriptor parameters with bad names and %d descriptor parameters with no name field. ",
badParameters.size(), missingParameterNameFields.size());
System.err.println(message);
System.err.println("descriptor parameters with bad names: ");
for (String badParameter : badParameters) {
System.err.println(badParameter);
}
System.err.println("each configuration parameter should have a public static final String that specifies its name. The missing fields are: ");
System.err.println(
"each configuration parameter should have a public static final String that specifies its name. The missing fields are: ");
for (String missingParameterNameField : missingParameterNameFields) {
System.err.println(missingParameterNameField + " should be named by "
+ missingParameterNameField.substring(0, missingParameterNameField.lastIndexOf('.'))
+ "." + getParameterNameField(missingParameterNameField));
+ missingParameterNameField.substring(0, missingParameterNameField.lastIndexOf('.'))
+ "." + getParameterNameField(missingParameterNameField));
}
Assert.fail(message);
}
Expand Down
Loading

0 comments on commit 421ad73

Please sign in to comment.