Skip to content

Commit

Permalink
Properly reconnect when switching chains after socials login
Browse files Browse the repository at this point in the history
  • Loading branch information
lefarchi committed Jul 25, 2024
1 parent 03ac87a commit 518a67a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 33 deletions.
40 changes: 20 additions & 20 deletions Assets/TestsBasic/IntegrationTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;
using Treasure;
using UnityEngine;
Expand Down Expand Up @@ -48,14 +47,16 @@ public IEnumerator Integration1()
yield return new WaitForSeconds(1);

if (connected) {
yield return SetChain(); // TODO this breaks the CreateSession call, is it intended?
yield return SetChain();
}

yield return SendAnalyticsEvents();

yield return new WaitForSeconds(3);
yield return new WaitForSeconds(1);

yield return CreateSession();

yield return new WaitForSeconds(1);

yield return ForceFlushCache();
}
Expand Down Expand Up @@ -107,22 +108,24 @@ private IEnumerator SetChain()

tdkLogs.Clear();

_ = TDK.Connect.SetChainId(ChainId.ArbitrumSepolia);
yield return new WaitForSeconds(2);
_ = TDK.Connect.SetChainId(ChainId.Arbitrum);
yield return new WaitForSeconds(2);
_ = TDK.Connect.SetChainId(ChainId.Arbitrum);
yield return new WaitForSeconds(2);
_ = TDK.Connect.SetChainId(ChainId.ArbitrumSepolia);
yield return new WaitForSeconds(2);
yield return TestHelpers.WaitForTask(TDK.Connect.SetChainId(ChainId.ArbitrumSepolia));
yield return TestHelpers.WaitForTask(TDK.Connect.SetChainId(ChainId.Arbitrum));
yield return TestHelpers.WaitForTask(TDK.Connect.SetChainId(ChainId.Arbitrum));
yield return TestHelpers.WaitForTask(TDK.Connect.SetChainId(ChainId.ArbitrumSepolia));

Assert.That(tdkLogs.Count, Is.EqualTo(10));

