Skip to content

Commit

Permalink
Add support for dynamic indices in VMEM loads and stores
Browse files Browse the repository at this point in the history
... at least in all but the last two dimensions, which have more stringent alignment requirements.

PiperOrigin-RevId: 578463563
  • Loading branch information
apaszke authored and jax authors committed Nov 1, 2023
1 parent 32a317f commit f17f549
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 99 deletions.
117 changes: 78 additions & 39 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <optional>
#include <tuple>
#include <utility>
#include <vector>

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -271,12 +272,15 @@ FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
}

FailureOr<int64_t> getIntConst(Value v) {
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
return integer_attr.getValue().getSExtValue();
}
}
if (silent) {
return failure();
}
return emitError(v.getLoc(), "Expected an integer constant");
}

Expand All @@ -289,6 +293,31 @@ FailureOr<SmallVector<int64_t>> getIntConstsFromOperandRange(
return res;
}

SmallVector<std::vector<Value>> getDimIndices(OperandRange indices,
ArrayRef<int64_t> shape,
ImplicitLocOpBuilder& builder) {
CHECK_EQ(indices.size(), shape.size());
SmallVector<std::vector<Value>> result(indices.size());
for (int dim = 0; dim < indices.size(); ++dim) {
auto& dim_idx = result[dim];
dim_idx.reserve(shape[dim]);
if (auto idx_const = getIntConst(indices[dim], /*silent=*/true);
succeeded(idx_const)) {
int64_t cst = idx_const.value();
for (int64_t off = 0; off < shape[dim]; ++off) {
dim_idx.push_back(IdxConst(cst + off, builder, builder.getLoc()));
}
} else {
for (int64_t off = 0; off < shape[dim]; ++off) {
dim_idx.push_back(builder.create<arith::AddIOp>(
indices[dim], IdxConst(off, builder, builder.getLoc())));
}
}
}
return result;
}


