Skip to content

Commit

Permalink
Fixes brave/brave-browser#11043 - Uses PKCE flow for Gemini auth
Browse files Browse the repository at this point in the history
Using new shared utility
  • Loading branch information
ryanml committed Aug 3, 2020
1 parent e007a15 commit 38db332
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 49 deletions.
39 changes: 3 additions & 36 deletions components/binance/browser/binance_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,20 @@
#include "base/containers/flat_set.h"
#include "base/files/file_enumerator.h"
#include "base/files/file_util.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/post_task.h"
#include "base/task_runner_util.h"
#include "base/time/time.h"
#include "base/token.h"
#include "brave/common/pref_names.h"
#include "brave/components/binance/browser/binance_json_parser.h"
#include "brave/components/crypto_exchange/browser/crypto_exchange_oauth_util.h"
#include "components/country_codes/country_codes.h"
#include "components/os_crypt/os_crypt.h"
#include "components/prefs/pref_service.h"
#include "components/user_prefs/user_prefs.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/storage_partition.h"
#include "crypto/random.h"
#include "crypto/sha2.h"
#include "net/base/load_flags.h"
#include "net/base/url_util.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
Expand Down Expand Up @@ -72,15 +70,6 @@ GURL GetURLWithPath(const std::string& host, const std::string& path) {
return GURL(std::string(url::kHttpsScheme) + "://" + host).Resolve(path);
}

std::string GetHexEncodedCryptoRandomSeed() {
const size_t kSeedByteLength = 32;
// crypto::RandBytes is fail safe.
uint8_t random_seed_bytes[kSeedByteLength];
crypto::RandBytes(random_seed_bytes, kSeedByteLength);
return base::HexEncode(
reinterpret_cast<char*>(random_seed_bytes), kSeedByteLength);
}

} // namespace

BinanceService::BinanceService(content::BrowserContext* context)
Expand All @@ -98,36 +87,14 @@ BinanceService::BinanceService(content::BrowserContext* context)
BinanceService::~BinanceService() {
}

// static
std::string BinanceService::GetCodeChallenge(const std::string& code_verifier) {
std::string code_challenge;
char raw[crypto::kSHA256Length] = {0};
crypto::SHA256HashString(code_verifier,
raw,
crypto::kSHA256Length);
base::Base64Encode(base::StringPiece(raw,
crypto::kSHA256Length),
&code_challenge);

// Binance expects the following conversions for the base64 encoded value:
std::replace(code_challenge.begin(), code_challenge.end(), '+', '-');
std::replace(code_challenge.begin(), code_challenge.end(), '/', '_');
code_challenge.erase(std::find_if(code_challenge.rbegin(),
code_challenge.rend(), [](int ch) {
return ch != '=';
}).base(), code_challenge.end());

return code_challenge;
}

std::string BinanceService::GetOAuthClientUrl() {
// The code_challenge_ value is derived from the code_verifier value.
// Step 1 of the oauth process uses the code_challenge_ value.
// Step 4 of the oauth process uess the code_verifer_.
// We never need to persist these values, they are just used to get an
// access token.
code_verifier_ = GetHexEncodedCryptoRandomSeed();
code_challenge_ = GetCodeChallenge(code_verifier_);
code_verifier_ = ::crypto_exchange::GetCryptoRandomString(true);
code_challenge_ = crypto_exchange::GetCodeChallenge(code_verifier_, true);

GURL url(oauth_url);
url = net::AppendQueryParameter(url, "response_type", "code");
Expand Down
1 change: 0 additions & 1 deletion components/binance/browser/binance_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class BinanceService : public KeyedService {

std::string GetBinanceTLD();
std::string GetOAuthClientUrl();
static std::string GetCodeChallenge(const std::string& code_verifier);
void SetAuthToken(const std::string& auth_token);

private:
Expand Down
22 changes: 10 additions & 12 deletions components/gemini/browser/gemini_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,19 @@
#include "base/files/file_enumerator.h"
#include "base/files/file_util.h"
#include "base/json/json_writer.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/post_task.h"
#include "base/task_runner_util.h"
#include "base/time/time.h"
#include "base/token.h"
#include "brave/components/crypto_exchange/browser/crypto_exchange_oauth_util.h"
#include "brave/components/gemini/browser/gemini_json_parser.h"
#include "brave/components/gemini/browser/pref_names.h"
#include "components/os_crypt/os_crypt.h"
#include "components/prefs/pref_service.h"
#include "components/user_prefs/user_prefs.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/storage_partition.h"
#include "crypto/random.h"
#include "net/base/load_flags.h"
#include "net/base/url_util.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
Expand Down Expand Up @@ -106,15 +105,6 @@ namespace {
return encoded_payload;
}

std::string GetEncodedCryptoRandomCSRF() {
uint8_t random_seed_bytes[24];
crypto::RandBytes(random_seed_bytes, 24);
std::string encoded_csrf;
base::Base64Encode(
reinterpret_cast<char*>(random_seed_bytes), &encoded_csrf);
return encoded_csrf;
}

} // namespace

GeminiService::GeminiService(content::BrowserContext* context)
Expand All @@ -135,11 +125,16 @@ GeminiService::~GeminiService() {

std::string GeminiService::GetOAuthClientUrl() {
GURL url(oauth_url);
code_verifier_ = crypto_exchange::GetCryptoRandomString(false);
code_challenge_ = crypto_exchange::GetCodeChallenge(code_verifier_, false);
url = net::AppendQueryParameter(url, "response_type", "code");
url = net::AppendQueryParameter(url, "client_id", client_id_);
url = net::AppendQueryParameter(url, "redirect_uri", oauth_callback);
url = net::AppendQueryParameter(url, "scope", oauth_scope);
url = net::AppendQueryParameter(url, "state", GetEncodedCryptoRandomCSRF());
url = net::AppendQueryParameter(url, "code_challenge", code_challenge_);
url = net::AppendQueryParameter(url, "code_challenge_method", "S256");
url = net::AppendQueryParameter(
url, "state", ::crypto_exchange::GetCryptoRandomString(false));
return url.spec();
}

Expand All @@ -157,6 +152,7 @@ bool GeminiService::GetAccessToken(AccessTokenCallback callback) {
dict.SetStringKey("client_secret", client_secret_);
dict.SetStringKey("code", auth_token_);
dict.SetStringKey("redirect_uri", oauth_callback);
dict.SetStringKey("code_verifier", code_verifier_);
dict.SetStringKey("grant_type", "authorization_code");
std::string request_body = CreateJSONRequestBody(dict);

Expand Down Expand Up @@ -278,6 +274,8 @@ void GeminiService::OnRevokeAccessToken(
const std::map<std::string, std::string>& headers) {
bool success = status >= 200 && status <= 299;
if (success) {
code_challenge_ = "";
code_verifier_ = "";
ResetAccessTokens();
}
std::move(callback).Run(success);
Expand Down
2 changes: 2 additions & 0 deletions components/gemini/browser/gemini_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class GeminiService : public KeyedService {
std::string auth_token_;
std::string access_token_;
std::string refresh_token_;
std::string code_challenge_;
std::string code_verifier_;
std::string client_id_;
std::string client_secret_;
std::string oauth_host_;
Expand Down

0 comments on commit 38db332

Please sign in to comment.