Skip to content

Commit

Permalink
run isort
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabh-ranjan committed Nov 27, 2023
1 parent 4c3e71c commit c9ac328
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 15 deletions.
5 changes: 2 additions & 3 deletions examples/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from typing import Dict

import torch
from torch import Tensor
from torchmetrics import AUROC, AveragePrecision, MeanAbsoluteError

from rtb.data import Table
from rtb.data.task import TaskType
from rtb.datasets import get_dataset
from torch import Tensor
from torchmetrics import AUROC, AveragePrecision, MeanAbsoluteError

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="relbench-forum")
Expand Down
7 changes: 3 additions & 4 deletions examples/fake/new_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import torch
import torch.nn.functional as F
from torch_geometric.loader import NodeLoader
from torch_geometric.nn import MLP
from torch_geometric.sampler import NeighborSampler

from rtb.datasets import FakeProductDataset
from rtb.external.graph import make_pkey_fkey_graph
from rtb.external.nn import HeteroEncoder, HeteroGraphSAGE
from torch_geometric.loader import NodeLoader
from torch_geometric.nn import MLP
from torch_geometric.sampler import NeighborSampler

parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="churn", choices=["churn", "ltv"])
Expand Down
7 changes: 3 additions & 4 deletions examples/fake/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import torch
import torch.nn.functional as F
from torch_geometric.loader import NodeLoader
from torch_geometric.nn import MLP
from torch_geometric.sampler import NeighborSampler

from rtb.datasets import FakeProductDataset
from rtb.external.graph import get_train_table_input, make_pkey_fkey_graph
from rtb.external.nn import HeteroEncoder, HeteroGraphSAGE
from torch_geometric.loader import NodeLoader
from torch_geometric.nn import MLP
from torch_geometric.sampler import NeighborSampler

parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="churn", choices=["churn", "ltv"])
Expand Down
5 changes: 2 additions & 3 deletions examples/xgboost_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import pandas as pd
import torch
import torch_frame
from rtb.data.task import TaskType
from rtb.datasets import get_dataset
from text_embedder import GloveTextEmbedding
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.data import Dataset
from torch_frame.gbdt import XGBoost
from torch_frame.typing import Metric

from rtb.data.task import TaskType
from rtb.datasets import get_dataset

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rtb-forum")
parser.add_argument("--task", type=str, default="UserSumCommentScoresTask")
Expand Down
2 changes: 1 addition & 1 deletion relbench/tasks/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

from relbench.data import Database, RelBenchTask, Table
from relbench.metrics import accuracy, f1, mae, rmse, roc_auc, average_precision
from relbench.metrics import accuracy, average_precision, f1, mae, rmse, roc_auc


class ChurnTask(RelBenchTask):
Expand Down

0 comments on commit c9ac328

Please sign in to comment.