Skip to content

Commit

Permalink
Refactor partition (#118)
Browse files Browse the repository at this point in the history
* #112 refactor partition

* refactor partition

* lint partition test

* use a Guo's solution for size parsing

* u[date

* change products path

* rename partition to feature partition & fix bugs & test with old implementation of partition algorithm

* sav
  • Loading branch information
eedalong authored Feb 22, 2022
1 parent fd115ef commit 119583d
Show file tree
Hide file tree
Showing 8 changed files with 998 additions and 379 deletions.
2 changes: 2 additions & 0 deletions srcs/python/quiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from .utils import Topo as p2pCliqueTopo
from .utils import init_p2p
from .comm import NcclComm, getNcclId
from .partition import quiver_partition_feature, load_quiver_feature_partition

__all__ = [
"Feature", "DistFeature", "GraphSageSampler", "PartitionInfo", "CSRTopo",
"MixedGraphSageSampler",
"SampleJob",
"quiver_partition_feature", "load_quiver_feature_partition"
"p2pCliqueTopo", "init_p2p", "getNcclId", "NcclComm"
]
29 changes: 4 additions & 25 deletions srcs/python/quiver/feature.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from quiver.shard_tensor import ShardTensor, ShardTensorConfig, Topo
from quiver.utils import reindex_feature, CSRTopo
from quiver.utils import reindex_feature, CSRTopo, parse_size
from typing import List
import numpy as np
from torch._C import device
Expand Down Expand Up @@ -70,26 +70,7 @@ def clique_device_symmetry_check(self):
self.topo.p2pClique2Device[0]):
return True
return False

def cal_memory_budget_bytes(self, memory_budget):
if isinstance(memory_budget, int):
return memory_budget
elif isinstance(memory_budget, float):
memory_budget = int(memory_budget)
elif isinstance(memory_budget, str):
if memory_budget.upper().endswith(
"M") or memory_budget.upper().endswith("MB"):
end = -1 if memory_budget.upper().endswith("M") else -2
memory_budget = int(float(memory_budget[:end]) * 1024 * 1024)
elif memory_budget.upper().endswith(
"G") or memory_budget.upper().endswith("GB"):
end = -1 if memory_budget.upper().endswith("G") else -2
memory_budget = int(
float(memory_budget[:end]) * 1024 * 1024 * 1024)
else:
raise Exception("memory budget input is not valid")
return memory_budget


def cal_size(self, cpu_tensor: torch.Tensor, cache_memory_budget: int):
element_size = cpu_tensor.shape[1] * 4
cache_size = cache_memory_budget // element_size
Expand Down Expand Up @@ -217,12 +198,10 @@ def from_cpu_tensor(self, cpu_tensor: torch.Tensor):
cpu_tensor (torch.FloatTensor): input cpu tensor
"""
if self.cache_policy == "device_replicate":
cache_memory_budget = self.cal_memory_budget_bytes(
self.device_cache_size)
cache_memory_budget = parse_size(self.device_cache_size)
shuffle_ratio = 0.0
else:
cache_memory_budget = self.cal_memory_budget_bytes(
self.device_cache_size) * len(self.topo.p2pClique2Device[0])
cache_memory_budget = parse_size(self.device_cache_size) * len(self.topo.p2pClique2Device[0])
shuffle_ratio = self.cal_size(
cpu_tensor, cache_memory_budget) / cpu_tensor.size(0)

Expand Down
256 changes: 151 additions & 105 deletions srcs/python/quiver/partition.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,173 @@
import torch
import shutil
import os
from typing import List
import quiver.utils as quiver_util

CHUNK_NUM = 32


def partition_without_replication(device, probs, ids):
"""Partition node with given node IDs and node access distribution.
__all__ = ["quiver_partition_feature", "load_quiver_feature_partition"]


QUIVER_MAGIC_NUMBER = 256

def partition_feature_without_replication(probs: List[torch.Tensor], chunk_size: int):
"""Partition node with node access distribution.
The result will cause no replication between each parititon.
We assume node IDs can be placed in the given device.
Args:
device (int): device which computes the partitioning strategy
probs (torch.Tensor): node access distribution
ids (Optional[torch.Tensor]): specified node IDs
chunk_size (int): chunk_size
Returns:
[torch.Tensor]: list of IDs for each partition
"""
ranks = len(probs)
if ids is not None:
ids = ids.to(device)
probs = [
prob[ids].to(device) if ids is not None else prob.to(device)
for prob in probs
]
total_size = ids.size(0) if ids is not None else probs[0].size(0)
res = [None] * ranks
for rank in range(ranks):
res[rank] = []
CHUNK_SIZE = (total_size + CHUNK_NUM - 1) // CHUNK_NUM
chunk_beg = 0
beg_rank = 0
for i in range(CHUNK_NUM):
chunk_end = min(total_size, chunk_beg + CHUNK_SIZE)
chunk_size = chunk_end - chunk_beg
chunk = torch.arange(chunk_beg,
chunk_end,
dtype=torch.int64,
device=device)

device = torch.cuda.current_device()
partitioned_num = len(probs)

probs = [prob.to(device) for prob in probs]
total_node_num = probs[0].size(0)

res = [[] for _ in range(partitioned_num)]

blob_size = chunk_size * partitioned_num
chunk_num = (total_node_num + chunk_size - 1) // chunk_size

current_chunk_start_pos = 0
current_partition_idx = 0
for _ in range(chunk_num):
current_chunk_end_pos = min(total_node_num, current_chunk_start_pos + blob_size)
current_chunk_size = current_chunk_end_pos - current_chunk_start_pos
chunk = torch.arange(current_chunk_start_pos, current_chunk_end_pos, device=device)
probs_sum_chunk = [
torch.zeros(chunk_size, device=device) + 1e-6 for i in range(ranks)
torch.zeros(current_chunk_size, device=device) + 1e-6 for _ in range(partitioned_num)
]
for rank in range(ranks):
for dst_rank in range(ranks):
if dst_rank == rank:
probs_sum_chunk[rank] += probs[dst_rank][chunk] * ranks
for src_rank in range(partitioned_num):
for dst_rank in range(partitioned_num):
if dst_rank == src_rank:
probs_sum_chunk[src_rank] += probs[dst_rank][chunk] * partitioned_num
else:
probs_sum_chunk[rank] -= probs[dst_rank][chunk]
acc_size = 0
rank_size = (chunk_size + ranks - 1) // ranks
picked_chunk_parts = torch.LongTensor([]).to(device)
for rank_ in range(beg_rank, beg_rank + ranks):
rank = rank_ % ranks
probs_sum_chunk[rank][picked_chunk_parts] -= 1e6
rank_size = min(rank_size, chunk_size - acc_size)
_, rank_order = torch.sort(probs_sum_chunk[rank], descending=True)
pick_chunk_part = rank_order[:rank_size]
probs_sum_chunk[src_rank] -= probs[dst_rank][chunk]
assigned_node_size = 0
per_partition_size = chunk_size
for partition_idx in range(current_partition_idx, current_partition_idx + partitioned_num):
partition_idx = partition_idx % partitioned_num
actual_per_partition_size = min(per_partition_size, current_chunk_size - assigned_node_size)
_, sorted_res_order = torch.sort(probs_sum_chunk[partition_idx], descending=True)
pick_chunk_part = sorted_res_order[:actual_per_partition_size]
pick_ids = chunk[pick_chunk_part]
picked_chunk_parts = torch.cat(
(picked_chunk_parts, pick_chunk_part))
res[rank].append(pick_ids)
acc_size += rank_size
beg_rank += 1
chunk_beg += chunk_size
for rank in range(ranks):
res[rank] = torch.cat(res[rank])
if ids is not None:
res[rank] = ids[res[rank]]
return res


def partition_with_replication(device, probs, ids, per_rank_size):
"""Partition node with given node IDs and node access distribution.
The result will cause replication between each parititon,
but the size of each partition will not exceed per_rank_size.
res[partition_idx].append(pick_ids)
for idx in range(partitioned_num):
probs_sum_chunk[idx][pick_chunk_part] = -1
assigned_node_size += actual_per_partition_size
current_partition_idx += 1
current_chunk_start_pos += current_chunk_size

for partition_idx in range(partitioned_num):
res[partition_idx] = torch.cat(res[partition_idx])
return res, probs


def quiver_partition_feature(probs:torch.Tensor, result_path: str, cache_memory_budget=0, per_feature_size=0, chunk_size=QUIVER_MAGIC_NUMBER):
"""
partition_res = partition_without_replication(device, probs, ids)
if ids is not None:
ids = ids.to(device)
ranks = len(probs)
total_res = [
torch.empty(per_rank_size, device=device) for i in range(ranks)
]
probs = [prob.clone().to(device) for prob in probs]
for rank in range(ranks):
partition_ids = partition_res[rank]
probs[rank][partition_ids] = -1e6
replication_size = per_rank_size - partition_ids.size(0)
_, prev_order = torch.sort(probs[rank], descending=True)
replication_ids = ids[
prev_order[:
replication_size]] if ids is not None else prev_order[:
replication_size]
total_res[rank] = torch.cat((partition_ids, replication_ids))
return total_res


def select_nodes(device, probs, ids):
nodes = probs[0].size(0)
prob_sum = torch.zeros(nodes, device=device)
for prob in probs:
if ids is None:
prob_sum += prob
Partition graph feature based on access probability and generate result folder. The final result folder will be like:
-result_path
-partition_0
-partition_res.pth
-cache_res.pth
-partition_1
-partition_res.pth
-cache_res.pth
-partition_2
-partition_res.pth
-cache_res.pth
...
Args:
probs:
result_path (str): path for partition result
cache_memory_budget (Union[str, int, float]): user-specified memory budget for caching hot feature
per_feature_size (Union[str, int, float]): per-feature size for user's feature
Returns:
partition_book (torch.Tensor): Indicates which partition_idx a node belongs to
feature_partition_res (torch.Tensor): partitioned feature result
feature_cache_res (torch.Tensor): cached feature result
"""

if os.path.exists(result_path):
res = input(f"{result_path} already exists, enter Y/N to continue, If continue, {result_path} will be deleted:")
res = res.upper()
if res == "Y":
shutil.rmtree(result_path)
else:
prob_sum[ids] += prob[ids]
node_ids = torch.nonzero(prob_sum)
return prob_sum, node_ids
print("exiting ...")
exit()

partition_num = len(probs)


# create result folder
for partition_idx in range(partition_num):
os.makedirs(os.path.join(result_path, f"feature_partition_{partition_idx}"))

# calculate cached feature count
cache_memory_budget_bytes = quiver_util.parse_size(cache_memory_budget)
per_feature_size_bytes = quiver_util.parse_size(per_feature_size)
cache_count = int(cache_memory_budget_bytes / (per_feature_size_bytes + 1e-6))
per_partition_cache_count = cache_count // partition_num

partition_book = torch.zeros(probs[0].shape, dtype=torch.int64, device=torch.cuda.current_device())
partition_res, changed_probs = partition_feature_without_replication(probs, chunk_size)

cache_res = [None] * partition_num

if cache_count > 0:
for partition_idx in range(partition_num):
_, prev_order = torch.sort(changed_probs[partition_idx], descending=True)
cache_res[partition_idx] = prev_order[: per_partition_cache_count]

for partition_idx in range(partition_num):
partition_result_path = os.path.join(result_path, f"feature_partition_{partition_idx}", "partition_res.pth")
cache_result_path = os.path.join(result_path, f"feature_partition_{partition_idx}", "cache_res.pth")
partition_book[partition_res[partition_idx]] = partition_idx
torch.save(partition_res[partition_idx], partition_result_path)
torch.save(cache_res[partition_idx], cache_result_path)

partition_book_path = os.path.join(result_path, f"feature_partition_book.pth")
torch.save(partition_book, partition_book_path)

return partition_book, partition_res, cache_res

def partition_free(device, probs, ids, per_rank_size):
"""Partition node with given node IDs and node access distribution.
The result will cause either replication or missing nodes across partitions.
The size of each partition is limited by per_rank_size.

def load_quiver_feature_partition(partition_idx: int, result_path:str):
"""
prob_sum, node_ids = select_nodes(device, probs, ids)
nodes = node_ids.size(0)
ranks = len(probs)
limit = ranks * per_rank_size
if nodes <= limit:
return partition_with_replication(device, probs, node_ids,
per_rank_size), None
else:
_, prev_order = torch.sort(prob_sum, descending=True)
limit_ids = prev_order[:limit]
return partition_without_replication(device, probs,
node_ids), limit_ids
Load partition result for partition ${partition_idx}
Args:
partition_idx (int): Partition idx
partition_result_path (str): partition result path
Returns:
partition_book (torch.Tensor): partition_book indicates which partition_idx a node belongs to
partition_res (torch.Tensor): node indexes belong to this partition
cache_res (torch.Tensor): cached node indexes belong to this partition
"""

if not os.path.exists(result_path):
raise Exception("Result path not exists")

partition_result_path = os.path.join(result_path, f"feature_partition_{partition_idx}", "partition_res.pth")
cache_result_path = os.path.join(result_path, f"feature_partition_{partition_idx}", "cache_res.pth")
partition_book_path = os.path.join(result_path, f"feature_partition_book.pth")


partition_book = torch.load(partition_book_path)
partition_res = torch.load(partition_result_path)
cache_res = torch.load(cache_result_path)

return partition_book, partition_res, cache_res
6 changes: 3 additions & 3 deletions srcs/python/quiver/pyg/sage_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ def sample(self, input_nodes):

return nodes, batch_size, adjs[::-1]

def sample_prob(self, train_idx, nodes):
def sample_prob(self, train_idx, total_node_count):
self.lazy_init_quiver()
last_prob = torch.zeros(nodes, device=self.device)
last_prob = torch.zeros(total_node_count, device=self.device)
last_prob[train_idx] = 1
for size in self.sizes:
cur_prob = torch.zeros(nodes, device=self.device)
cur_prob = torch.zeros(total_node_count, device=self.device)
self.quiver.cal_neighbor_prob(0, last_prob, cur_prob, size)
last_prob = cur_prob
return last_prob
Expand Down
Loading

0 comments on commit 119583d

Please sign in to comment.