Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add environment variable interface #3568

Merged
merged 5 commits into from
Aug 18, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ limitations under the License. */
#include <memory> // for unique_ptr
#include <mutex> // for call_once

#include "glog/logging.h"

#include "paddle/memory/detail/buddy_allocator.h"
#include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/gpu_info.h"

DECLARE_double(fraction_of_gpu_memory_to_use);

namespace paddle {
namespace memory {
Expand Down Expand Up @@ -80,6 +85,11 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize()));
}
VLOG(3) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n"
<< "You can set environment variable '"
<< platform::kEnvFractionGpuMemoryToUse
<< "' to change the fraction of GPU usage.\n\n";
});

platform::SetDeviceId(gpu_id);
Expand Down
1 change: 0 additions & 1 deletion paddle/memory/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License. */

#pragma once

#include "paddle/platform/gpu_info.h"
#include "paddle/platform/place.h"

namespace paddle {
Expand Down
3 changes: 2 additions & 1 deletion paddle/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog)
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)

nv_library(gpu_info SRCS gpu_info.cc DEPS gflags)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog)

cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)

add_subdirectory(dynload)

cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece)
cc_test(environment_test SRCS environment_test.cc DEPS stringpiece)

IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
Expand Down
60 changes: 60 additions & 0 deletions paddle/platform/environment.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <stdlib.h>
#include <unistd.h>
#include <vector>

#include "paddle/platform/enforce.h"
#include "paddle/string/piece.h"

extern char** environ; // for environment variables

namespace paddle {
namespace platform {

inline void SetEnvVariable(const std::string& name, const std::string& value) {
PADDLE_ENFORCE_NE(setenv(name.c_str(), value.c_str(), 1), -1,
"Failed to set environment variable %s=%s", name, value);
}

inline void UnsetEnvVariable(const std::string& name) {
PADDLE_ENFORCE_NE(unsetenv(name.c_str()), -1,
"Failed to unset environment variable %s", name);
}

inline bool IsEnvVarDefined(const std::string& name) {
return std::getenv(name.c_str()) != nullptr;
}

inline std::string GetEnvValue(const std::string& name) {
PADDLE_ENFORCE(IsEnvVarDefined(name),
"Tried to access undefined environment variable %s", name);
return std::getenv(name.c_str());
}

inline std::vector<std::string> GetAllEnvVariables() {
std::vector<std::string> vars;
for (auto var = environ; *var != nullptr; ++var) {
auto tail = string::Index(*var, "=");
auto name = string::SubStr(*var, 0, tail).ToString();
vars.push_back(name);
}
return vars;
}

} // namespace platform
} // namespace paddle
54 changes: 54 additions & 0 deletions paddle/platform/environment_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/platform/environment.h"

#include "glog/logging.h"
#include "gtest/gtest.h"

TEST(ENVIRONMENT, ACCESS) {
namespace platform = paddle::platform;
namespace string = paddle::string;

platform::SetEnvVariable("PADDLE_USE_ENV", "TRUE");

EXPECT_TRUE(platform::IsEnvVarDefined("PADDLE_USE_ENV"));
EXPECT_EQ(platform::GetEnvValue("PADDLE_USE_ENV"), "TRUE");

platform::UnsetEnvVariable("PADDLE_USE_ENV");
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV"));

platform::SetEnvVariable("PADDLE_USE_ENV1", "Hello ");
platform::SetEnvVariable("PADDLE_USE_ENV2", "World, ");
platform::SetEnvVariable("PADDLE_USE_ENV3", "PaddlePaddle!");

std::string env_info;
auto vars = platform::GetAllEnvVariables();
for_each(vars.begin(), vars.end(), [&](const std::string& var) {
env_info += platform::GetEnvValue(var);
});

EXPECT_TRUE(string::Contains(env_info, "Hello World, PaddlePaddle!"));
platform::UnsetEnvVariable("PADDLE_USE_ENV1");
platform::UnsetEnvVariable("PADDLE_USE_ENV2");
platform::UnsetEnvVariable("PADDLE_USE_ENV3");

env_info.clear();
vars = platform::GetAllEnvVariables();
for_each(vars.begin(), vars.end(), [&](const std::string& var) {
env_info += platform::GetEnvValue(var);
});

EXPECT_FALSE(string::Contains(env_info, "Hello World, PaddlePaddle!"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV1"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV2"));
EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV3"));
}
10 changes: 10 additions & 0 deletions paddle/platform/gpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/platform/gpu_info.h"

#include "gflags/gflags.h"

#include "paddle/platform/enforce.h"
#include "paddle/platform/environment.h"

DEFINE_double(fraction_of_gpu_memory_to_use, 0.95,
"Default use 95% of GPU memory for PaddlePaddle,"
Expand Down Expand Up @@ -70,6 +73,13 @@ size_t GpuMaxChunkSize() {

GpuMemoryUsage(available, total);

if (IsEnvVarDefined(kEnvFractionGpuMemoryToUse)) {
auto val = std::stod(GetEnvValue(kEnvFractionGpuMemoryToUse));
PADDLE_ENFORCE_GT(val, 0.0);
PADDLE_ENFORCE_LE(val, 1.0);
FLAGS_fraction_of_gpu_memory_to_use = val;
}

// Reserving the rest memory for page tables, etc.
size_t reserving = (1 - FLAGS_fraction_of_gpu_memory_to_use) * total;

Expand Down
5 changes: 5 additions & 0 deletions paddle/platform/gpu_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ limitations under the License. */

#include <cuda_runtime.h>
#include <stddef.h>
#include <string>

namespace paddle {
namespace platform {

//! Environment variable: fraction of GPU memory to use on each device.
const std::string kEnvFractionGpuMemoryToUse =
"PADDLE_FRACTION_GPU_MEMORY_TO_USE";

//! Get the total number of GPU devices in system.
int GetDeviceCount();

Expand Down