Assert.That(tdkLogs, Is.EquivalentTo(new List<string> {
Assert.That(tdkLogs, Is.EqualTo(new List<string> {
"Chain is already set to ArbitrumSepolia",
"Initializing Thirdweb SDK for chain: arbitrum",
"[TDK.Connect:Connect] Connecting to SmartWallet...",
"[TDK.Connect:Connect] Connection success!",
"Switched chain to Arbitrum",
"Chain is already set to Arbitrum",
"Initializing Thirdweb SDK for chain: arbitrum-sepolia",
"Switched chain to ArbitrumSepolia",
"[TDK.Connect:Connect] Connecting to SmartWallet...",
"[TDK.Connect:Connect] Connection success!",
"Switched chain to ArbitrumSepolia"
}));
}

Expand Down Expand Up @@ -157,8 +160,7 @@ private IEnumerator CreateSession()
tdkLogs.Clear();

Assert.That(TDK.Identity.IsAuthenticated, Is.False);
var startTask = TDK.Identity.StartUserSession();
yield return TestHelpers.WaitUntilWithMax(() => startTask.IsCompleted, 10);
yield return TestHelpers.WaitForTask(TDK.Identity.StartUserSession(), 10);
Assert.That(TDK.Identity.IsAuthenticated, Is.True);

Assert.That(tdkLogs.Count, Is.EqualTo(5));
Expand All @@ -170,10 +172,9 @@ private IEnumerator CreateSession()

tdkLogs.Clear();

startTask = TDK.Identity.StartUserSession();
yield return TestHelpers.WaitUntilWithMax(() => startTask.IsCompleted, 10);
yield return TestHelpers.WaitForTask(TDK.Identity.StartUserSession(), 10);

Assert.That(tdkLogs, Is.EquivalentTo(new List<string> {
Assert.That(tdkLogs, Is.EqualTo(new List<string> {
"Validating existing user session",
"Fetching user details",
"Existing user session is valid",
Expand All @@ -182,8 +183,7 @@ private IEnumerator CreateSession()

tdkLogs.Clear();

var endTask = TDK.Identity.EndUserSession();
yield return TestHelpers.WaitUntilWithMax(() => endTask.IsCompleted, 10);
yield return TestHelpers.WaitForTask(TDK.Identity.EndUserSession(), 10);
Assert.That(TDK.Identity.IsAuthenticated, Is.False);

yield return ForceFlushCache();
Expand Down
7 changes: 7 additions & 0 deletions Assets/TestsBasic/TestHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using System;
using System.Collections;
using System.IO;
using System.Threading.Tasks;
using NUnit.Framework;
using Treasure;
using UnityEngine;

Expand Down Expand Up @@ -65,4 +67,9 @@ public static IEnumerator WaitUntilWithMax(Func<bool> predicate, float maxWait)
return timeout <= 0 || predicate();
});
}

public static IEnumerator WaitForTask(Task task, float maxWait = 5) {
yield return WaitUntilWithMax(() => task.IsCompleted, maxWait);
Assert.That(task.IsCompleted);
}
}
34 changes: 21 additions & 13 deletions Assets/Treasure/TDK/Runtime/Connect/TDK.Connect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public struct Options
#region private vars
private Options? _options;
private ChainId _chainId = ChainId.Unknown;
private string _email;
private WalletConnection _lastWalletConnection;
private string _address;
#endregion

Expand Down Expand Up @@ -87,11 +87,6 @@ public async Task<ChainId> GetChainId()
return _chainId;
}

public string Email
{
get { return _email; }
}

public string Address
{
get { return _address; }
Expand All @@ -115,9 +110,10 @@ private async Task ConnectWallet(WalletConnection wc, ChainId chainId)
TDKLogger.Log($"[TDK.Connect:Connect] Connecting to {wc.provider}...");
var result = await TDKServiceLocator.GetService<TDKThirdwebService>().Wallet.Connect(wc);
_address = result;
_email = wc.email;
_lastWalletConnection = wc;
OnConnected?.Invoke(_address);
TDK.Analytics.SetTreasureConnectInfo(_address, (int)chainId);
TDKLogger.LogDebug($"[TDK.Connect:Connect] Connection success!");
#else
TDKLogger.LogError("Unable to connect wallet. TDK Connect wallet service not implemented.");
#endif
Expand Down Expand Up @@ -147,11 +143,11 @@ public async Task SetChainId(ChainId chainId)
#if TDK_THIRDWEB
// Thirdweb SDK currently doesn't allow you to switch networks while connected to a smart wallet
// Reinitialize it and auto-connect instead
var connectedEmail = _email;
var lastWalletConnection = _lastWalletConnection;
TDKServiceLocator.GetService<TDKThirdwebService>().InitializeSDK(Constants.ChainIdToName[chainId]);
if (!string.IsNullOrEmpty(connectedEmail))
if (lastWalletConnection != null)
{
await ConnectEmail(connectedEmail, new Options { isSilent = true });
await Reconnect(lastWalletConnection);
}
#endif

Expand All @@ -175,10 +171,10 @@ public void HideConnectModal()
TDKConnectUIManager.Instance.Hide();
}

public async Task ConnectEmail(string email, Options? options = null)
public async Task ConnectEmail(string email)
{
#if TDK_THIRDWEB
_options = options;
_options = null;
var chainId = await GetChainId();
var wc = new WalletConnection(
provider: WalletProvider.SmartWallet,
Expand All @@ -196,6 +192,7 @@ public async Task ConnectEmail(string email, Options? options = null)
public async Task ConnectSocial(SocialAuthProvider provider)
{
#if TDK_THIRDWEB
_options = null;
var chainId = await GetChainId();
var wc = new WalletConnection(
provider: WalletProvider.SmartWallet,
Expand All @@ -209,6 +206,17 @@ public async Task ConnectSocial(SocialAuthProvider provider)
#endif
}

public async Task Reconnect(WalletConnection lastWalletConnection)
{
#if TDK_THIRDWEB
_options = new Options { isSilent = true };
var chainId = await GetChainId();
await ConnectWallet(lastWalletConnection, chainId);
#else
TDKLogger.LogError("Unable to reconnect. TDK Connect wallet service not implemented.");
#endif
}

public async Task Disconnect(bool endSession = false)
{
#if TDK_THIRDWEB
Expand All @@ -224,7 +232,7 @@ public async Task Disconnect(bool endSession = false)
#endif

_address = null;
_email = null;
_lastWalletConnection = null;
TDK.Analytics.TrackCustomEvent(AnalyticsConstants.EVT_TREASURECONNECT_DISCONNECTED);
}
#endregion
Expand Down

0 comments on commit 518a67a

Please sign in to comment.