Skip to content

Commit

Permalink
Do not cache messages in GetMessagesAsync.
Browse files Browse the repository at this point in the history
GetTokensAsync will wait for ongoing token fetching tasks before fetching for tokens.
  • Loading branch information
CXuesong committed May 20, 2017
1 parent 780be75 commit cb52c39
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 231 deletions.
199 changes: 115 additions & 84 deletions WikiClientLibrary/Site.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
Expand All @@ -24,8 +25,14 @@ public class Site

#region Services

/// <summary>
/// Gets the <see cref="WikiClientBase" /> used to perform the requests.
/// </summary>
public WikiClientBase WikiClient { get; }

/// <summary>
/// Gets or sets the <see cref="ILogger"/> used to log the requests.
/// </summary>
public ILogger Logger { get; set; }

/// <summary>
Expand Down Expand Up @@ -427,6 +434,8 @@ public Task<IDictionary<string, string>> GetTokensAsync(IEnumerable<string> toke
return GetTokensAsync(tokenTypes, false, cancellationToken);
}

private readonly SemaphoreSlim fetchTokensAsyncCoreSemaphore = new SemaphoreSlim(1, 1);

/// <summary>
/// Request tokens for operations.
/// </summary>
Expand All @@ -438,18 +447,78 @@ public Task<IDictionary<string, string>> GetTokensAsync(IEnumerable<string> toke
/// <para>This method is thread-safe.</para>
/// <para>See https://www.mediawiki.org/wiki/API:Tokens .</para>
/// </remarks>
public async Task<IDictionary<string, string>> GetTokensAsync(IEnumerable<string> tokenTypes, bool forceRefetch, CancellationToken cancellationToken)
public async Task<IDictionary<string, string>> GetTokensAsync(IEnumerable<string> tokenTypes, bool forceRefetch,
CancellationToken cancellationToken)
{
// TODO wait for other threads to fetch a token, instead of fetching whenever it's not
// available, even if there's already another thread requesting it.
if (tokenTypes == null) throw new ArgumentNullException(nameof(tokenTypes));
var tokenTypesList = tokenTypes as IReadOnlyList<string> ?? tokenTypes.ToList();
List<string> pendingtokens;
List<string> pendingtokens = null;
var result = new Dictionary<string, string>();
lock (_TokensCache)
{
pendingtokens = tokenTypesList.Where(tt => forceRefetch || !_TokensCache.ContainsKey(tt)).ToList();
foreach (var tt in tokenTypes)
{
if (string.IsNullOrEmpty(tt))
throw new ArgumentException("tokenTypes contains null or empty item.", nameof(tokenTypes));
if (forceRefetch || !_TokensCache.TryGetValue(tt, out var value))
{
if (pendingtokens == null) pendingtokens = new List<string>();
pendingtokens.Add(tt);
}
else
{
result[tt] = value;
}
}
}
if (pendingtokens != null)
{
await fetchTokensAsyncCoreSemaphore.WaitAsync(cancellationToken);
try
{
// In case some tokens have just been fetched…
if (!forceRefetch)
{
lock (_TokensCache)
{
for (int i = 0; i < pendingtokens.Count; i++)
{
if (_TokensCache.TryGetValue(pendingtokens[i], out var value))
{
result[pendingtokens[i]] = value;
pendingtokens.RemoveAt(i);
i--;
}
}
}
}
await FetchTokensAsyncCore(pendingtokens, cancellationToken);
}
finally
{
fetchTokensAsyncCoreSemaphore.Release();
}
}
lock (_TokensCache)
{
foreach (var key in pendingtokens)
{
if (_TokensCache.TryGetValue(key, out var value))
{
result[key] = value;
}
else
{
throw new InvalidOperationException("Unrecognized token: " + key + ".");
}
}
}
return result;
}

public async Task FetchTokensAsyncCore(IList<string> tokenTypes, CancellationToken cancellationToken)
{
JObject fetchedTokens = null;
var localTokenTypes = tokenTypes.ToList();
if (SiteInfo.Version < new Version("1.24"))
{
/*
Expand All @@ -458,17 +527,12 @@ Patrol was added in v1.14.
For v1.17-19, the patrol token must be obtained from the query
list recentchanges.
*/
var needPatrolFromRC = false;
// Check whether we need a patrol token.
if (SiteInfo.Version < new Version("1.20"))
needPatrolFromRC = pendingtokens.Remove("patrol");
if (needPatrolFromRC)
if (SiteInfo.Version < new Version("1.20") && localTokenTypes.Remove("patrol"))
{
string patrolToken;
lock (_TokensCache)
if (!_TokensCache.TryGetValue("patrol", out patrolToken)) patrolToken = null;
if (patrolToken == null)
if (!_TokensCache.ContainsKey("patrol"))
{
string patrolToken;
if (SiteInfo.Version < new Version("1.17"))
{
patrolToken = await GetTokenAsync("edit");
Expand All @@ -482,62 +546,46 @@ list recentchanges.
rctoken = "patrol",
rclimit = 1
}, cancellationToken);
patrolToken = (string)jobj["query"]["recentchanges"]["patroltoken"];
patrolToken = (string) jobj["query"]["recentchanges"]["patroltoken"];
}
lock (_TokensCache)
_TokensCache["patrol"] = patrolToken;
_TokensCache["patrol"] = patrolToken;
}
}
if (pendingtokens.Count > 0)
fetchedTokens = await FetchTokensAsync(string.Join("|", pendingtokens), cancellationToken);
if (localTokenTypes.Count > 0)
fetchedTokens = await FetchTokensAsync(string.Join("|", localTokenTypes), cancellationToken);
}
else
{
// Use csrf token if possible.
if (!pendingtokens.Contains("csrf"))
if (!localTokenTypes.Contains("csrf"))
{
var needCsrf = false;
foreach (var t in CsrfTokens)
{
if (pendingtokens.Remove(t)) needCsrf = true;
if (localTokenTypes.Remove(t)) needCsrf = true;
}
if (needCsrf) pendingtokens.Add("csrf");
if (needCsrf) localTokenTypes.Add("csrf");
}
if (pendingtokens.Count > 0)
if (localTokenTypes.Count > 0)
{
fetchedTokens = await FetchTokensAsync2(string.Join("|", pendingtokens), cancellationToken);
var csrf = (string)fetchedTokens["csrftoken"];
fetchedTokens = await FetchTokensAsync2(string.Join("|", localTokenTypes), cancellationToken);
var csrf = (string) fetchedTokens["csrftoken"];
if (csrf != null)
{
lock (_TokensCache)
foreach (var t in CsrfTokens) _TokensCache[t] = csrf;
foreach (var t in CsrfTokens) _TokensCache[t] = csrf;
}
}
}
// Put tokens into cache first.
if (fetchedTokens != null)
if (fetchedTokens == null) return;
foreach (var p in fetchedTokens.Properties())
{
foreach (var p in fetchedTokens.Properties())
{
// Remove "token" in the result
var tokenName = p.Name.EndsWith("token")
? p.Name.Substring(0, p.Name.Length - 5)
: p.Name;
lock (_TokensCache)
{
_TokensCache[tokenName] = (string)p.Value;
}
pendingtokens.Remove(tokenName);
}
if (pendingtokens.Count > 0)
{
throw new InvalidOperationException(
"Unrecognized token(s): " + string.Join(", ", pendingtokens) + ".");
}
// Remove "token" in the result
var tokenName = p.Name.EndsWith("token")
? p.Name.Substring(0, p.Name.Length - 5)
: p.Name;
_TokensCache[tokenName] = (string) p.Value;
}
// Then return.
lock (_TokensCache)
return tokenTypesList.ToDictionary(t => t, t => _TokensCache[t]);
}

