Skip to content

Commit

Permalink
Generate spotify dataset with embeddings. (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Jan 26, 2023
1 parent 99a8e5e commit b043fd0
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
76 changes: 76 additions & 0 deletions python/benchmarks/spotify/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3
#

import os
from random import choice

import lance
import pyarrow as pa
import pandas as pd
from gensim.models.word2vec import Word2Vec


def generate_embeddings(col: pd.Series, **hyper_params) -> pd.Series:
# unique_tracks
sentences = []
for s in col:
if not isinstance(s, str):
continue
# We could prob do more cleaning here
sentences.append(s)
model = Word2Vec(sentences, **hyper_params)
print(f"Vector space size: {len(model.wv.index_to_key)}")
embeddings = []
for query_item in col:
if query_item not in model.wv:
query_item = choice(list(model.wv.index_to_key))
embeding = model.wv[query_item]
embeddings.append(embeding)

print(embeddings)
return embeddings


def main():
# Get dataset from:
# https://www.kaggle.com/datasets/andrewmvd/spotify-playlists
#
df = pd.read_csv("spotify_dataset.csv", on_bad_lines="skip")
# clean up the col names
df.columns = df.columns.str.replace('"', "")
df.columns = df.columns.str.replace("name", "")
df.columns = df.columns.str.replace(" ", "")
print(df)
print(f"Value counts: {df.nunique(axis=0)}")
for col in df.columns:
df[col] = df[col].astype("category")

# Hyper parameters from https://outerbounds.com/docs/recsys-tutorial-L4/
embeddings = generate_embeddings(
df["track"],
min_count=3,
epochs=30,
vector_size=48,
window=10,
ns_exponent=0.75,
workers=os.cpu_count(),
)
df["embeddings"] = embeddings
print(embeddings)

schema = pa.schema(
[
pa.field("user_id", pa.dictionary(pa.uint16(), pa.utf8())),
pa.field("artist", pa.dictionary(pa.uint32(), pa.utf8())),
pa.field("track", pa.dictionary(pa.uint32(), pa.utf8())),
pa.field("playlist", pa.dictionary(pa.uint32(), pa.utf8())),
pa.field("embeddings", pa.list_(pa.float32(), 48))
]
)
table = pa.Table.from_pandas(df, schema=schema)

lance.write_dataset(table, "spotify.lance", max_rows_per_group=8192)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions python/benchmarks/spotify/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gensim

0 comments on commit b043fd0

Please sign in to comment.