Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: (minor) support various boolean expressions for outputting flags #5489

Merged
merged 6 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions source/module_base/formatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ class FmtCore
[&delim](const std::string& acc, const std::string& s) { return acc + delim + s; });
}

static std::string upper(const std::string& in)
{
std::string dst = in;
std::transform(dst.begin(), dst.end(), dst.begin(), ::toupper);
return dst;
}

static std::string lower(const std::string& in)
{
std::string dst = in;
std::transform(dst.begin(), dst.end(), dst.begin(), ::tolower);
return dst;
}

private:
std::string fmt_;
template<typename T>
Expand Down
96 changes: 30 additions & 66 deletions source/module_io/read_input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,79 +5,52 @@
#include <fstream>
#include <iostream>
#include <sstream>
#include <array>
#include <vector>
#include <cassert>
#include "module_base/formatter.h"
#include "module_base/global_file.h"
#include "module_base/global_function.h"
#include "module_base/tool_quit.h"
#include "module_base/tool_title.h"
namespace ModuleIO
{

void strtolower(char* sa, char* sb)
std::string longstring(const std::vector<std::string>& words)
{
char c;
int len = strlen(sa);
for (int i = 0; i < len; i++)
{
c = sa[i];
sb[i] = tolower(c);
}
sb[len] = '\0';
return FmtCore::join(" ", words);
}

std::string longstring(const std::vector<std::string>& str_values)
bool assume_as_boolean(const std::string& val)
{
std::string output;
output = "";
const size_t length = str_values.size();
for (int i = 0; i < length; ++i)
{
output += str_values[i];
if (i != length - 1)
{
output += " ";
}
}
return output;
}
const std::string val_ = FmtCore::lower(val);

bool convert_bool(std::string str)
{
for (auto& i: str)
{
i = tolower(i);
}
if (str == "true")
{
return true;
}
else if (str == "false")
{
return false;
}
else if (str == "1")
{
return true;
}
else if (str == "0")
{
return false;
}
else if (str == "t")
const std::array<std::string, 7> t_ = {"true", "1", "t", "yes", "y", "on", ".true."};
const std::array<std::string, 7> f_ = {"false", "0", "f", "no", "n", "off", ".false."};
// This will work because std::array<T, N>::size() is a constexpr function
// Ouch it is of C++17 standard...
// static_assert(t_.size() == f_.size(), "t_ and f_ must have the same lengths");
#ifdef __DEBUG // C++11 can do this
assert(t_.size() == f_.size());
#endif

if (std::find(t_.begin(), t_.end(), val_) != t_.end())
{
return true;
}
else if (str == "f")
else if (std::find(f_.begin(), f_.end(), val_) != f_.end())
{
return false;
}
else
{
std::string warningstr = "Bad boolean parameter ";
warningstr.append(str);
warningstr.append(", please check the input parameters in file INPUT");
ModuleBase::WARNING_QUIT("Input", warningstr);
std::string warnmsg = "Bad boolean parameter ";
warnmsg.append(val);
warnmsg.append(", please check the input parameters in file INPUT");
ModuleBase::WARNING_QUIT("Input", warnmsg);
}
}

std::string to_dir(const std::string& str)
{
std::string str_dir = str;
Expand Down Expand Up @@ -216,8 +189,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
ifs.clear();
ifs.seekg(0);

char word[80];
char word1[80];
std::string word, word1;
int ierr = 0;

// ifs >> std::setiosflags(ios::uppercase);
Expand All @@ -226,7 +198,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
{
ifs >> word;
ifs.ignore(150, '\n');
if (strcmp(word, "INPUT_PARAMETERS") == 0)
if (word == "INPUT_PARAMETERS")
{
ierr = 1;
break;
Expand All @@ -247,10 +219,8 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
while (ifs.good())
{
ifs >> word1;
if (ifs.eof()) {
break;
}
strtolower(word1, word);
if (ifs.eof()) { break; }
word = FmtCore::lower(word1);
auto it = std::find_if(input_lists.begin(),
input_lists.end(),
[&word](const std::pair<std::string, Input_Item>& item) { return item.first == word; });
Expand Down Expand Up @@ -311,7 +281,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
Input_Item* resetvalue_item = &(input_item.second);
if (resetvalue_item->reset_value != nullptr) {
resetvalue_item->reset_value(*resetvalue_item, param);
}
}
}
this->set_globalv(param);

Expand All @@ -327,7 +297,7 @@ void ReadInput::read_txt_input(Parameter& param, const std::string& filename)
Input_Item* checkvalue_item = &(input_item.second);
if (checkvalue_item->check_value != nullptr) {
checkvalue_item->check_value(*checkvalue_item, param);
}
}
}
}

Expand Down Expand Up @@ -505,12 +475,6 @@ void ReadInput::add_item(const Input_Item& item)
}
}

