Skip to content

Commit

Permalink
All setting index-base for input/output tensor in convert_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
etphipp committed Oct 22, 2024
1 parent 0535c77 commit c290e2a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
8 changes: 6 additions & 2 deletions src/Genten_TensorIO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,9 @@ queryFile()
template <typename ExecSpace>
TensorWriter<ExecSpace>::
TensorWriter(const std::string& fname,
const bool comp) : filename(fname), compressed(comp) {}
const ttb_indx ib,
const bool comp) :
filename(fname), index_base(ib), compressed(comp) {}

template <typename ExecSpace>
void
Expand Down Expand Up @@ -1119,7 +1121,9 @@ writeText(const SptensorT<ExecSpace>& X) const
{
Sptensor X_host = create_mirror_view(X);
deep_copy(X_host, X);
export_sptensor(filename, X_host, true, 15, true, compressed);
if (index_base != 0 && index_base != 1)
Genten::error("Writing a sparse tensor requires index base of 0 or 1");
export_sptensor(filename, X_host, true, 15, index_base==0, compressed);
}

template <typename ExecSpace>
Expand Down
2 changes: 2 additions & 0 deletions src/Genten_TensorIO.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ template <typename ExecSpace>
class TensorWriter {
public:
TensorWriter(const std::string& filename,
const ttb_indx index_base = 0,
const bool compressed = false);

void writeBinary(const SptensorT<ExecSpace>& X,
Expand All @@ -199,6 +200,7 @@ class TensorWriter {
void writeText(const TensorT<ExecSpace>& X) const;
private:
std::string filename;
ttb_indx index_base;
bool compressed;
};

Expand Down
48 changes: 30 additions & 18 deletions tools/convert_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
template <typename TensorType>
void print_tensor_stats(const TensorType& x)
{
std::cout << " Stats: ";
std::cout << " Stats: ";
const ttb_indx nd = x.ndims();
for (ttb_indx i=0; i<nd; ++i) {
std::cout << x.size(i);
Expand All @@ -65,19 +65,22 @@ void print_tensor_stats(const TensorType& x)

template <typename TensorType>
void save_tensor(const TensorType& x_in, const std::string& filename,
const std::string format, const std::string type, bool gz,
const std::string format, const std::string type,
const ttb_indx index_base, bool gz,
bool header)
{
std::cout << "\nOutput:\n"
<< " File: " << filename << std::endl
<< " Format: " << format << std::endl
<< " Type: " << type;
<< " File: " << filename << std::endl
<< " Index Base: " << index_base << std::endl
<< " Format: " << format << std::endl
<< " Type: " << type;
if (type == "text" && gz)
std::cout << " (compressed)";
if (type == "binary" && !header)
std::cout << " (no header)";
std::cout << std::endl;
Genten::TensorWriter<Genten::DefaultHostExecutionSpace> writer(filename,gz);
Genten::TensorWriter<Genten::DefaultHostExecutionSpace> writer(
filename,index_base,gz);
if (format == "sparse") {
Genten::Sptensor x_out(x_in);
print_tensor_stats(x_out);
Expand All @@ -97,10 +100,12 @@ void save_tensor(const TensorType& x_in, const std::string& filename,
}

void read_tensor_file(const std::string& filename,
const ttb_indx index_base,
std::string& format, std::string& type, bool gz,
Genten::Sptensor& x_sparse, Genten::Tensor& x_dense)
{
Genten::TensorReader<Genten::DefaultHostExecutionSpace> reader(filename,0,gz);
Genten::TensorReader<Genten::DefaultHostExecutionSpace> reader(
filename,index_base,gz);
reader.read();

if (reader.isSparse()) {
Expand All @@ -125,14 +130,16 @@ int main(int argc, char* argv[])
auto args = Genten::build_arg_list(argc,argv);
const bool help =
Genten::parse_ttb_bool(args, "--help", "--no-help", false);
if (argc < 9 || argc > 11 || help) {
if (argc < 9 || argc > 16 || help) {
std::cout << "\nconvert-tensor: a helper utility for converting tensor data between\n"
<< "tensor formats (sparse or dense), and file types (text or binary).\n\n"
<< "Usage: " << argv[0] << " --input-file <string> --output-file <string> --output-format <sparse|dense> --output-type <text|binary> [options] \n"
<< "Options:\n"
<< " --input-gz Input tensor is Gzip compressed (text-only, default: off)\n"
<< " --output-gz Output tensor is Gzip compressed (text-only, default: off)\n"
<< " --output-header Write header to output file (binary-only, default: on)\n";
<< " --input-gz Input tensor is Gzip compressed (text-only, default: off)\n"
<< " --output-gz Output tensor is Gzip compressed (text-only, default: off)\n"
<< " --output-header Write header to output file (binary-only, default: on)\n"
<< " --input-index-base Starting index for input tensor (sparse-only, default: 0)\n"
<< " --output-index-base Starting index for output tensor (sparse-only, default: 0)\n";
return 0;
}

Expand All @@ -153,6 +160,10 @@ int main(int argc, char* argv[])
Genten::parse_ttb_bool(args, "--output-gz", "--no-output-gz", false);
const bool output_header =
Genten::parse_ttb_bool(args, "--output-header", "--no-output-header", true);
const ttb_indx input_index_base =
Genten::parse_ttb_indx(args, "--input-index-base", 0, 0, INT_MAX);
const ttb_indx output_index_base =
Genten::parse_ttb_indx(args, "--output-index-base", 0, 0, INT_MAX);

if (input_filename == "")
Genten::error("input filename must be specified");
Expand All @@ -168,29 +179,30 @@ int main(int argc, char* argv[])
Genten::error("No header option only supported for binary output files");

std::cout << "\nInput:\n"
<< " File: " << input_filename << std::endl;
<< " File: " << input_filename << std::endl
<< " Index base: " << input_index_base << std::endl;

std::string input_format = "unknown";
std::string input_type = "unknown";
Genten::Sptensor x_sparse;
Genten::Tensor x_dense;
read_tensor_file(input_filename, input_format, input_type, input_gz,
x_sparse, x_dense);
read_tensor_file(input_filename, input_index_base, input_format, input_type,
input_gz, x_sparse, x_dense);

std::cout << " Format: " << input_format << std::endl
<< " Type: " << input_type;
std::cout << " Format: " << input_format << std::endl
<< " Type: " << input_type;
if (input_type == "text" && input_gz)
std::cout << " (compressed)";
std::cout << std::endl;
if (input_format == "sparse") {
print_tensor_stats(x_sparse);
save_tensor(x_sparse, output_filename, output_format, output_type,
output_gz, output_header);
output_index_base, output_gz, output_header);
}
else if (input_format == "dense") {
print_tensor_stats(x_dense);
save_tensor(x_dense, output_filename, output_format, output_type,
output_gz, output_header);
output_index_base, output_gz, output_header);
}
else
Genten::error("Invalid input tensor format!");
Expand Down

0 comments on commit c290e2a

Please sign in to comment.