diff --git a/examples/baseline.py b/examples/baseline.py index 212acecb..79205400 100644 --- a/examples/baseline.py +++ b/examples/baseline.py @@ -2,12 +2,11 @@ from typing import Dict import torch -from torch import Tensor -from torchmetrics import AUROC, AveragePrecision, MeanAbsoluteError - from rtb.data import Table from rtb.data.task import TaskType from rtb.datasets import get_dataset +from torch import Tensor +from torchmetrics import AUROC, AveragePrecision, MeanAbsoluteError parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="relbench-forum") diff --git a/examples/fake/new_train.py b/examples/fake/new_train.py index d0d80c3e..b74ae77a 100644 --- a/examples/fake/new_train.py +++ b/examples/fake/new_train.py @@ -2,13 +2,12 @@ import torch import torch.nn.functional as F -from torch_geometric.loader import NodeLoader -from torch_geometric.nn import MLP -from torch_geometric.sampler import NeighborSampler - from rtb.datasets import FakeProductDataset from rtb.external.graph import make_pkey_fkey_graph from rtb.external.nn import HeteroEncoder, HeteroGraphSAGE +from torch_geometric.loader import NodeLoader +from torch_geometric.nn import MLP +from torch_geometric.sampler import NeighborSampler parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="churn", choices=["churn", "ltv"]) diff --git a/examples/fake/train.py b/examples/fake/train.py index 727a03a7..40f0664b 100644 --- a/examples/fake/train.py +++ b/examples/fake/train.py @@ -2,13 +2,12 @@ import torch import torch.nn.functional as F -from torch_geometric.loader import NodeLoader -from torch_geometric.nn import MLP -from torch_geometric.sampler import NeighborSampler - from rtb.datasets import FakeProductDataset from rtb.external.graph import get_train_table_input, make_pkey_fkey_graph from rtb.external.nn import HeteroEncoder, HeteroGraphSAGE +from torch_geometric.loader import NodeLoader +from torch_geometric.nn import MLP +from torch_geometric.sampler import NeighborSampler parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="churn", choices=["churn", "ltv"]) diff --git a/examples/xgboost_baseline.py b/examples/xgboost_baseline.py index 6727b4c0..888b6bfc 100644 --- a/examples/xgboost_baseline.py +++ b/examples/xgboost_baseline.py @@ -4,15 +4,14 @@ import pandas as pd import torch import torch_frame +from rtb.data.task import TaskType +from rtb.datasets import get_dataset from text_embedder import GloveTextEmbedding from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.data import Dataset from torch_frame.gbdt import XGBoost from torch_frame.typing import Metric -from rtb.data.task import TaskType -from rtb.datasets import get_dataset - parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="rtb-forum") parser.add_argument("--task", type=str, default="UserSumCommentScoresTask") diff --git a/relbench/tasks/amazon.py b/relbench/tasks/amazon.py index 25c7a9cb..927647f9 100644 --- a/relbench/tasks/amazon.py +++ b/relbench/tasks/amazon.py @@ -4,7 +4,7 @@ import pandas as pd from relbench.data import Database, RelBenchTask, Table -from relbench.metrics import accuracy, f1, mae, rmse, roc_auc, average_precision +from relbench.metrics import accuracy, average_precision, f1, mae, rmse, roc_auc class ChurnTask(RelBenchTask):