Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabh-ranjan committed Nov 27, 2023
1 parent 4380ff4 commit 4c3e71c
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 23 deletions.
6 changes: 3 additions & 3 deletions examples/text_embedder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import List, Optional

import torch

# Please run `pip install -U sentence-transformers`
from sentence_transformers import SentenceTransformer
from torch import Tensor


class GloveTextEmbedding:

def __init__(self, device: Optional[torch.device] = None):
self.model = SentenceTransformer(
"sentence-transformers/average_word_embeddings_glove.6B.300d",
device=device)
"sentence-transformers/average_word_embeddings_glove.6B.300d", device=device
)

def __call__(self, sentences: List[str]) -> Tensor:
return torch.from_numpy(self.model.encode(sentences))
15 changes: 8 additions & 7 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

import torch
import torch_frame
from rtb.data.task import TaskType
from rtb.datasets import get_dataset
from rtb.external.graph import (
get_stype_proposal,
get_train_table_input,
make_pkey_fkey_graph,
)
from rtb.external.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder
from text_embedder import GloveTextEmbedding
from torch import Tensor
from torch.nn import BCEWithLogitsLoss, L1Loss
Expand All @@ -18,13 +26,6 @@
from torchmetrics import AUROC, AveragePrecision, MeanAbsoluteError
from tqdm import tqdm

from rtb.data.task import TaskType
from rtb.datasets import get_dataset
from rtb.external.graph import (get_stype_proposal, get_train_table_input,
make_pkey_fkey_graph)
from rtb.external.nn import (HeteroEncoder, HeteroGraphSAGE,
HeteroTemporalEncoder)

# Stores the informative text columns to retain for each table:
dataset_to_informative_text_cols = {}
dataset_to_informative_text_cols["rtb-forum"] = {
Expand Down
5 changes: 4 additions & 1 deletion relbench/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ def roc_auc(true: NDArray[np.float64], pred: NDArray[np.float64]) -> float:
assert pred.ndim == 1 or pred.shape[1] == 1
return skm.roc_auc_score(true, pred)


def average_precision(true: NDArray[np.float64], pred: NDArray[np.float64]) -> float:
assert pred.ndim == 1 or pred.shape[1] == 1
return skm.average_precision_score(true, pred)


def auprc(true: NDArray[np.float64], pred: NDArray[np.float64]) -> float:
assert pred.ndim == 1 or pred.shape[1] == 1
precision, recall, _ = skm.precision_recall_curve(true, pred)
return skm.auc(recall, precision)



### applicable to multiclass classification only


Expand Down
19 changes: 9 additions & 10 deletions relbench/tasks/stackex.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from typing import Dict, Tuple

import pandas as pd
import numpy as np
import pandas as pd
from tqdm import tqdm

from relbench.data import Database, RelBenchTask, Table
from relbench.metrics import accuracy, f1, mae, rmse, roc_auc, average_precision
from relbench.metrics import accuracy, average_precision, f1, mae, rmse, roc_auc
from relbench.utils import get_df_in_window
from tqdm import tqdm


class EngageTask(RelBenchTask):
Expand All @@ -21,7 +22,6 @@ class EngageTask(RelBenchTask):
timedelta = pd.Timedelta(days=365 * 2)
metrics = [average_precision, accuracy, f1, roc_auc]


def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
r"""Create Task object for UserContributionTask."""
timestamp_df = pd.DataFrame({"timestamp": timestamps})
Expand All @@ -44,7 +44,9 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab

def get_values_in_window(row, posts, users):
posts_window = get_df_in_window(posts, "CreationDate", row, self.timedelta)
comments_window = get_df_in_window(comments, "CreationDate", row, self.timedelta)
comments_window = get_df_in_window(
comments, "CreationDate", row, self.timedelta
)
votes_window = get_df_in_window(votes, "CreationDate", row, self.timedelta)

user_made_posts_in_this_period = posts_window.OwnerUserId.unique()
Expand Down Expand Up @@ -136,10 +138,7 @@ def get_values_in_window(row, votes, posts):
votes_window = get_df_in_window(votes, "CreationDate", row, self.timedelta)
posts_exist = posts[
(posts.CreationDate <= row["timestamp"])
& (
posts.CreationDate
> (row["timestamp"] - pd.Timedelta(days=365 * 2))
)
& (posts.CreationDate > (row["timestamp"] - pd.Timedelta(days=365 * 2)))
] ## posts exist and active defined by created in the last 2 years
posts_exist_ids = posts_exist.Id.values
train_table = pd.DataFrame()
Expand All @@ -164,4 +163,4 @@ def get_values_in_window(row, votes, posts):
fkey_col_to_pkey_table={self.entity_col: self.entity_table},
pkey_col=None,
time_col=self.time_col,
)
)
3 changes: 1 addition & 2 deletions relbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,5 @@ def unzip_processor(fname: Union[str, Path], action: str, pooch: pooch.Pooch) ->

def get_df_in_window(df, time_col, row, delta):
return df[
(df[time_col] > row["timestamp"])
& (df[time_col] <= (row["timestamp"] + delta))
(df[time_col] > row["timestamp"]) & (df[time_col] <= (row["timestamp"] + delta))
]

0 comments on commit 4c3e71c

Please sign in to comment.