Skip to content

Commit

Permalink
Use ExtractMap instead of HashMap when possible (serenity-rs#2797)
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev committed Mar 19, 2024
1 parent 0abf8f4 commit 2cb925d
Show file tree
Hide file tree
Showing 32 changed files with 239 additions and 346 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ bool_to_bitflags = { version = "0.1.0" }
nonmax = { version = "0.5.5", features = ["serde"] }
strum = { version = "0.26", features = ["derive"] }
to-arraystring = "0.1.0"
extract_map = { version = "0.1.0", features = ["serde", "iter_mut"] }
# Optional dependencies
fxhash = { version = "0.2.1", optional = true }
chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"], optional = true }
Expand All @@ -49,7 +50,7 @@ mime_guess = { version = "2.0.4", optional = true }
dashmap = { version = "5.5.3", features = ["serde"], optional = true }
parking_lot = { version = "0.12.1", optional = true }
ed25519-dalek = { version = "2.0.0", optional = true }
typesize = { version = "0.1.5", optional = true, features = ["url", "time", "serde_json", "secrecy", "dashmap", "parking_lot", "nonmax", "details"] }
typesize = { version = "0.1.6", optional = true, features = ["url", "time", "serde_json", "secrecy", "dashmap", "parking_lot", "nonmax", "extract_map_01", "details"] }
# serde feature only allows for serialisation,
# Serenity workspace crates
serenity-voice-model = { version = "0.2.0", path = "./voice-model", optional = true }
Expand Down Expand Up @@ -145,3 +146,4 @@ native_tls_backend = [
[package.metadata.docs.rs]
features = ["full"]
rustdoc-args = ["--cfg", "docsrs"]

77 changes: 41 additions & 36 deletions src/cache/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl CacheUpdate for ChannelCreateEvent {
let old_channel = cache
.guilds
.get_mut(&self.channel.guild_id)
.and_then(|mut g| g.channels.insert(self.channel.id, self.channel.clone()));
.and_then(|mut g| g.channels.insert(self.channel.clone()));

old_channel
}
Expand All @@ -72,7 +72,7 @@ impl CacheUpdate for ChannelUpdateEvent {
cache
.guilds
.get_mut(&self.channel.guild_id)
.and_then(|mut g| g.channels.insert(self.channel.id, self.channel.clone()))
.and_then(|mut g| g.channels.insert(self.channel.clone()))
}
}

Expand All @@ -82,7 +82,7 @@ impl CacheUpdate for ChannelPinsUpdateEvent {
fn update(&mut self, cache: &Cache) -> Option<()> {
if let Some(guild_id) = self.guild_id {
if let Some(mut guild) = cache.guilds.get_mut(&guild_id) {
if let Some(channel) = guild.channels.get_mut(&self.channel_id) {
if let Some(mut channel) = guild.channels.get_mut(&self.channel_id) {
channel.last_pin_timestamp = self.last_pin_timestamp;
}
}
Expand Down Expand Up @@ -118,9 +118,9 @@ impl CacheUpdate for GuildDeleteEvent {

match cache.guilds.remove(&self.guild.id) {
Some(guild) => {
for channel_id in guild.1.channels.keys() {
for channel in &guild.1.channels {
// Remove the channel's cached messages.
cache.messages.remove(channel_id);
cache.messages.remove(&channel.id);
}

Some(guild.1)
Expand Down Expand Up @@ -148,7 +148,7 @@ impl CacheUpdate for GuildMemberAddEvent {
fn update(&mut self, cache: &Cache) -> Option<()> {
if let Some(mut guild) = cache.guilds.get_mut(&self.member.guild_id) {
guild.member_count += 1;
guild.members.insert(self.member.user.id, self.member.clone());
guild.members.insert(self.member.clone());
}

None
Expand All @@ -173,7 +173,7 @@ impl CacheUpdate for GuildMemberUpdateEvent {

fn update(&mut self, cache: &Cache) -> Option<Self::Output> {
if let Some(mut guild) = cache.guilds.get_mut(&self.guild_id) {
let item = if let Some(member) = guild.members.get_mut(&self.user.id) {
let item = if let Some(mut member) = guild.members.get_mut(&self.user.id) {
let item = Some(member.clone());

member.joined_at.clone_from(&Some(self.joined_at));
Expand Down Expand Up @@ -213,7 +213,7 @@ impl CacheUpdate for GuildMemberUpdateEvent {
new_member.set_deaf(self.deaf());
new_member.set_mute(self.mute());

guild.members.insert(self.user.id, new_member);
guild.members.insert(new_member);
}

item
Expand All @@ -239,11 +239,7 @@ impl CacheUpdate for GuildRoleCreateEvent {
type Output = ();

fn update(&mut self, cache: &Cache) -> Option<()> {
cache
.guilds
.get_mut(&self.role.guild_id)
.map(|mut g| g.roles.insert(self.role.id, self.role.clone()));

cache.guilds.get_mut(&self.role.guild_id).map(|mut g| g.roles.insert(self.role.clone()));
None
}
}
Expand All @@ -261,8 +257,8 @@ impl CacheUpdate for GuildRoleUpdateEvent {

fn update(&mut self, cache: &Cache) -> Option<Self::Output> {
if let Some(mut guild) = cache.guilds.get_mut(&self.role.guild_id) {
if let Some(role) = guild.roles.get_mut(&self.role.id) {
return Some(std::mem::replace(role, self.role.clone()));
if let Some(mut role) = guild.roles.get_mut(&self.role.id) {
return Some(std::mem::replace(&mut *role, self.role.clone()));
}
}

Expand Down Expand Up @@ -328,9 +324,14 @@ impl CacheUpdate for MessageCreateEvent {
let guild = self.message.guild_id.and_then(|g_id| cache.guilds.get_mut(&g_id));

if let Some(mut guild) = guild {
if let Some(channel) = guild.channels.get_mut(&self.message.channel_id) {
update_channel_last_message_id(&self.message, channel, cache);
} else {
let mut found_channel = false;
if let Some(mut channel) = guild.channels.get_mut(&self.message.channel_id) {
update_channel_last_message_id(&self.message, &mut channel, cache);
found_channel = true;
}

// found_channel is to avoid limitations of the NLL borrow checker.
if !found_channel {
// This may be a thread.
let thread =
guild.threads.iter_mut().find(|thread| thread.id == self.message.channel_id);
Expand Down Expand Up @@ -403,25 +404,27 @@ impl CacheUpdate for PresenceUpdateEvent {
if self.presence.status == OnlineStatus::Offline {
guild.presences.remove(&self.presence.user.id);
} else {
guild.presences.insert(self.presence.user.id, self.presence.clone());
guild.presences.insert(self.presence.clone());
}

// Create a partial member instance out of the presence update data.
if let Some(user) = self.presence.user.to_user() {
guild.members.entry(self.presence.user.id).or_insert_with(|| Member {
guild_id,
joined_at: None,
nick: None,
user,
roles: FixedArray::default(),
premium_since: None,
permissions: None,
avatar: None,
communication_disabled_until: None,
flags: GuildMemberFlags::default(),
unusual_dm_activity_until: None,
__generated_flags: MemberGeneratedFlags::empty(),
});
if !guild.members.contains_key(&self.presence.user.id) {
guild.members.insert(Member {
guild_id,
joined_at: None,
nick: None,
user,
roles: FixedArray::default(),
premium_since: None,
permissions: None,
avatar: None,
communication_disabled_until: None,
flags: GuildMemberFlags::default(),
unusual_dm_activity_until: None,
__generated_flags: MemberGeneratedFlags::empty(),
});
}
}
}
}
Expand Down Expand Up @@ -566,12 +569,14 @@ impl CacheUpdate for VoiceStateUpdateEvent {
if let Some(guild_id) = self.voice_state.guild_id {
if let Some(mut guild) = cache.guilds.get_mut(&guild_id) {
if let Some(member) = &self.voice_state.member {
guild.members.insert(member.user.id, member.clone());
guild.members.insert(member.clone());
}

if self.voice_state.channel_id.is_some() {
// Update or add to the voice state list
guild.voice_states.insert(self.voice_state.user_id, self.voice_state.clone())
let old_state = guild.voice_states.remove(&self.voice_state.user_id);
guild.voice_states.insert(self.voice_state.clone());
old_state
} else {
// Remove the user from the voice state list
guild.voice_states.remove(&self.voice_state.user_id)
Expand All @@ -590,7 +595,7 @@ impl CacheUpdate for VoiceChannelStatusUpdateEvent {

fn update(&mut self, cache: &Cache) -> Option<Self::Output> {
let mut guild = cache.guilds.get_mut(&self.guild_id)?;
let channel = guild.channels.get_mut(&self.id)?;
let mut channel = guild.channels.get_mut(&self.id)?;

let old = channel.status.clone();
channel.status.clone_from(&self.status);
Expand Down
18 changes: 8 additions & 10 deletions src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,14 @@ impl Cache {
}

/// Clones all channel categories in the given guild and returns them.
pub fn guild_categories(&self, guild_id: GuildId) -> Option<HashMap<ChannelId, GuildChannel>> {
pub fn guild_categories(
&self,
guild_id: GuildId,
) -> Option<ExtractMap<ChannelId, GuildChannel>> {
let guild = self.guilds.get(&guild_id)?;
Some(
guild
.channels
.iter()
.filter(|(_id, channel)| channel.kind == ChannelType::Category)
.map(|(id, channel)| (*id, channel.clone()))
.collect(),
)

let filter = |channel: &&GuildChannel| channel.kind == ChannelType::Category;
Some(guild.channels.iter().filter(filter).cloned().collect())
}

/// Inserts new messages into the message cache for a channel manually.
Expand Down Expand Up @@ -571,7 +569,7 @@ mod test {
let mut guild_create = GuildCreateEvent {
guild: Guild {
id: GuildId::new(1),
channels: HashMap::from([(ChannelId::new(2), channel)]),
channels: ExtractMap::from_iter([channel]),
..Default::default()
},
};
Expand Down
4 changes: 2 additions & 2 deletions src/client/event_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ event_handler! {
/// Dispatched when the emojis are updated.
///
/// Provides the guild's id and the new state of the emojis in the guild.
GuildEmojisUpdate { guild_id: GuildId, current_state: HashMap<EmojiId, Emoji> } => async fn guild_emojis_update(&self, ctx: Context);
GuildEmojisUpdate { guild_id: GuildId, current_state: ExtractMap<EmojiId, Emoji> } => async fn guild_emojis_update(&self, ctx: Context);

/// Dispatched when a guild's integration is added, updated or removed.
///
Expand Down Expand Up @@ -245,7 +245,7 @@ event_handler! {
/// Dispatched when the stickers are updated.
///
/// Provides the guild's id and the new state of the stickers in the guild.
GuildStickersUpdate { guild_id: GuildId, current_state: HashMap<StickerId, Sticker> } => async fn guild_stickers_update(&self, ctx: Context);
GuildStickersUpdate { guild_id: GuildId, current_state: ExtractMap<StickerId, Sticker> } => async fn guild_stickers_update(&self, ctx: Context);

/// Dispatched when the guild is updated.
///
Expand Down
7 changes: 5 additions & 2 deletions src/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3134,7 +3134,10 @@ impl Http {
}

/// Gets all channels in a guild.
pub async fn get_channels(&self, guild_id: GuildId) -> Result<Vec<GuildChannel>> {
pub async fn get_channels(
&self,
guild_id: GuildId,
) -> Result<ExtractMap<ChannelId, GuildChannel>> {
self.fire(Request {
body: None,
multipart: None,
Expand Down Expand Up @@ -3639,7 +3642,7 @@ impl Http {
}

/// Retrieves a list of roles in a [`Guild`].
pub async fn get_guild_roles(&self, guild_id: GuildId) -> Result<Vec<Role>> {
pub async fn get_guild_roles(&self, guild_id: GuildId) -> Result<ExtractMap<RoleId, Role>> {
let mut value: Value = self
.fire(Request {
body: None,
Expand Down
1 change: 1 addition & 0 deletions src/internal/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
pub use std::result::Result as StdResult;

pub use extract_map::{ExtractKey, ExtractMap, LendingIterator};
pub use serde_json::Value;
pub use small_fixed_array::{FixedArray, FixedString, TruncatingInto};

Expand Down
11 changes: 11 additions & 0 deletions src/internal/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,14 @@ pub(crate) fn join_to_string(
buf.truncate(buf.len() - 1);
buf
}

// Required because of https://github.com/Crazytieguy/gat-lending-iterator/issues/31
macro_rules! lending_for_each {
($iter:expr, |$item:ident| $body:expr ) => {
while let Some(mut $item) = $iter.next() {
$body
}
};
}

pub(crate) use lending_for_each;
46 changes: 35 additions & 11 deletions src/model/application/command_interaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::client::Context;
#[cfg(feature = "model")]
use crate::http::Http;
use crate::internal::prelude::*;
use crate::internal::utils::lending_for_each;
use crate::model::application::{CommandOptionType, CommandType};
use crate::model::channel::{Attachment, Message, PartialChannel};
use crate::model::guild::{Member, PartialMember, Role};
Expand Down Expand Up @@ -253,7 +254,9 @@ impl<'de> Deserialize<'de> for CommandInteraction {
// If `member` is present, `user` wasn't sent and is still filled with default data
interaction.user = member.user.clone();
}
interaction.data.resolved.roles.values_mut().for_each(|r| r.guild_id = guild_id);

let mut role_iter = interaction.data.resolved.roles.iter_mut();
lending_for_each!(role_iter, |r| r.guild_id = guild_id);
}
Ok(interaction)
}
Expand Down Expand Up @@ -486,23 +489,44 @@ pub enum ResolvedTarget<'a> {
#[non_exhaustive]
pub struct CommandDataResolved {
/// The resolved users.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub users: HashMap<UserId, User>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub users: ExtractMap<UserId, User>,
/// The resolved partial members.
// Cannot use ExtractMap, as PartialMember does not always store an ID.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub members: HashMap<UserId, PartialMember>,
/// The resolved roles.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub roles: HashMap<RoleId, Role>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub roles: ExtractMap<RoleId, Role>,
/// The resolved partial channels.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub channels: HashMap<ChannelId, PartialChannel>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub channels: ExtractMap<ChannelId, PartialChannel>,
/// The resolved messages.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub messages: HashMap<MessageId, Message>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub messages: ExtractMap<MessageId, Message>,
/// The resolved attachments.
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub attachments: HashMap<AttachmentId, Attachment>,
#[serde(
default,
skip_serializing_if = "ExtractMap::is_empty",
serialize_with = "extract_map::serialize_as_map"
)]
pub attachments: ExtractMap<AttachmentId, Attachment>,
}

/// A set of a parameter and a value from the user.
Expand Down
6 changes: 6 additions & 0 deletions src/model/channel/attachment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,9 @@ impl Attachment {
Ok(bytes.to_vec())
}
}

impl ExtractKey<AttachmentId> for Attachment {
fn extract_key(&self) -> &AttachmentId {
&self.id
}
}
Loading

0 comments on commit 2cb925d

Please sign in to comment.