Skip to content

Commit

Permalink
Update of fastutil to 8.3.0; fixes to #840
Browse files Browse the repository at this point in the history
  • Loading branch information
lintool committed Oct 25, 2019
1 parent f2ef38d commit 50d4314
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 77 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@
<dependency>
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil</artifactId>
<version>6.5.6</version>
<version>8.3.0</version>
</dependency>
<dependency>
<groupId>org.wikiclean</groupId>
Expand Down
13 changes: 5 additions & 8 deletions src/main/java/io/anserini/rerank/lib/Rm3Reranker.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public ScoredDocuments rerank(ScoredDocuments docs, RerankerContext context) {
Iterator<String> terms = rm.iterator();
while (terms.hasNext()) {
String term = terms.next();
float prob = rm.getFeatureWeight(term);
float prob = rm.getValue(term);
feedbackQueryBuilder.add(new BoostQuery(new TermQuery(new Term(this.field, term)), prob), BooleanClause.Occur.SHOULD);
}

Expand Down Expand Up @@ -133,10 +133,7 @@ private FeatureVector estimateRelevanceModel(ScoredDocuments docs, IndexReader r

for (int i = 0; i < numdocs; i++) {
try {
FeatureVector docVector = createdFeatureVector(
reader.getTermVector(docs.ids[i], field), reader, tweetsearch);
docVector.pruneToSize(fbTerms);

FeatureVector docVector = createdFeatureVector(reader.getTermVector(docs.ids[i], field), reader, tweetsearch);
vocab.addAll(docVector.getFeatures());
docvectors[i] = docVector;
} catch (IOException e) {
Expand All @@ -159,10 +156,10 @@ private FeatureVector estimateRelevanceModel(ScoredDocuments docs, IndexReader r
// Zero-length feedback documents occur (e.g., with CAR17) when a document has only terms
// that accents (which are indexed, but not selected for feedback).
if (norms[i] > 0.001f) {
fbWeight += (docvectors[i].getFeatureWeight(term) / norms[i]) * docs.scores[i];
fbWeight += (docvectors[i].getValue(term) / norms[i]) * docs.scores[i];
}
}
f.addFeatureWeight(term, fbWeight);
f.addFeatureValue(term, fbWeight);
}

f.pruneToSize(fbTerms);
Expand Down Expand Up @@ -230,7 +227,7 @@ private FeatureVector createdFeatureVector(Terms terms, IndexReader reader, bool
} else if (ratio > 0.1f) continue;

int freq = (int) termsEnum.totalTermFreq();
f.addFeatureWeight(term, (float) freq);
f.addFeatureValue(term, (float) freq);
}
} catch (Exception e) {
e.printStackTrace();
Expand Down
134 changes: 77 additions & 57 deletions src/main/java/io/anserini/util/FeatureVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
Expand All @@ -30,22 +29,26 @@
public class FeatureVector {
private Object2FloatOpenHashMap<String> features = new Object2FloatOpenHashMap<String>();

public enum Order {
FEATURE_DESCENDING, FEATURE_ASCENDING, VALUE_DESCENDING, VALUE_ASCENDING
}

public FeatureVector() {}

public void addFeatureWeight(String term, float weight) {
if (!features.containsKey(term)) {
features.put(term, weight);
public void addFeatureValue(String feature, float value) {
if (!features.containsKey(feature)) {
features.put(feature, value);
} else {
features.put(term, features.get(term) + weight);
features.put(feature, features.getFloat(feature) + value);
}
}

public FeatureVector pruneToSize(int k) {
List<KeyValuePair> pairs = getOrderedFeatures();
Object2FloatOpenHashMap<String> pruned = new Object2FloatOpenHashMap<String>();
List<FeatureValuePair> pairs = getOrderedFeatures();
Object2FloatOpenHashMap<String> pruned = new Object2FloatOpenHashMap<>();

for (KeyValuePair pair : pairs) {
pruned.put((String) pair.getKey(), pair.getValue());
for (FeatureValuePair pair : pairs) {
pruned.put(pair.getFeature(), pair.getValue());
if (pruned.size() >= k) {
break;
}
Expand All @@ -58,7 +61,7 @@ public FeatureVector pruneToSize(int k) {
public FeatureVector scaleToUnitL2Norm() {
double norm = computeL2Norm();
for (String f : features.keySet()) {
features.put(f, (float) (features.get(f) / norm));
features.put(f, (float) (features.getFloat(f) / norm));
}

return this;
Expand All @@ -67,7 +70,7 @@ public FeatureVector scaleToUnitL2Norm() {
public FeatureVector scaleToUnitL1Norm() {
double norm = computeL1Norm();
for (String f : features.keySet()) {
features.put(f, (float) (features.get(f) / norm));
features.put(f, (float) (features.getFloat(f) / norm));
}

return this;
Expand All @@ -77,8 +80,8 @@ public Set<String> getFeatures() {
return features.keySet();
}

public float getFeatureWeight(String feature) {
return features.containsKey(feature) ? features.get(feature) : 0.0f;
public float getValue(String feature) {
return features.containsKey(feature) ? features.getFloat(feature) : 0.0f;
}

public Iterator<String> iterator() {
Expand All @@ -105,55 +108,74 @@ public double computeL1Norm() {
return norm;
}

public static FeatureVector fromTerms(List<String> terms) {
public static FeatureVector fromTerms(List<String> features) {
FeatureVector f = new FeatureVector();
for (String t : terms) {
f.addFeatureWeight(t, 1.0f);
for (String t : features) {
f.addFeatureValue(t, 1.0f);
}
return f;
}

// VIEWING

@Override
public String toString() {
return this.toString(features.size());
private List<FeatureValuePair> getOrderedFeatures() {
return getOrderedFeatures(Order.VALUE_DESCENDING);
}

private List<KeyValuePair> getOrderedFeatures() {
List<KeyValuePair> kvpList = new ArrayList<KeyValuePair>(features.size());
private List<FeatureValuePair> getOrderedFeatures(Order order) {
List<FeatureValuePair> pairs = new ArrayList<>(features.size());
Iterator<String> featureIterator = features.keySet().iterator();
while (featureIterator.hasNext()) {
String feature = featureIterator.next();
float value = features.get(feature);
KeyValuePair keyValuePair = new KeyValuePair(feature, value);
kvpList.add(keyValuePair);
float value = features.getFloat(feature);
FeatureValuePair featureValuePair = new FeatureValuePair(feature, value);
pairs.add(featureValuePair);
}

Collections.sort(kvpList, new Comparator<KeyValuePair>() {
public int compare(KeyValuePair x, KeyValuePair y) {
double xVal = x.getValue();
double yVal = y.getValue();
if (order.equals(Order.VALUE_DESCENDING)) {
Collections.sort(pairs, (FeatureValuePair x, FeatureValuePair y) -> {
if (x.getValue() == y.getValue()) return x.getFeature().compareTo(y.getFeature());
return x.getValue() > y.getValue() ? -1 : 1;
});
} else if (order.equals(Order.VALUE_ASCENDING)) {
Collections.sort(pairs, (FeatureValuePair x, FeatureValuePair y) -> {
if (x.getValue() == y.getValue()) return x.getFeature().compareTo(y.getFeature());
return x.getValue() > y.getValue() ? 1 : -1;
});
} else if (order.equals(Order.FEATURE_ASCENDING)) {
Collections.sort(pairs, Comparator.comparing(FeatureValuePair::getFeature));
} else if (order.equals(Order.FEATURE_DESCENDING)) {
Collections.sort(pairs, Comparator.comparing(FeatureValuePair::getFeature).reversed());
}

return (xVal > yVal ? -1 : (xVal == yVal ? 0 : 1));
}
});
return pairs;
}

return kvpList;
@Override
public String toString() {
return this.toString(Order.VALUE_DESCENDING, features.size());
}

public String toString(int k) {
DecimalFormat format = new DecimalFormat("#.#########");
StringBuilder b = new StringBuilder();
List<KeyValuePair> kvpList = getOrderedFeatures();
Iterator<KeyValuePair> it = kvpList.iterator();
return this.toString(Order.VALUE_DESCENDING, k);
}

public String toString(Order order) {
return this.toString(order, features.size());
}

public String toString(Order order, int k) {
StringBuilder builder = new StringBuilder();
List<FeatureValuePair> features = getOrderedFeatures(order);
Iterator<FeatureValuePair> it = features.iterator();
builder.append("[");
int i = 0;
while (it.hasNext() && i++ < k) {
KeyValuePair pair = it.next();
b.append(format.format(pair.getValue()) + " " + pair.getKey() + "\n");
FeatureValuePair pair = it.next();
if (i!= 1)
builder.append(", ");
builder.append(pair.getFeature() + "=" + pair.getValue());
}
return b.toString();

builder.append("]");
return builder.toString();
}

public static FeatureVector interpolate(FeatureVector x, FeatureVector y, float xWeight) {
Expand All @@ -164,34 +186,32 @@ public static FeatureVector interpolate(FeatureVector x, FeatureVector y, float
Iterator<String> features = vocab.iterator();
while (features.hasNext()) {
String feature = features.next();
float weight = (float) (xWeight * x.getFeatureWeight(feature) + (1.0 - xWeight)
* y.getFeatureWeight(feature));
z.addFeatureWeight(feature, weight);
float weight = (float) (xWeight * x.getValue(feature) + (1.0 - xWeight) * y.getValue(feature));
z.addFeatureValue(feature, weight);
}
return z;
}

private class KeyValuePair {
private String key;
public class FeatureValuePair {
private String feature;
private float value;

public KeyValuePair(String key, float value) {
this.key = key;
public FeatureValuePair(String key, float value) {
this.feature = key;
this.value = value;
}

public String getKey() {
return key;
}

@Override
public String toString() {
StringBuilder b = new StringBuilder(value + "\t" + key);
return b.toString();
public String getFeature() {
return feature;
}

public float getValue() {
return value;
}

@Override
public String toString() {
return new StringBuilder(feature + "=" + value).toString();
}
}
}
66 changes: 55 additions & 11 deletions src/test/java/io/anserini/util/FeatureVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,68 @@
import java.util.HashSet;

public class FeatureVectorTest extends LuceneTestCase {
private final FeatureVector createAndAddFeatureWeights() {
private final FeatureVector createAndAddFeatureWeights1() {
FeatureVector fv = new FeatureVector();
fv.addFeatureWeight("a", 0.2f);
fv.addFeatureWeight("a", 0.3f);
fv.addFeatureWeight("b", 0.3f);
fv.addFeatureWeight("b", 0.5f);
fv.addFeatureWeight("c", 0.4f);
fv.addFeatureWeight("d", 0.1f);
fv.addFeatureValue("a", 0.2f);
fv.addFeatureValue("a", 0.3f);
fv.addFeatureValue("b", 0.3f);
fv.addFeatureValue("b", 0.5f);
fv.addFeatureValue("c", 0.4f);
fv.addFeatureValue("d", 0.1f);
return fv;
}


// To test tie-breaking
private final FeatureVector createAndAddFeatureWeights2() {
FeatureVector fv = new FeatureVector();
fv.addFeatureValue("ds", 0.2f);
fv.addFeatureValue("z", 0.5f);
fv.addFeatureValue("zz", 0.01f);
fv.addFeatureValue("c", 0.4f);
fv.addFeatureValue("a", 0.2f);
fv.addFeatureValue("x", 0.2f);
fv.addFeatureValue("1a", 0.2f);
fv.addFeatureValue("d", 0.6f);
return fv;
}

@Test
public void pruneToSizeTest() {
FeatureVector fv1 = createAndAddFeatureWeights();
FeatureVector fv1 = createAndAddFeatureWeights1();
assertEquals(fv1.pruneToSize(2).getFeatures().size(), 2);
FeatureVector fv2 = createAndAddFeatureWeights();
FeatureVector fv2 = createAndAddFeatureWeights1();
assertEquals(fv2.pruneToSize(1).getFeatures().size(), 1);
FeatureVector fv3 = createAndAddFeatureWeights();
FeatureVector fv3 = createAndAddFeatureWeights1();
assertEquals(fv3.pruneToSize(2).getFeatures(), new HashSet<>(Arrays.asList(new String[]{"a", "b"})));
}

@Test
public void toStringTest1() {
FeatureVector fv = createAndAddFeatureWeights2();
assertEquals("[d=0.6, z=0.5, c=0.4, 1a=0.2, a=0.2, ds=0.2, x=0.2, zz=0.01]", fv.toString());
// Make sure that feature value ties are broken lexicographically

assertEquals("[d=0.6, z=0.5, c=0.4]", fv.toString(3));
assertEquals("[d=0.6, z=0.5, c=0.4, 1a=0.2, a=0.2]", fv.toString(5));
}

@Test
public void toStringTest2() {
FeatureVector fv = createAndAddFeatureWeights2();
assertEquals("[d=0.6, z=0.5, c=0.4, 1a=0.2, a=0.2, ds=0.2, x=0.2, zz=0.01]",
fv.toString(FeatureVector.Order.VALUE_DESCENDING));
assertEquals("[d=0.6, z=0.5, c=0.4]", fv.toString(FeatureVector.Order.VALUE_DESCENDING, 3));

assertEquals("[zz=0.01, 1a=0.2, a=0.2, ds=0.2, x=0.2, c=0.4, z=0.5, d=0.6]",
fv.toString(FeatureVector.Order.VALUE_ASCENDING));
assertEquals("[zz=0.01, 1a=0.2, a=0.2]", fv.toString(FeatureVector.Order.VALUE_ASCENDING, 3));

assertEquals("[1a=0.2, a=0.2, c=0.4, d=0.6, ds=0.2, x=0.2, z=0.5, zz=0.01]",
fv.toString(FeatureVector.Order.FEATURE_ASCENDING));
assertEquals("[1a=0.2, a=0.2, c=0.4]", fv.toString(FeatureVector.Order.FEATURE_ASCENDING, 3));

assertEquals("[zz=0.01, z=0.5, x=0.2, ds=0.2, d=0.6, c=0.4, a=0.2, 1a=0.2]",
fv.toString(FeatureVector.Order.FEATURE_DESCENDING));
assertEquals("[zz=0.01, z=0.5, x=0.2]", fv.toString(FeatureVector.Order.FEATURE_DESCENDING, 3));
}
}

0 comments on commit 50d4314

Please sign in to comment.