-
Notifications
You must be signed in to change notification settings - Fork 128
/
KNNMethodContext.java
201 lines (172 loc) · 6.95 KB
/
KNNMethodContext.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
package org.opensearch.knn.index;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.common.xcontent.ToXContentFragment;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.index.mapper.MapperParsingException;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
/**
* KNNMethodContext will contain the information necessary to produce a library index from an Opensearch mapping.
* It will encompass all parameters necessary to build the index.
*/
public class KNNMethodContext implements ToXContentFragment {
private static KNNMethodContext defaultInstance = null;
public static synchronized KNNMethodContext getDefault() {
if (defaultInstance == null) {
defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT,
new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()));
}
return defaultInstance;
}
private final KNNEngine knnEngine;
private final SpaceType spaceType;
private final MethodComponentContext methodComponent;
/**
* Constructor
*
* @param knnEngine engine that this method uses
* @param spaceType space type that this method uses
* @param methodComponent MethodComponent describing the main index
*/
public KNNMethodContext(KNNEngine knnEngine, SpaceType spaceType, MethodComponentContext methodComponent) {
this.knnEngine = knnEngine;
this.spaceType = spaceType;
this.methodComponent = methodComponent;
}
/**
* Gets the main method component
*
* @return methodComponent
*/
public MethodComponentContext getMethodComponent() {
return methodComponent;
}
/**
* Gets the engine to be used for this context
*
* @return knnEngine
*/
public KNNEngine getEngine() {
return knnEngine;
}
/**
* Gets the space type for this context
*
* @return spaceType
*/
public SpaceType getSpaceType() {
return spaceType;
}
/**
* This method uses the knnEngine to validate that the method is compatible with the engine
*
*/
public void validate() {
knnEngine.validateMethod(this);
}
/**
* Parses an Object into a KNNMethodContext.
*
* @param in Object containing mapping to be parsed
* @return KNNMethodContext
*/
public static KNNMethodContext parse(Object in) {
if (!(in instanceof Map<?, ?>)) {
throw new MapperParsingException("Unable to parse mapping into KNNMethodContext. Object not of type \"Map\"");
}
@SuppressWarnings("unchecked")
Map<String, Object> methodMap = (Map<String, Object>) in;
KNNEngine engine = KNNEngine.DEFAULT; // Get or default
SpaceType spaceType = SpaceType.DEFAULT; // Get or default
String name = "";
Map<String, Object> parameters = null;
String key;
Object value;
for (Map.Entry<String, Object> methodEntry : methodMap.entrySet()) {
key = methodEntry.getKey();
value = methodEntry.getValue();
if (KNN_ENGINE.equals(key)) {
if (value != null && !(value instanceof String)) {
throw new MapperParsingException("\"" + KNN_ENGINE + "\" must be a string");
}
if (value != null) {
try {
engine = KNNEngine.getEngine((String) value);
} catch (IllegalArgumentException iae) {
throw new MapperParsingException("Invalid " + KNN_ENGINE + ": " + value);
}
}
} else if (METHOD_PARAMETER_SPACE_TYPE.equals(key)) {
if (value != null && !(value instanceof String)) {
throw new MapperParsingException("\"" + METHOD_PARAMETER_SPACE_TYPE + "\" must be a string");
}
try {
spaceType = SpaceType.getSpace((String) value);
} catch (IllegalArgumentException iae) {
throw new MapperParsingException("Invalid " + METHOD_PARAMETER_SPACE_TYPE + ": " + value);
}
} else if (NAME.equals(key)) {
if (!(value instanceof String)) {
throw new MapperParsingException(NAME + "has to be a string");
}
name = (String) value;
} else if (PARAMETERS.equals(key)) {
if (value != null && !(value instanceof Map)) {
throw new MapperParsingException("Unable to parse parameters for main method component");
}
@SuppressWarnings("unchecked")
Map<String, Object> parameters1 = (Map<String, Object>) value;
parameters = parameters1;
} else {
throw new MapperParsingException("Invalid parameter: " + key);
}
}
if (name.isEmpty()) {
throw new MapperParsingException(NAME + " needs to be set");
}
MethodComponentContext method = new MethodComponentContext(name, parameters);
return new KNNMethodContext(engine, spaceType, method);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(KNN_ENGINE, knnEngine.getName());
builder.field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue());
builder = methodComponent.toXContent(builder, params);
return builder;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null || getClass() != obj.getClass())
return false;
KNNMethodContext other = (KNNMethodContext) obj;
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(knnEngine, other.knnEngine);
equalsBuilder.append(spaceType, other.spaceType);
equalsBuilder.append(methodComponent, other.methodComponent);
return equalsBuilder.isEquals();
}
@Override
public int hashCode() {
return new HashCodeBuilder().append(knnEngine).append(spaceType).append(methodComponent).toHashCode();
}
}