bool find_str(const std::vector<std::string>& strings, const std::string& strToFind)
{
auto it = std::find(strings.begin(), strings.end(), strToFind);
return it != strings.end();
}

std::string nofound_str(std::vector<std::string> init_chgs, const std::string& str)
{
std::string warningstr = "The parameter ";
Expand Down
6 changes: 1 addition & 5 deletions source/module_io/read_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,10 @@ class ReadInput
std::vector<std::function<void(Parameter&)>> bcastfuncs;
};

// convert string to lower case
void strtolower(char* sa, char* sb);
// convert string vector to a long string
std::string longstring(const std::vector<std::string>& str_values);
// convert string to bool
bool convert_bool(std::string str);
// if find a string in a vector of strings
bool find_str(const std::vector<std::string>& strings, const std::string& strToFind);
bool assume_as_boolean(const std::string& val);
// convert to directory format
std::string to_dir(const std::string& str);
// return a warning string if the string is not found in the vector
Expand Down
6 changes: 3 additions & 3 deletions source/module_io/read_input_item_elec_stru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ void ReadInput::item_elec_stru()

if (para.input.basis_type == "pw")
{
if (!find_str(pw_solvers, ks_solver))
if (std::find(pw_solvers.begin(), pw_solvers.end(), ks_solver) == pw_solvers.end())
{
const std::string warningstr = "For PW basis: " + nofound_str(pw_solvers, "ks_solver");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
}
}
else if (para.input.basis_type == "lcao")
{
if (!find_str(lcao_solvers, ks_solver))
if (std::find(lcao_solvers.begin(), lcao_solvers.end(), ks_solver) == lcao_solvers.end())
{
const std::string warningstr = "For LCAO basis: " + nofound_str(lcao_solvers, "ks_solver");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
Expand Down Expand Up @@ -163,7 +163,7 @@ void ReadInput::item_elec_stru()
};
item.check_value = [](const Input_Item& item, const Parameter& para) {
const std::vector<std::string> basis_types = {"pw", "lcao_in_pw", "lcao"};
if (!find_str(basis_types, para.input.basis_type))
if (std::find(basis_types.begin(), basis_types.end(), para.input.basis_type) == basis_types.end())
{
const std::string warningstr = nofound_str(basis_types, "basis_type");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
Expand Down
50 changes: 13 additions & 37 deletions source/module_io/read_input_item_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void ReadInput::item_output()
item.annotation = "output the structure files after each ion step";
item.reset_value = [](const Input_Item& item, Parameter& para) {
const std::vector<std::string> offlist = {"nscf", "get_S", "get_pchg", "get_wf"};
if (find_str(offlist, para.input.calculation))
if (std::find(offlist.begin(), offlist.end(), para.input.calculation) != offlist.end())
{
para.input.out_stru = false;
}
Expand Down Expand Up @@ -96,21 +96,13 @@ void ReadInput::item_output()
Input_Item item("out_band");
item.annotation = "output energy and band structure (with precision 8)";
item.read_value = [](const Input_Item& item, Parameter& para) {
size_t count = item.get_size();
if (count == 1)
{
para.input.out_band[0] = std::stoi(item.str_values[0]);
para.input.out_band[1] = 8;
}
else if (count == 2)
{
para.input.out_band[0] = std::stoi(item.str_values[0]);
para.input.out_band[1] = std::stoi(item.str_values[1]);
}
else
const size_t count = item.get_size();
if (count != 1 && count != 2)
{
ModuleBase::WARNING_QUIT("ReadInput", "out_band should have 1 or 2 values");
}
para.input.out_band[0] = assume_as_boolean(item.str_values[0]);
para.input.out_band[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8;
};
item.reset_value = [](const Input_Item& item, Parameter& para) {
if (para.input.calculation == "get_wf" || para.input.calculation == "get_pchg")
Expand Down Expand Up @@ -239,21 +231,13 @@ void ReadInput::item_output()
Input_Item item("out_mat_hs");
item.annotation = "output H and S matrix (with precision 8)";
item.read_value = [](const Input_Item& item, Parameter& para) {
size_t count = item.get_size();
if (count == 1)
{
para.input.out_mat_hs[0] = std::stoi(item.str_values[0]);
para.input.out_mat_hs[1] = 8;
}
else if (count == 2)
{
para.input.out_mat_hs[0] = std::stoi(item.str_values[0]);
para.input.out_mat_hs[1] = std::stoi(item.str_values[1]);
}
else
const size_t count = item.get_size();
if (count != 1 && count != 2)
{
ModuleBase::WARNING_QUIT("ReadInput", "out_mat_hs should have 1 or 2 values");
}
para.input.out_mat_hs[0] = assume_as_boolean(item.str_values[0]);
para.input.out_mat_hs[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8;
};
item.reset_value = [](const Input_Item& item, Parameter& para) {
if (para.input.qo_switch)
Expand All @@ -268,21 +252,13 @@ void ReadInput::item_output()
Input_Item item("out_mat_tk");
item.annotation = "output T(k)";
item.read_value = [](const Input_Item& item, Parameter& para) {
size_t count = item.get_size();
if (count == 1)
{
para.input.out_mat_tk[0] = std::stoi(item.str_values[0]);
para.input.out_mat_tk[1] = 8;
}
else if (count == 2)
{
para.input.out_mat_tk[0] = std::stoi(item.str_values[0]);
para.input.out_mat_tk[1] = std::stoi(item.str_values[1]);
}
else
const size_t count = item.get_size();
if (count != 1 && count != 2)
{
ModuleBase::WARNING_QUIT("ReadInput", "out_mat_tk should have 1 or 2 values");
}
para.input.out_mat_tk[0] = assume_as_boolean(item.str_values[0]);
para.input.out_mat_tk[1] = (count == 2) ? std::stoi(item.str_values[1]) : 8;
};
sync_intvec(input.out_mat_tk, 2, 0);
this->add_item(item);
Expand Down
4 changes: 2 additions & 2 deletions source/module_io/read_input_item_relax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void ReadInput::item_relax()
read_sync_string(input.relax_method);
item.check_value = [](const Input_Item& item, const Parameter& para) {
const std::vector<std::string> relax_methods = {"cg", "bfgs", "sd", "cg_bfgs"};
if (!find_str(relax_methods, para.input.relax_method))
if (std::find(relax_methods.begin(), relax_methods.end(), para.input.relax_method) == relax_methods.end())
{
const std::string warningstr = nofound_str(relax_methods, "relax_method");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
Expand Down Expand Up @@ -52,7 +52,7 @@ void ReadInput::item_relax()
const std::string& calculation = para.input.calculation;
const std::vector<std::string> singlelist
= {"scf", "nscf", "get_S", "get_pchg", "get_wf", "test_memory", "test_neighbour", "gen_bessel"};
if (find_str(singlelist, calculation))
if (std::find(singlelist.begin(), singlelist.end(), calculation) != singlelist.end())
{
para.input.relax_nmax = 1;
}
Expand Down
10 changes: 5 additions & 5 deletions source/module_io/read_input_item_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void ReadInput::item_system()
"get_wf",
"get_pchg",
"gen_bessel"};
if (!find_str(callist, calculation))
if (std::find(callist.begin(), callist.end(), calculation) == callist.end())
{
const std::string warningstr = nofound_str(callist, "calculation");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
Expand Down Expand Up @@ -111,7 +111,7 @@ void ReadInput::item_system()
read_sync_string(input.esolver_type);
item.check_value = [](const Input_Item& item, const Parameter& para) {
const std::vector<std::string> esolver_types = { "ksdft", "sdft", "ofdft", "tddft", "lj", "dp", "lr", "ks-lr" };
if (!find_str(esolver_types, para.input.esolver_type))
if (std::find(esolver_types.begin(), esolver_types.end(), para.input.esolver_type) == esolver_types.end())
{
const std::string warningstr = nofound_str(esolver_types, "esolver_type");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
Expand Down Expand Up @@ -208,15 +208,15 @@ void ReadInput::item_system()
item.reset_value = [](const Input_Item& item, Parameter& para) {
std::vector<std::string> use_force = {"cell-relax", "relax", "md"};
std::vector<std::string> not_use_force = {"get_wf", "get_pchg", "nscf", "get_S"};
if (find_str(use_force, para.input.calculation))
if (std::find(use_force.begin(), use_force.end(), para.input.calculation) != use_force.end())
{
if (!para.input.cal_force)
{
ModuleBase::GlobalFunc::AUTO_SET("cal_force", "true");
}
para.input.cal_force = true;
}
else if (find_str(not_use_force, para.input.calculation))
else if (std::find(not_use_force.begin(), not_use_force.end(), para.input.calculation) != not_use_force.end())
{
if (para.input.cal_force)
{
Expand Down Expand Up @@ -538,7 +538,7 @@ void ReadInput::item_system()
};
item.check_value = [](const Input_Item& item, const Parameter& para) {
const std::vector<std::string> init_chgs = {"atomic", "file", "wfc", "auto"};
if (!find_str(init_chgs, para.input.init_chg))
if (std::find(init_chgs.begin(), init_chgs.end(), para.input.init_chg) == init_chgs.end())
{
const std::string warningstr = nofound_str(init_chgs, "init_chg");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/read_input_tool.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#define strvalue item.str_values[0]
#define intvalue std::stoi(item.str_values[0])
#define doublevalue std::stod(item.str_values[0])
#define boolvalue convert_bool(item.str_values[0])
#define boolvalue assume_as_boolean(item.str_values[0])

#ifdef __MPI
#define add_double_bcast(PARAMETER) \
Expand Down
Loading