diff --git a/pkg/tasks/languagemodeling/bert/languagemodel.go b/pkg/tasks/languagemodeling/bert/languagemodel.go index 809067e..6b323d3 100644 --- a/pkg/tasks/languagemodeling/bert/languagemodel.go +++ b/pkg/tasks/languagemodeling/bert/languagemodel.go @@ -77,6 +77,12 @@ func LoadMaskedLanguageModel(modelPath string) (*LanguageModel, error) { }, nil } +// Close finalizes the LanguageModel resources. +// It satisfies the interface io.Closer. +func (m *LanguageModel) Close() error { + return m.embeddingsRepo.Close() +} + // Predict returns the predicted tokens func (m *LanguageModel) Predict(_ context.Context, text string, parameters languagemodeling.Parameters) (languagemodeling.Response, error) { if parameters.K == 0 { diff --git a/pkg/tasks/questionanswering/bert/questionanswering.go b/pkg/tasks/questionanswering/bert/questionanswering.go index 6ba52c2..b36ab9f 100644 --- a/pkg/tasks/questionanswering/bert/questionanswering.go +++ b/pkg/tasks/questionanswering/bert/questionanswering.go @@ -72,6 +72,12 @@ func LoadQuestionAnswering(modelPath string) (*QuestionAnswering, error) { }, nil } +// Close finalizes the QuestionAnswering resources. +// It satisfies the interface io.Closer. +func (qa *QuestionAnswering) Close() error { + return qa.embeddingsRepo.Close() +} + // Answer returns the answers for the given question and passage. // The options may assume default values if those are not set. func (qa *QuestionAnswering) Answer(_ context.Context, question string, passage string, opts *questionanswering.Options) (questionanswering.Response, error) { diff --git a/pkg/tasks/textclassification/bert/textclassification.go b/pkg/tasks/textclassification/bert/textclassification.go index e6277c2..59e38b7 100644 --- a/pkg/tasks/textclassification/bert/textclassification.go +++ b/pkg/tasks/textclassification/bert/textclassification.go @@ -97,6 +97,12 @@ func ID2Label(value map[string]string) []string { return y } +// Close finalizes the TextClassification resources. +// It satisfies the interface io.Closer. +func (m *TextClassification) Close() error { + return m.embeddingsRepo.Close() +} + // Classify returns the classification of the given text. func (m *TextClassification) Classify(_ context.Context, text string) (textclassification.Response, error) { tokenized := m.tokenize(text) diff --git a/pkg/tasks/textencoding/bert/textencoding.go b/pkg/tasks/textencoding/bert/textencoding.go index fd04ca6..31e0d15 100644 --- a/pkg/tasks/textencoding/bert/textencoding.go +++ b/pkg/tasks/textencoding/bert/textencoding.go @@ -73,6 +73,12 @@ func LoadTextEncoding(modelPath string) (*TextEncoding, error) { }, nil } +// Close finalizes the TextEncoding resources. +// It satisfies the interface io.Closer. +func (m *TextEncoding) Close() error { + return m.embeddingsRepo.Close() +} + // Encode returns the dense encoded representation of the given text. func (m *TextEncoding) Encode(_ context.Context, text string, poolingStrategy int) (textencoding.Response, error) { tokenized := m.tokenize(text) diff --git a/pkg/tasks/tokenclassification/bert/tokenclassification.go b/pkg/tasks/tokenclassification/bert/tokenclassification.go index 9fb80fb..5e23b59 100644 --- a/pkg/tasks/tokenclassification/bert/tokenclassification.go +++ b/pkg/tasks/tokenclassification/bert/tokenclassification.go @@ -96,6 +96,12 @@ func ID2Label(value map[string]string) []string { return y } +// Close finalizes the TokenClassification resources. +// It satisfies the interface io.Closer. +func (m *TokenClassification) Close() error { + return m.embeddingsRepo.Close() +} + // Classify returns the classification of the given text. func (m *TokenClassification) Classify(_ context.Context, text string, parameters tokenclassification.Parameters) (tokenclassification.Response, error) { tokenized := m.tokenize(text) diff --git a/pkg/tasks/tokenclassification/flair/tokenclassification.go b/pkg/tasks/tokenclassification/flair/tokenclassification.go index 9b7cb13..cd34fe1 100644 --- a/pkg/tasks/tokenclassification/flair/tokenclassification.go +++ b/pkg/tasks/tokenclassification/flair/tokenclassification.go @@ -79,6 +79,12 @@ func ID2Label(value map[string]string) []string { return y } +// Close finalizes the TokenClassification resources. +// It satisfies the interface io.Closer. +func (m *TokenClassification) Close() error { + return m.embeddingsRepo.Close() +} + // Classify returns the classification of the given text. func (m *TokenClassification) Classify(_ context.Context, text string, parameters tokenclassification.Parameters) (tokenclassification.Response, error) { tokenized := m.tokenize(text)