// Returns the first-level tiling of a (packed and tiled) memref value.
FailureOr<std::array<int64_t, 2>> getMemRefTiling(
TypedValue<MemRefType> value, const std::array<int64_t, 2> target_shape) {
Expand Down Expand Up @@ -1555,11 +1584,16 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
FAILUREOR_ASSIGN_OR_RETURN(
VectorType target_ty,
getNativeVregType(vty.getElementType(), ctx.target_shape));
if (layout_out.implicit_dim() == VectorLayout::ImplicitDim::kMinor) {
return op.emitOpError("Not implemented");
if (vty.getRank() == 0) {
op.emitOpError("Not implemented: scalar loads from vmem");
}
const bool is_1d = vty.getRank() == 1;
VectorLayout::ImplicitDim expected_dim =
is_1d ? VectorLayout::ImplicitDim::kSecondMinor
: VectorLayout::ImplicitDim::kNone;
if (layout_out.implicit_dim() != expected_dim) {
return op.emitOpError("Not implemented: unsupported layout");
}
const bool is_1d =
layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone;
using Tiling = std::array<int64_t, 2>; // To avoid comma in macro
FAILUREOR_ASSIGN_OR_RETURN(
Tiling memref_tiling,
Expand All @@ -1574,16 +1608,8 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
}
// TODO(apaszke): Check that loads are from vmem!
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<int64_t> indices,
getIntConstsFromOperandRange(load_op.getIndices()));
if (llvm::any_of(
llvm::zip_equal(indices, vty.getShape(), memref_ty.getShape()),
[](auto tup) {
auto [idx, n, extent] = tup;
return idx + n > extent;
})) {
return op.emitOpError("Reading out of bounds");
}
const SmallVector<int64_t> tile_indices,
getIntConstsFromOperandRange(load_op.getIndices().take_back(2 - is_1d)));
const SmallVector<int64_t> implicit_shape =
layout_out.implicitShape(vty.getShape());
const int64_t ss = implicit_shape[implicit_shape.size() - 2];
Expand Down Expand Up @@ -1624,38 +1650,42 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
layout_out.tileArrayShape(vty.getShape(), ctx.target_shape));
const std::array<int64_t, 2> vreg_slice =
layout_out.vregSlice(ctx.target_shape);
const int64_t num_dims = indices.size();
const int64_t num_dims = vty.getRank();
const int64_t num_batch_dims = num_dims - (is_1d ? 1 : 2);
SmallVector<std::vector<Value>> base_batch =
getDimIndices(load_op.getIndices().take_front(num_batch_dims),
vty.getShape().take_front(num_batch_dims),
builder);
const absl::Status status =
tiles.EachStatus([&](absl::Span<const int64_t> tile_idxs, Value * /*v*/) {
CHECK_EQ(num_dims, tile_idxs.size());
SmallVector<int64_t> idxs(tile_idxs.size());
SmallVector<Value> idxs(tile_idxs.size());
for (int64_t i = 0; i < num_batch_dims; ++i) {
idxs[i] = tile_idxs[i] + indices[i];
idxs[i] = base_batch[i][tile_idxs[i]];
}
const int64_t base_l = indices[num_dims - 1];
const int64_t base_l = tile_indices.back();
const int64_t lidx = tile_idxs[num_dims - 1];
idxs[num_dims - 1] = base_l + lidx * vreg_slice[1] - *offsets[1];
idxs[num_dims - 1] =
IdxConst(base_l + lidx * vreg_slice[1] - *offsets[1], builder,
load_op->getLoc());
if (!is_1d) {
const int64_t base_s = indices[num_dims - 2];
CHECK_EQ(tile_indices.size(), 2);
const int64_t base_s = tile_indices.front();
const int64_t sidx = tile_idxs[num_dims - 2];
idxs[num_dims - 2] =
base_s + sidx * vreg_slice[0] - offsets[0].value_or(0);
IdxConst(base_s + sidx * vreg_slice[0] - offsets[0].value_or(0),
builder, load_op->getLoc());
}
CHECK(tile_idxs[num_dims - 1] + ctx.target_shape[1] <=
memref_ty.getShape()[num_dims - 1]);
std::unique_ptr<VRegDataBounds> bounds = layout_out.tileDataBounds(
mlir_ctx, vty.getShape(), toArrayRef(tile_idxs), ctx.target_shape,
/*allow_replicated =*/{true, false});
SmallVector<Value> idxs_vs(idxs.size());
for (int64_t i = 0; i < idxs.size(); ++i) {
idxs_vs[i] = IdxConst(idxs[i], builder, load_op->getLoc());
}
Operation *tile;
if (bounds->maskVariesAlong(Direction::kSublanes, ctx.target_shape)) {
CHECK(offsets[0].has_value());
tile = builder.create<tpu::LoadOp>(
target_ty, load_op.getBase(), idxs_vs,
target_ty, load_op.getBase(), idxs,
bounds->getSublaneMask(mlir_ctx, ctx.target_shape),
builder.getI32IntegerAttr(sublane_stride));
} else {
Expand All @@ -1666,14 +1696,14 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
return absl::UnimplementedError("");
}
tile = builder.create<vector::TransferReadOp>(
target_ty, load_op.getBase(), idxs_vs, load_map, padding,
target_ty, load_op.getBase(), idxs, load_map, padding,
nullptr, nullptr);
} else {
const SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
const auto sublane_mask_attr =
DenseBoolArrayAttr::get(mlir_ctx, sublane_mask);
tile = builder.create<tpu::LoadOp>(
target_ty, load_op.getBase(), idxs_vs, sublane_mask_attr,
target_ty, load_op.getBase(), idxs, sublane_mask_attr,
builder.getI32IntegerAttr(sublane_stride));
}
}
Expand Down Expand Up @@ -2507,11 +2537,16 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
vector::StoreOp store_op = cast<vector::StoreOp>(op);
const VectorType ty = store_op.getValueToStore().getType();
const VectorLayout &to_store_layout = *layouts_in.front();
if (to_store_layout.implicit_dim() == VectorLayout::ImplicitDim::kMinor) {
return op.emitOpError("Not implemented");
if (!ty.getRank()) {
return op.emitOpError("Not implemented: scalar stores to vmem");
}
const bool is_1d = ty.getRank() == 1;
VectorLayout::ImplicitDim expected_dim =
is_1d ? VectorLayout::ImplicitDim::kSecondMinor
: VectorLayout::ImplicitDim::kNone;
if (to_store_layout.implicit_dim() != expected_dim) {
return op.emitOpError("Not implemented: unsupported layout");
}
const bool is_1d =
to_store_layout.implicit_dim() != VectorLayout::ImplicitDim::kNone;
using Tiling = std::array<int64_t, 2>;
FAILUREOR_ASSIGN_OR_RETURN(
const Tiling memref_tiling,
Expand All @@ -2525,19 +2560,23 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
}
}
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<int64_t> base_indices,
getIntConstsFromOperandRange(store_op.getIndices()));
const SmallVector<int64_t> tile_indices,
getIntConstsFromOperandRange(store_op.getIndices().take_back(2 - is_1d)));
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tiles,
disassemble(ctx, builder, to_store_layout, store_op.getValueToStore()));
const int64_t ndims = base_indices.size();
const int64_t ndims = ty.getRank();
const int64_t nbatchdims = is_1d ? ndims - 1 : ndims - 2;
const int64_t base_s = is_1d ? 0 : base_indices[ndims - 2];
const int64_t base_l = base_indices[ndims - 1];
const int64_t base_s = is_1d ? 0 : tile_indices.front();
const int64_t base_l = tile_indices.back();
if (is_1d) {
tiles.Reshape(
to_store_layout.implicitShape(toArrayRef(tiles.dimensions())));
}
SmallVector<std::vector<Value>> base_batch =
getDimIndices(store_op.getIndices().take_front(nbatchdims),
ty.getShape().take_front(nbatchdims),
builder);
const LayoutOffset sublane_offset = to_store_layout.offsets()[0];
const LayoutOffset lane_offset = to_store_layout.offsets()[1];
if (!sublane_offset.has_value() || !lane_offset.has_value()) {
Expand Down Expand Up @@ -2567,7 +2606,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder,
store_op->getLoc());
for (int64_t i = 0; i < nbatchdims; ++i) {
indices[i] = boundIdxConst(idx[i] + base_indices[i]);
indices[i] = base_batch[i][idx[i]];
}
if (!is_1d) {
*(indices.end() - 2) =
Expand Down
64 changes: 40 additions & 24 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -849,17 +849,24 @@ class VectorLayoutInferer {

SmallVector<Layout, 4> in_layout(op->getNumOperands(), kNoLayout);
CHECK_EQ(op->getNumOperands(), op.getIndices().size() + 1);
SmallVector<int64_t, 4> indices;
indices.reserve(rank);
for (Value v : op.getIndices()) {
auto cst_op = v.getDefiningOp<arith::ConstantOp>();
TPU_CHECK_OP(cst_op, "only constant indices are supported");
indices.push_back(cast<IntegerAttr>(cst_op.getValue()).getInt());
}
for (int64_t i = 0; i < rank; ++i) {
TPU_CHECK_OP(indices[i] + res_ty.getDimSize(i) <= src_ty.getDimSize(i),
"Loading elements out of bounds");
SmallVector<int64_t, 2> tile_indices;
for (int i = rank - 1; i >= 0; --i) {
auto cst_op = op.getIndices()[i].getDefiningOp<arith::ConstantOp>();
if (cst_op) {
int64_t idx = cast<IntegerAttr>(cst_op.getValue()).getInt();
TPU_CHECK_OP(idx + res_ty.getDimSize(i) <= src_ty.getDimSize(i),
"Loading elements out of bounds");
if (tile_indices.size() < 2) {
tile_indices.push_back(idx);
}
} else {
TPU_CHECK_OP(
tile_indices.size() == 2,
"Dynamic indices are not supported in the last two dimensions");
}
}
// We pushed the indices in reverse.
std::reverse(tile_indices.begin(), tile_indices.end());

if (rank == 0) {
op.emitOpError("rank 0 vectors unsupported");
Expand All @@ -870,16 +877,17 @@ class VectorLayoutInferer {
auto tile = tiling.front();
TPU_CHECK_OP(tile % target_shape_[1] == 0,
"Unsupported tiling for 1D load");
int64_t idx = indices.front();
CHECK_EQ(tile_indices.size(), 1);
int64_t idx = tile_indices.front();
int64_t offset = idx % kVmemAlignment32;
// TODO(apaszke): We could generate replicated loads for short values.
setLayout(op, in_layout,
VectorLayout(bitwidth, {0, offset}, {1, tile},
ImplicitDim::kSecondMinor));
} else { // rank >= 2
TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ loads");
CHECK_EQ(tile_indices.size(), 2);
std::array<std::optional<int64_t>, 2> offsets;
const auto tile_indices = ArrayRef<int64_t>(indices).take_back(2);
const auto tile_src_shape = src_ty.getShape().take_back(2);
const auto tile_res_shape = res_ty.getShape().take_back(2);
const int64_t num_sublanes = tile_res_shape[0];
Expand Down Expand Up @@ -1140,17 +1148,24 @@ class VectorLayoutInferer {
}
auto tiling = *maybe_tiling;

SmallVector<int64_t, 4> indices;
indices.reserve(rank);
for (Value v : op.getIndices()) {
auto cst_op = v.getDefiningOp<arith::ConstantOp>();
TPU_CHECK_OP(cst_op, "only constant indices are supported");
indices.push_back(cast<IntegerAttr>(cst_op.getValue()).getInt());
}
for (int64_t i = 0; i < rank; ++i) {
TPU_CHECK_OP(indices[i] + store_ty.getDimSize(i) <= ref_ty.getDimSize(i),
"storing elements out of bounds");
SmallVector<int64_t, 2> tile_indices;
for (int i = rank - 1; i >= 0; --i) {
auto cst_op = op.getIndices()[i].getDefiningOp<arith::ConstantOp>();
if (cst_op) {
int64_t idx = cast<IntegerAttr>(cst_op.getValue()).getInt();
TPU_CHECK_OP(idx + store_ty.getDimSize(i) <= ref_ty.getDimSize(i),
"Loading elements out of bounds");
if (tile_indices.size() < 2) {
tile_indices.push_back(idx);
}
} else {
TPU_CHECK_OP(
tile_indices.size() == 2,
"Dynamic indices are not supported in the last two dimensions");
}
}
// We pushed the indices in reverse.
std::reverse(tile_indices.begin(), tile_indices.end());

Layout store_layout;
if (rank == 0) {
Expand All @@ -1162,14 +1177,15 @@ class VectorLayoutInferer {
auto tile = tiling.front();
TPU_CHECK_OP(tile % target_shape_[1] == 0,
"Unsupported 1D tiling for 1D store");
int64_t idx = indices.front();
CHECK_EQ(tile_indices.size(), 1);
int64_t idx = tile_indices.front();
int64_t offset = idx % kVmemAlignment32;
store_layout = VectorLayout(bitwidth, {0, offset}, {1, tile},
ImplicitDim::kSecondMinor);
} else { // rank >= 2 // NOLINT(readability-else-after-return)
TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ store");
CHECK_EQ(tile_indices.size(), 2);
std::array<std::optional<int64_t>, 2> offsets;
const auto tile_indices = ArrayRef<int64_t>(indices).take_back(2);
const auto tile_ref_shape = ref_ty.getShape().take_back(2);
const auto tile_store_shape = store_ty.getShape().take_back(2);
const int64_t num_sublanes = tile_store_shape[0];
Expand Down
Loading

0 comments on commit f17f549

Please sign in to comment.