/// <summary>
Expand Down Expand Up @@ -644,7 +692,7 @@ public async Task LoginAsync(string userName, string password, string domain, Ca
switch (result)
{
case "Success":
_TokensCache.Clear();
lock (_TokensCache) _TokensCache.Clear();
await RefreshAccountInfoAsync();
Debug.Assert(AccountInfo.IsUser);
return;
Expand Down Expand Up @@ -681,7 +729,7 @@ public async Task LogoutAsync()
{
action = "logout",
}, true, CancellationToken.None);
_TokensCache.Clear();
lock (_TokensCache) _TokensCache.Clear();
if (options.ExplicitInfoRefresh)
_AccountInfo = null;
else
Expand All @@ -701,8 +749,6 @@ private async Task<bool> Relogin()
#endregion

#region Query
private readonly AsyncReaderWriterLock cachedMessagesLock = new AsyncReaderWriterLock();
private readonly IDictionary<string, string> _CachedMessages = new Dictionary<string, string>();

private async Task<JArray> FetchMessagesAsync(string messagesExpr, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -757,41 +803,26 @@ public async Task<IDictionary<string, string>> GetMessagesAsync(IEnumerable<stri
{
if (messages == null) throw new ArgumentNullException(nameof(messages));
cancellationToken.ThrowIfCancellationRequested();
var impending = new List<string>();
var exprBuilder = new StringBuilder();
var result = new Dictionary<string, string>();
using (await cachedMessagesLock.ReaderLockAsync(cancellationToken))
foreach (var m in messages)
{
cancellationToken.ThrowIfCancellationRequested();
foreach (var m in messages)
if (m == null) throw new ArgumentException("The sequence contains null item.", nameof(messages));
if (m.Contains("|"))
throw new ArgumentException($"The message name \"{m}\" contains pipe character.",
nameof(messages));
if (m == "*") throw new InvalidOperationException("Getting all the messages is deprecated.");
if (exprBuilder.Length > 0) exprBuilder.Append('|');
exprBuilder.Append(m);
var jr = await FetchMessagesAsync(exprBuilder.ToString(), cancellationToken);
foreach (var entry in jr)
{
if (m == null) throw new ArgumentException("The sequence contains null item.", nameof(messages));
if (m.Contains("|"))
throw new ArgumentException($"The message name \"{m}\" contains pipe character.",
nameof(messages));
if (m == "*") throw new InvalidOperationException("Getting all the messages is deprecated.");
string content;
if (_CachedMessages.TryGetValue(m.ToLowerInvariant(), out content))
result[m] = content;
else
impending.Add(m);
}
}
if (impending.Count > 0)
{
using (await cachedMessagesLock.WriterLockAsync(cancellationToken))
{
cancellationToken.ThrowIfCancellationRequested();
var jr = await FetchMessagesAsync(string.Join("|", impending), cancellationToken);
foreach (var entry in jr)
{
var name = (string)entry["name"];
//var nname = (string)entry["normalizedname"];
// for Wikia, there's no normalizedname
var message = (string)entry["*"];
//var missing = entry["missing"] != null; message will be null
result[name] = message;
_CachedMessages[name] = message;
}
var name = (string) entry["name"];
//var nname = (string)entry["normalizedname"];
// for Wikia, there's no normalizedname
var message = (string) entry["*"];
//var missing = entry["missing"] != null; message will be null
if (message != null) result[name] = message;
}
}
return result;
Expand Down
Loading

0 comments on commit cb52c39

Please sign in to comment.