From efc8ae1e69ed517be411b4cc3978e277aab36675 Mon Sep 17 00:00:00 2001 From: Gnome! Date: Sun, 17 Mar 2024 23:37:27 +0000 Subject: [PATCH] Use ExtractMap instead of HashMap when possible (#2797) --- Cargo.toml | 4 +- src/cache/event.rs | 77 ++++----- src/cache/mod.rs | 18 +-- src/client/event_handler.rs | 4 +- src/http/client.rs | 7 +- src/internal/prelude.rs | 1 + src/internal/utils.rs | 11 ++ src/model/application/command_interaction.rs | 46 ++++-- src/model/channel/attachment.rs | 6 + src/model/channel/guild_channel.rs | 14 +- src/model/channel/message.rs | 6 + src/model/channel/partial_channel.rs | 6 + src/model/event.rs | 27 ++-- src/model/gateway.rs | 6 + src/model/guild/audit_log/mod.rs | 8 +- src/model/guild/audit_log/utils.rs | 38 ----- src/model/guild/emoji.rs | 6 + src/model/guild/guild_id.rs | 12 +- src/model/guild/member.rs | 12 +- src/model/guild/mod.rs | 55 +++---- src/model/guild/partial_guild.rs | 17 +- src/model/guild/role.rs | 6 + src/model/mod.rs | 1 - src/model/sticker.rs | 6 + src/model/user.rs | 6 + src/model/utils.rs | 160 +------------------ src/model/voice.rs | 6 + src/model/webhook.rs | 6 + src/utils/argument_convert/channel.rs | 2 +- src/utils/argument_convert/emoji.rs | 2 +- src/utils/argument_convert/role.rs | 2 +- src/utils/content_safe.rs | 7 +- 32 files changed, 239 insertions(+), 346 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ab9aeb277e6..72d19407759 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ bool_to_bitflags = { version = "0.1.0" } nonmax = { version = "0.5.5", features = ["serde"] } strum = { version = "0.26", features = ["derive"] } to-arraystring = "0.1.0" +extract_map = { version = "0.1.0", features = ["serde", "iter_mut"] } # Optional dependencies fxhash = { version = "0.2.1", optional = true } chrono = { version = "0.4.31", default-features = false, features = ["clock", "serde"], optional = true } @@ -50,7 +51,7 @@ mime_guess = { version = "2.0.4", optional = true } dashmap = { version = "5.5.3", features = ["serde"], optional = true } parking_lot = { version = "0.12.1", optional = true } ed25519-dalek = { version = "2.0.0", optional = true } -typesize = { version = "0.1.5", optional = true, features = ["url", "time", "serde_json", "secrecy", "dashmap", "parking_lot", "nonmax", "details"] } +typesize = { version = "0.1.6", optional = true, features = ["url", "time", "serde_json", "secrecy", "dashmap", "parking_lot", "nonmax", "extract_map_01", "details"] } # serde feature only allows for serialisation, # Serenity workspace crates serenity-voice-model = { version = "0.2.0", path = "./voice-model", optional = true } @@ -146,3 +147,4 @@ native_tls_backend = [ [package.metadata.docs.rs] features = ["full"] rustdoc-args = ["--cfg", "docsrs"] + diff --git a/src/cache/event.rs b/src/cache/event.rs index 01994319334..d01057c13b4 100644 --- a/src/cache/event.rs +++ b/src/cache/event.rs @@ -45,7 +45,7 @@ impl CacheUpdate for ChannelCreateEvent { let old_channel = cache .guilds .get_mut(&self.channel.guild_id) - .and_then(|mut g| g.channels.insert(self.channel.id, self.channel.clone())); + .and_then(|mut g| g.channels.insert(self.channel.clone())); old_channel } @@ -71,7 +71,7 @@ impl CacheUpdate for ChannelUpdateEvent { cache .guilds .get_mut(&self.channel.guild_id) - .and_then(|mut g| g.channels.insert(self.channel.id, self.channel.clone())) + .and_then(|mut g| g.channels.insert(self.channel.clone())) } } @@ -81,7 +81,7 @@ impl CacheUpdate for ChannelPinsUpdateEvent { fn update(&mut self, cache: &Cache) -> Option<()> { if let Some(guild_id) = self.guild_id { if let Some(mut guild) = cache.guilds.get_mut(&guild_id) { - if let Some(channel) = guild.channels.get_mut(&self.channel_id) { + if let Some(mut channel) = guild.channels.get_mut(&self.channel_id) { channel.last_pin_timestamp = self.last_pin_timestamp; } } @@ -117,9 +117,9 @@ impl CacheUpdate for GuildDeleteEvent { match cache.guilds.remove(&self.guild.id) { Some(guild) => { - for channel_id in guild.1.channels.keys() { + for channel in &guild.1.channels { // Remove the channel's cached messages. - cache.messages.remove(channel_id); + cache.messages.remove(&channel.id); } Some(guild.1) @@ -147,7 +147,7 @@ impl CacheUpdate for GuildMemberAddEvent { fn update(&mut self, cache: &Cache) -> Option<()> { if let Some(mut guild) = cache.guilds.get_mut(&self.member.guild_id) { guild.member_count += 1; - guild.members.insert(self.member.user.id, self.member.clone()); + guild.members.insert(self.member.clone()); } None @@ -172,7 +172,7 @@ impl CacheUpdate for GuildMemberUpdateEvent { fn update(&mut self, cache: &Cache) -> Option { if let Some(mut guild) = cache.guilds.get_mut(&self.guild_id) { - let item = if let Some(member) = guild.members.get_mut(&self.user.id) { + let item = if let Some(mut member) = guild.members.get_mut(&self.user.id) { let item = Some(member.clone()); member.joined_at.clone_from(&Some(self.joined_at)); @@ -212,7 +212,7 @@ impl CacheUpdate for GuildMemberUpdateEvent { new_member.set_deaf(self.deaf()); new_member.set_mute(self.mute()); - guild.members.insert(self.user.id, new_member); + guild.members.insert(new_member); } item @@ -238,11 +238,7 @@ impl CacheUpdate for GuildRoleCreateEvent { type Output = (); fn update(&mut self, cache: &Cache) -> Option<()> { - cache - .guilds - .get_mut(&self.role.guild_id) - .map(|mut g| g.roles.insert(self.role.id, self.role.clone())); - + cache.guilds.get_mut(&self.role.guild_id).map(|mut g| g.roles.insert(self.role.clone())); None } } @@ -260,8 +256,8 @@ impl CacheUpdate for GuildRoleUpdateEvent { fn update(&mut self, cache: &Cache) -> Option { if let Some(mut guild) = cache.guilds.get_mut(&self.role.guild_id) { - if let Some(role) = guild.roles.get_mut(&self.role.id) { - return Some(std::mem::replace(role, self.role.clone())); + if let Some(mut role) = guild.roles.get_mut(&self.role.id) { + return Some(std::mem::replace(&mut *role, self.role.clone())); } } @@ -327,9 +323,14 @@ impl CacheUpdate for MessageCreateEvent { let guild = self.message.guild_id.and_then(|g_id| cache.guilds.get_mut(&g_id)); if let Some(mut guild) = guild { - if let Some(channel) = guild.channels.get_mut(&self.message.channel_id) { - update_channel_last_message_id(&self.message, channel, cache); - } else { + let mut found_channel = false; + if let Some(mut channel) = guild.channels.get_mut(&self.message.channel_id) { + update_channel_last_message_id(&self.message, &mut channel, cache); + found_channel = true; + } + + // found_channel is to avoid limitations of the NLL borrow checker. + if !found_channel { // This may be a thread. let thread = guild.threads.iter_mut().find(|thread| thread.id == self.message.channel_id); @@ -402,25 +403,27 @@ impl CacheUpdate for PresenceUpdateEvent { if self.presence.status == OnlineStatus::Offline { guild.presences.remove(&self.presence.user.id); } else { - guild.presences.insert(self.presence.user.id, self.presence.clone()); + guild.presences.insert(self.presence.clone()); } // Create a partial member instance out of the presence update data. if let Some(user) = self.presence.user.to_user() { - guild.members.entry(self.presence.user.id).or_insert_with(|| Member { - guild_id, - joined_at: None, - nick: None, - user, - roles: FixedArray::default(), - premium_since: None, - permissions: None, - avatar: None, - communication_disabled_until: None, - flags: GuildMemberFlags::default(), - unusual_dm_activity_until: None, - __generated_flags: MemberGeneratedFlags::empty(), - }); + if !guild.members.contains_key(&self.presence.user.id) { + guild.members.insert(Member { + guild_id, + joined_at: None, + nick: None, + user, + roles: FixedArray::default(), + premium_since: None, + permissions: None, + avatar: None, + communication_disabled_until: None, + flags: GuildMemberFlags::default(), + unusual_dm_activity_until: None, + __generated_flags: MemberGeneratedFlags::empty(), + }); + } } } } @@ -553,12 +556,14 @@ impl CacheUpdate for VoiceStateUpdateEvent { if let Some(guild_id) = self.voice_state.guild_id { if let Some(mut guild) = cache.guilds.get_mut(&guild_id) { if let Some(member) = &self.voice_state.member { - guild.members.insert(member.user.id, member.clone()); + guild.members.insert(member.clone()); } if self.voice_state.channel_id.is_some() { // Update or add to the voice state list - guild.voice_states.insert(self.voice_state.user_id, self.voice_state.clone()) + let old_state = guild.voice_states.remove(&self.voice_state.user_id); + guild.voice_states.insert(self.voice_state.clone()); + old_state } else { // Remove the user from the voice state list guild.voice_states.remove(&self.voice_state.user_id) @@ -577,7 +582,7 @@ impl CacheUpdate for VoiceChannelStatusUpdateEvent { fn update(&mut self, cache: &Cache) -> Option { let mut guild = cache.guilds.get_mut(&self.guild_id)?; - let channel = guild.channels.get_mut(&self.id)?; + let mut channel = guild.channels.get_mut(&self.id)?; let old = channel.status.clone(); channel.status.clone_from(&self.status); diff --git a/src/cache/mod.rs b/src/cache/mod.rs index d0f463d9918..2beadf54fc5 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -446,16 +446,14 @@ impl Cache { } /// Clones all channel categories in the given guild and returns them. - pub fn guild_categories(&self, guild_id: GuildId) -> Option> { + pub fn guild_categories( + &self, + guild_id: GuildId, + ) -> Option> { let guild = self.guilds.get(&guild_id)?; - Some( - guild - .channels - .iter() - .filter(|(_id, channel)| channel.kind == ChannelType::Category) - .map(|(id, channel)| (*id, channel.clone())) - .collect(), - ) + + let filter = |channel: &&GuildChannel| channel.kind == ChannelType::Category; + Some(guild.channels.iter().filter(filter).cloned().collect()) } /// Inserts new messages into the message cache for a channel manually. @@ -572,7 +570,7 @@ mod test { let mut guild_create = GuildCreateEvent { guild: Guild { id: GuildId::new(1), - channels: HashMap::from([(ChannelId::new(2), channel)]), + channels: ExtractMap::from_iter([channel]), ..Default::default() }, }; diff --git a/src/client/event_handler.rs b/src/client/event_handler.rs index 678f18d9d6f..8f4c3157d1a 100644 --- a/src/client/event_handler.rs +++ b/src/client/event_handler.rs @@ -192,7 +192,7 @@ event_handler! { /// Dispatched when the emojis are updated. /// /// Provides the guild's id and the new state of the emojis in the guild. - GuildEmojisUpdate { guild_id: GuildId, current_state: HashMap } => async fn guild_emojis_update(&self, ctx: Context); + GuildEmojisUpdate { guild_id: GuildId, current_state: ExtractMap } => async fn guild_emojis_update(&self, ctx: Context); /// Dispatched when a guild's integration is added, updated or removed. /// @@ -250,7 +250,7 @@ event_handler! { /// Dispatched when the stickers are updated. /// /// Provides the guild's id and the new state of the stickers in the guild. - GuildStickersUpdate { guild_id: GuildId, current_state: HashMap } => async fn guild_stickers_update(&self, ctx: Context); + GuildStickersUpdate { guild_id: GuildId, current_state: ExtractMap } => async fn guild_stickers_update(&self, ctx: Context); /// Dispatched when the guild is updated. /// diff --git a/src/http/client.rs b/src/http/client.rs index 222d5bc0787..98dba51c766 100644 --- a/src/http/client.rs +++ b/src/http/client.rs @@ -3194,7 +3194,10 @@ impl Http { } /// Gets all channels in a guild. - pub async fn get_channels(&self, guild_id: GuildId) -> Result> { + pub async fn get_channels( + &self, + guild_id: GuildId, + ) -> Result> { self.fire(Request { body: None, multipart: None, @@ -3762,7 +3765,7 @@ impl Http { } /// Retrieves a list of roles in a [`Guild`]. - pub async fn get_guild_roles(&self, guild_id: GuildId) -> Result> { + pub async fn get_guild_roles(&self, guild_id: GuildId) -> Result> { let mut value: Value = self .fire(Request { body: None, diff --git a/src/internal/prelude.rs b/src/internal/prelude.rs index f806e249f26..53593b74d36 100644 --- a/src/internal/prelude.rs +++ b/src/internal/prelude.rs @@ -4,6 +4,7 @@ pub use std::result::Result as StdResult; +pub use extract_map::{ExtractKey, ExtractMap, LendingIterator}; pub use serde_json::Value; pub use small_fixed_array::{FixedArray, FixedString, TruncatingInto}; diff --git a/src/internal/utils.rs b/src/internal/utils.rs index d6b16888501..df3e4b32cd1 100644 --- a/src/internal/utils.rs +++ b/src/internal/utils.rs @@ -12,3 +12,14 @@ pub(crate) fn join_to_string( buf.truncate(buf.len() - 1); buf } + +// Required because of https://github.com/Crazytieguy/gat-lending-iterator/issues/31 +macro_rules! lending_for_each { + ($iter:expr, |$item:ident| $body:expr ) => { + while let Some(mut $item) = $iter.next() { + $body + } + }; +} + +pub(crate) use lending_for_each; diff --git a/src/model/application/command_interaction.rs b/src/model/application/command_interaction.rs index 56ed4ce3cd0..e5a8095085c 100644 --- a/src/model/application/command_interaction.rs +++ b/src/model/application/command_interaction.rs @@ -18,6 +18,7 @@ use crate::client::Context; #[cfg(feature = "model")] use crate::http::Http; use crate::internal::prelude::*; +use crate::internal::utils::lending_for_each; use crate::model::application::{CommandOptionType, CommandType}; use crate::model::channel::{Attachment, Message, PartialChannel}; use crate::model::guild::{Member, PartialMember, Role}; @@ -253,7 +254,9 @@ impl<'de> Deserialize<'de> for CommandInteraction { // If `member` is present, `user` wasn't sent and is still filled with default data interaction.user = member.user.clone(); } - interaction.data.resolved.roles.values_mut().for_each(|r| r.guild_id = guild_id); + + let mut role_iter = interaction.data.resolved.roles.iter_mut(); + lending_for_each!(role_iter, |r| r.guild_id = guild_id); } Ok(interaction) } @@ -486,23 +489,44 @@ pub enum ResolvedTarget<'a> { #[non_exhaustive] pub struct CommandDataResolved { /// The resolved users. - #[serde(default, skip_serializing_if = "HashMap::is_empty")] - pub users: HashMap, + #[serde( + default, + skip_serializing_if = "ExtractMap::is_empty", + serialize_with = "extract_map::serialize_as_map" + )] + pub users: ExtractMap, /// The resolved partial members. + // Cannot use ExtractMap, as PartialMember does not always store an ID. #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub members: HashMap, /// The resolved roles. - #[serde(default, skip_serializing_if = "HashMap::is_empty")] - pub roles: HashMap, + #[serde( + default, + skip_serializing_if = "ExtractMap::is_empty", + serialize_with = "extract_map::serialize_as_map" + )] + pub roles: ExtractMap, /// The resolved partial channels. - #[serde(default, skip_serializing_if = "HashMap::is_empty")] - pub channels: HashMap, + #[serde( + default, + skip_serializing_if = "ExtractMap::is_empty", + serialize_with = "extract_map::serialize_as_map" + )] + pub channels: ExtractMap, /// The resolved messages. - #[serde(default, skip_serializing_if = "HashMap::is_empty")] - pub messages: HashMap, + #[serde( + default, + skip_serializing_if = "ExtractMap::is_empty", + serialize_with = "extract_map::serialize_as_map" + )] + pub messages: ExtractMap, /// The resolved attachments. - #[serde(default, skip_serializing_if = "HashMap::is_empty")] - pub attachments: HashMap, + #[serde( + default, + skip_serializing_if = "ExtractMap::is_empty", + serialize_with = "extract_map::serialize_as_map" + )] + pub attachments: ExtractMap, } /// A set of a parameter and a value from the user. diff --git a/src/model/channel/attachment.rs b/src/model/channel/attachment.rs index 54a43a6c607..f1d7b5ce398 100644 --- a/src/model/channel/attachment.rs +++ b/src/model/channel/attachment.rs @@ -156,3 +156,9 @@ impl Attachment { Ok(bytes.to_vec()) } } + +impl ExtractKey for Attachment { + fn extract_key(&self) -> &AttachmentId { + &self.id + } +} diff --git a/src/model/channel/guild_channel.rs b/src/model/channel/guild_channel.rs index 3a17473dfd7..7e20d77e2eb 100644 --- a/src/model/channel/guild_channel.rs +++ b/src/model/channel/guild_channel.rs @@ -842,7 +842,7 @@ impl GuildChannel { match self.kind { ChannelType::Voice | ChannelType::Stage => Ok(guild .voice_states - .values() + .iter() .filter_map(|v| { v.channel_id.and_then(|c| { if c == self.id { @@ -856,12 +856,12 @@ impl GuildChannel { ChannelType::News | ChannelType::Text => Ok(guild .members .iter() - .filter(|(id, _)| { - self.permissions_for_user(cache, **id) + .filter(|m| { + self.permissions_for_user(cache, m.user.id) .map(|p| p.contains(Permissions::VIEW_CHANNEL)) .unwrap_or(false) }) - .map(|e| e.1.clone()) + .cloned() .collect::>()), _ => Err(Error::from(ModelError::InvalidChannelType)), } @@ -1033,6 +1033,12 @@ impl fmt::Display for GuildChannel { } } +impl ExtractKey for GuildChannel { + fn extract_key(&self) -> &ChannelId { + &self.id + } +} + /// A partial guild channel. /// /// [Discord docs](https://discord.com/developers/docs/resources/channel#channel-object), diff --git a/src/model/channel/message.rs b/src/model/channel/message.rs index 228579e3ab2..f2ef648273e 100644 --- a/src/model/channel/message.rs +++ b/src/model/channel/message.rs @@ -828,6 +828,12 @@ impl Message { } } +impl ExtractKey for Message { + fn extract_key(&self) -> &MessageId { + &self.id + } +} + impl From for MessageId { /// Gets the Id of a [`Message`]. fn from(message: Message) -> MessageId { diff --git a/src/model/channel/partial_channel.rs b/src/model/channel/partial_channel.rs index 1e544c9d115..4317986b384 100644 --- a/src/model/channel/partial_channel.rs +++ b/src/model/channel/partial_channel.rs @@ -30,6 +30,12 @@ pub struct PartialChannel { pub parent_id: Option, } +impl ExtractKey for PartialChannel { + fn extract_key(&self) -> &ChannelId { + &self.id + } +} + /// A container for the IDs returned by following a news channel. /// /// [Discord docs](https://discord.com/developers/docs/resources/channel#followed-channel-object). diff --git a/src/model/event.rs b/src/model/event.rs index 314efb7d5c5..e397c0666aa 100644 --- a/src/model/event.rs +++ b/src/model/event.rs @@ -14,15 +14,9 @@ use tracing::{debug, warn}; use crate::constants::Opcode; use crate::internal::prelude::*; +use crate::internal::utils::lending_for_each; use crate::model::prelude::*; -use crate::model::utils::{ - deserialize_val, - emojis, - members, - remove_from_map, - remove_from_map_opt, - stickers, -}; +use crate::model::utils::{deserialize_val, remove_from_map, remove_from_map_opt}; /// Requires no gateway intents. /// @@ -179,9 +173,9 @@ pub struct GuildCreateEvent { impl<'de> Deserialize<'de> for GuildCreateEvent { fn deserialize>(deserializer: D) -> StdResult { let mut guild: Guild = Guild::deserialize(deserializer)?; - guild.channels.values_mut().for_each(|x| x.guild_id = guild.id); - guild.members.values_mut().for_each(|x| x.guild_id = guild.id); - guild.roles.values_mut().for_each(|x| x.guild_id = guild.id); + lending_for_each!(guild.channels.iter_mut(), |x| x.guild_id = guild.id); + lending_for_each!(guild.members.iter_mut(), |x| x.guild_id = guild.id); + lending_for_each!(guild.roles.iter_mut(), |x| x.guild_id = guild.id); Ok(Self { guild, }) @@ -206,8 +200,7 @@ pub struct GuildDeleteEvent { #[derive(Clone, Debug, Deserialize, Serialize)] #[non_exhaustive] pub struct GuildEmojisUpdateEvent { - #[serde(with = "emojis")] - pub emojis: HashMap, + pub emojis: ExtractMap, pub guild_id: GuildId, } @@ -279,8 +272,7 @@ pub struct GuildMembersChunkEvent { /// ID of the guild. pub guild_id: GuildId, /// Set of guild members. - #[serde(with = "members")] - pub members: HashMap, + pub members: ExtractMap, /// Chunk index in the expected chunks for this response (0 <= chunk_index < chunk_count). pub chunk_index: u32, /// Total number of expected chunks for this response. @@ -300,7 +292,7 @@ pub struct GuildMembersChunkEvent { impl<'de> Deserialize<'de> for GuildMembersChunkEvent { fn deserialize>(deserializer: D) -> StdResult { let mut event = Self::deserialize(deserializer)?; // calls #[serde(remote)]-generated inherent method - event.members.values_mut().for_each(|m| m.guild_id = event.guild_id); + lending_for_each!(event.members.iter_mut(), |m| m.guild_id = event.guild_id); Ok(event) } } @@ -379,8 +371,7 @@ impl<'de> Deserialize<'de> for GuildRoleUpdateEvent { #[derive(Clone, Debug, Deserialize, Serialize)] #[non_exhaustive] pub struct GuildStickersUpdateEvent { - #[serde(with = "stickers")] - pub stickers: HashMap, + pub stickers: ExtractMap, pub guild_id: GuildId, } diff --git a/src/model/gateway.rs b/src/model/gateway.rs index 030b413fa80..ebdd4451eee 100644 --- a/src/model/gateway.rs +++ b/src/model/gateway.rs @@ -316,6 +316,12 @@ pub struct Presence { pub client_status: Option, } +impl ExtractKey for Presence { + fn extract_key(&self) -> &UserId { + &self.user.id + } +} + /// An initial set of information given after IDENTIFYing to the gateway. /// /// [Discord docs](https://discord.com/developers/docs/topics/gateway#ready-ready-event-fields). diff --git a/src/model/guild/audit_log/mod.rs b/src/model/guild/audit_log/mod.rs index b2c658c76c5..37d1263cfce 100644 --- a/src/model/guild/audit_log/mod.rs +++ b/src/model/guild/audit_log/mod.rs @@ -7,7 +7,7 @@ mod change; mod utils; pub use change::{AffectedRole, Change, EntityType}; -use utils::{optional_string, users, webhooks}; +use utils::optional_string; use crate::internal::prelude::*; use crate::model::prelude::*; @@ -350,11 +350,9 @@ pub struct AuditLogs { /// map since archived threads might not be kept in memory by clients. pub threads: FixedArray, /// List of users referenced in the audit log. - #[serde(with = "users")] - pub users: HashMap, + pub users: ExtractMap, /// List of webhooks referenced in the audit log. - #[serde(with = "webhooks")] - pub webhooks: HashMap, + pub webhooks: ExtractMap, } /// Partial version of [`Integration`], used in [`AuditLogs::integrations`]. diff --git a/src/model/guild/audit_log/utils.rs b/src/model/guild/audit_log/utils.rs index 90f15f28f4e..f91cab1d5ca 100644 --- a/src/model/guild/audit_log/utils.rs +++ b/src/model/guild/audit_log/utils.rs @@ -1,41 +1,3 @@ -/// Used with `#[serde(with = "users")]` -pub mod users { - use std::collections::HashMap; - - use serde::Deserializer; - - use crate::model::id::UserId; - use crate::model::user::User; - use crate::model::utils::SequenceToMapVisitor; - - pub fn deserialize<'de, D: Deserializer<'de>>( - deserializer: D, - ) -> Result, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|u: &User| u.id)) - } - - pub use crate::model::utils::serialize_map_values as serialize; -} - -/// Used with `#[serde(with = "webhooks")]` -pub mod webhooks { - use std::collections::HashMap; - - use serde::Deserializer; - - use crate::model::id::WebhookId; - use crate::model::utils::SequenceToMapVisitor; - use crate::model::webhook::Webhook; - - pub fn deserialize<'de, D: Deserializer<'de>>( - deserializer: D, - ) -> Result, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|h: &Webhook| h.id)) - } - - pub use crate::model::utils::serialize_map_values as serialize; -} - /// Deserializes an optional string containing a valid integer as `Option`. /// /// Used with `#[serde(with = "optional_string")]`. diff --git a/src/model/guild/emoji.rs b/src/model/guild/emoji.rs index 0a669ada567..5f261888aca 100644 --- a/src/model/guild/emoji.rs +++ b/src/model/guild/emoji.rs @@ -85,6 +85,12 @@ impl fmt::Display for Emoji { } } +impl ExtractKey for Emoji { + fn extract_key(&self) -> &EmojiId { + &self.id + } +} + impl From for EmojiId { /// Gets the Id of an [`Emoji`]. fn from(emoji: Emoji) -> EmojiId { diff --git a/src/model/guild/guild_id.rs b/src/model/guild/guild_id.rs index da34463006f..41efbb94ccf 100644 --- a/src/model/guild/guild_id.rs +++ b/src/model/guild/guild_id.rs @@ -300,10 +300,8 @@ impl GuildId { /// # Errors /// /// Returns [`Error::Http`] if the current user is not in the guild. - pub async fn channels(self, http: &Http) -> Result> { - let channels = http.get_channels(self).await?; - - Ok(channels.into_iter().map(|c| (c.id, c)).collect()) + pub async fn channels(self, http: &Http) -> Result> { + http.get_channels(self).await } /// Creates a [`GuildChannel`] in the the guild. @@ -823,10 +821,8 @@ impl GuildId { /// /// Returns [`Error::Http`] if the current user is not in /// the guild. - pub async fn roles(self, http: &Http) -> Result> { - let roles = http.get_guild_roles(self).await?; - - Ok(roles.into_iter().map(|r| (r.id, r)).collect()) + pub async fn roles(self, http: &Http) -> Result> { + http.get_guild_roles(self).await } /// Gets the default permission role (@everyone) from the guild. diff --git a/src/model/guild/member.rs b/src/model/guild/member.rs index 3fdf50c65d2..8a1c1b1483e 100644 --- a/src/model/guild/member.rs +++ b/src/model/guild/member.rs @@ -170,7 +170,7 @@ impl Member { let member = guild.members.get(&self.user.id)?; - for channel in guild.channels.values() { + for channel in &guild.channels { if channel.kind != ChannelType::Category && guild.user_permissions_in(channel, member).view_channel() { @@ -453,8 +453,8 @@ impl Member { .guild(self.guild_id)? .roles .iter() - .filter(|(id, _)| self.roles.contains(id)) - .map(|(_, role)| role.clone()) + .filter(|r| self.roles.contains(&r.id)) + .cloned() .collect(), ) } @@ -508,6 +508,12 @@ impl fmt::Display for Member { } } +impl ExtractKey for Member { + fn extract_key(&self) -> &UserId { + &self.user.id + } +} + /// A partial amount of data for a member. /// /// This is used in [`Message`]s from [`Guild`]s. diff --git a/src/model/guild/mod.rs b/src/model/guild/mod.rs index e2f01d78c38..4e9140b14e8 100644 --- a/src/model/guild/mod.rs +++ b/src/model/guild/mod.rs @@ -147,11 +147,9 @@ pub struct Guild { /// Default explicit content filter level. pub explicit_content_filter: ExplicitContentFilter, /// A mapping of the guild's roles. - #[serde(with = "roles")] - pub roles: HashMap, + pub roles: ExtractMap, /// All of the guild's custom emojis. - #[serde(with = "emojis")] - pub emojis: HashMap, + pub emojis: ExtractMap, /// The guild features. More information available at [`discord documentation`]. /// /// The following is a list of known features: @@ -238,8 +236,7 @@ pub struct Guild { /// [`discord support article`]: https://support.discord.com/hc/en-us/articles/1500005389362-NSFW-Server-Designation pub nsfw_level: NsfwLevel, /// All of the guild's custom stickers. - #[serde(with = "stickers")] - pub stickers: HashMap, + pub stickers: ExtractMap, /// Whether the guild has the boost progress bar enabled pub premium_progress_bar_enabled: bool, @@ -256,29 +253,23 @@ pub struct Guild { /// The number of members in the guild. pub member_count: u64, /// A mapping of [`User`]s to their current voice state. - #[serde(serialize_with = "serialize_map_values")] - #[serde(deserialize_with = "deserialize_voice_states")] - pub voice_states: HashMap, + pub voice_states: ExtractMap, /// Users who are members of the guild. /// /// Members might not all be available when the [`ReadyEvent`] is received if the /// [`Self::member_count`] is greater than the [`LARGE_THRESHOLD`] set by the library. - #[serde(with = "members")] - pub members: HashMap, + pub members: ExtractMap, /// All voice and text channels contained within a guild. /// /// This contains all channels regardless of permissions (i.e. the ability of the bot to read /// from or connect to them). - #[serde(serialize_with = "serialize_map_values")] - #[serde(deserialize_with = "deserialize_guild_channels")] - pub channels: HashMap, + pub channels: ExtractMap, /// All active threads in this guild that current user has permission to view. pub threads: FixedArray, /// A mapping of [`User`]s' Ids to their current presences. /// /// **Note**: This will be empty unless the "guild presences" privileged intent is enabled. - #[serde(with = "presences")] - pub presences: HashMap, + pub presences: ExtractMap, /// The stage instances in this guild. pub stage_instances: FixedArray, /// The stage instances in this guild. @@ -385,7 +376,7 @@ impl Guild { #[must_use] pub fn default_channel(&self, uid: UserId) -> Option<&GuildChannel> { let member = self.members.get(&uid)?; - self.channels.values().find(|&channel| { + self.channels.iter().find(|&channel| { channel.kind != ChannelType::Category && self.user_permissions_in(channel, member).view_channel() }) @@ -397,11 +388,11 @@ impl Guild { /// **Note**: This is very costly if used in a server with lots of channels, members, or both. #[must_use] pub fn default_channel_guaranteed(&self) -> Option<&GuildChannel> { - self.channels.values().find(|&channel| { + self.channels.iter().find(|&channel| { channel.kind != ChannelType::Category && self .members - .values() + .iter() .map(|member| self.user_permissions_in(channel, member)) .all(Permissions::view_channel) }) @@ -582,7 +573,7 @@ impl Guild { /// # Errors /// /// Returns [`Error::Http`] if the guild is currently unavailable. - pub async fn channels(&self, http: &Http) -> Result> { + pub async fn channels(&self, http: &Http) -> Result> { self.id.channels(http).await } @@ -1529,9 +1520,8 @@ impl Guild { /// Gets a list of all the members (satisfying the status provided to the function) in this /// guild. pub fn members_with_status(&self, status: OnlineStatus) -> impl Iterator { - self.members.iter().filter_map(move |(id, member)| match self.presences.get(id) { - Some(presence) if presence.status == status => Some(member), - _ => None, + self.members.iter().filter(move |member| { + self.presences.get(&member.user.id).is_some_and(|p| p.status == status) }) } @@ -1558,7 +1548,7 @@ impl Guild { None => (name, None), }; - for member in self.members.values() { + for member in &self.members { if &*member.user.name == username && discrim.map_or(true, |d| member.user.discriminator == d) { @@ -1566,7 +1556,7 @@ impl Guild { } } - self.members.values().find(|member| member.nick.as_deref().is_some_and(|nick| nick == name)) + self.members.iter().find(|member| member.nick.as_deref().is_some_and(|nick| nick == name)) } /// Retrieves all [`Member`] that start with a given [`String`]. @@ -1598,7 +1588,7 @@ impl Guild { let mut members = self .members - .values() + .iter() .filter_map(|member| { let username = &member.user.name; @@ -1653,7 +1643,7 @@ impl Guild { ) -> Vec<(&Member, String)> { let mut members = self .members - .values() + .iter() .filter_map(|member| { let username = &member.user.name; @@ -1703,7 +1693,7 @@ impl Guild { ) -> Vec<(&Member, String)> { let mut members = self .members - .values() + .iter() .filter_map(|member| { let name = &member.user.name; contains(name, substring, case_sensitive).then(|| (member, name.clone().into())) @@ -1746,7 +1736,7 @@ impl Guild { ) -> Vec<(&Member, String)> { let mut members = self .members - .values() + .iter() .filter_map(|member| { let nick = member.nick.as_ref().unwrap_or(&member.user.name); contains(nick, substring, case_sensitive).then(|| (member, nick.clone().into())) @@ -1838,7 +1828,7 @@ impl Guild { member_user_id: UserId, member_roles: &[RoleId], guild_id: GuildId, - guild_roles: &HashMap, + guild_roles: &ExtractMap, guild_owner_id: UserId, ) -> Permissions { let mut everyone_allow_overwrites = Permissions::empty(); @@ -2223,7 +2213,7 @@ impl Guild { /// ``` #[must_use] pub fn role_by_name(&self, role_name: &str) -> Option<&Role> { - self.roles.values().find(|role| role_name == &*role.name) + self.roles.iter().find(|role| role_name == &*role.name) } /// Returns a builder which can be awaited to obtain a message or stream of messages in this @@ -2575,7 +2565,6 @@ enum_number! { mod test { #[cfg(feature = "model")] mod model { - use std::collections::*; use std::num::NonZeroU16; use crate::model::prelude::*; @@ -2596,7 +2585,7 @@ mod test { let m = gen_member(); Guild { - members: HashMap::from([(m.user.id, m)]), + members: ExtractMap::from_iter([m]), ..Default::default() } } diff --git a/src/model/guild/partial_guild.rs b/src/model/guild/partial_guild.rs index 15d21faf443..0f5ae9f6eb4 100644 --- a/src/model/guild/partial_guild.rs +++ b/src/model/guild/partial_guild.rs @@ -24,10 +24,10 @@ use crate::gateway::ShardMessenger; #[cfg(feature = "model")] use crate::http::{CacheHttp, Http, UserPagination}; use crate::internal::prelude::*; +use crate::internal::utils::lending_for_each; use crate::model::prelude::*; #[cfg(feature = "model")] use crate::model::utils::icon_url; -use crate::model::utils::{emojis, roles, stickers}; /// Partial information about a [`Guild`]. This does not include information like member data. /// @@ -82,11 +82,9 @@ pub struct PartialGuild { /// Default explicit content filter level. pub explicit_content_filter: ExplicitContentFilter, /// A mapping of the guild's roles. - #[serde(with = "roles")] - pub roles: HashMap, + pub roles: ExtractMap, /// All of the guild's custom emojis. - #[serde(with = "emojis")] - pub emojis: HashMap, + pub emojis: ExtractMap, /// The guild features. More information available at [`discord documentation`]. /// /// The following is a list of known features: @@ -173,8 +171,7 @@ pub struct PartialGuild { /// [`discord support article`]: https://support.discord.com/hc/en-us/articles/1500005389362-NSFW-Server-Designation pub nsfw_level: NsfwLevel, /// All of the guild's custom stickers. - #[serde(with = "stickers")] - pub stickers: HashMap, + pub stickers: ExtractMap, /// Whether the guild has the boost progress bar enabled pub premium_progress_bar_enabled: bool, } @@ -349,7 +346,7 @@ impl PartialGuild { /// /// Returns [`Error::Http`] if the current user is not in the guild or if the guild is /// otherwise unavailable. - pub async fn channels(&self, http: &Http) -> Result> { + pub async fn channels(&self, http: &Http) -> Result> { self.id.channels(http).await } @@ -1339,7 +1336,7 @@ impl PartialGuild { /// ``` #[must_use] pub fn role_by_name(&self, role_name: &str) -> Option<&Role> { - self.roles.values().find(|role| role_name == &*role.name) + self.roles.iter().find(|role| role_name == &*role.name) } /// Returns a builder which can be awaited to obtain a message or stream of messages in this @@ -1383,7 +1380,7 @@ impl PartialGuild { impl<'de> Deserialize<'de> for PartialGuildGeneratedOriginal { fn deserialize>(deserializer: D) -> StdResult { let mut guild = Self::deserialize(deserializer)?; // calls #[serde(remote)]-generated inherent method - guild.roles.values_mut().for_each(|r| r.guild_id = guild.id); + lending_for_each!(guild.roles.iter_mut(), |r| r.guild_id = guild.id); Ok(guild) } } diff --git a/src/model/guild/role.rs b/src/model/guild/role.rs index c4773f10c09..d4a7c51e61d 100644 --- a/src/model/guild/role.rs +++ b/src/model/guild/role.rs @@ -130,6 +130,12 @@ impl fmt::Display for Role { } } +impl ExtractKey for Role { + fn extract_key(&self) -> &RoleId { + &self.id + } +} + impl PartialOrd for Role { fn partial_cmp(&self, other: &Role) -> Option { Some(self.cmp(other)) diff --git a/src/model/mod.rs b/src/model/mod.rs index 878ba7a5536..14c4e1dc2c2 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -64,7 +64,6 @@ pub use self::timestamp::Timestamp; pub mod prelude { pub(crate) use std::collections::HashMap; - pub(crate) use serde::de::Visitor; pub(crate) use serde::{Deserialize, Deserializer}; pub use super::guild::automod::EventType as AutomodEventType; diff --git a/src/model/sticker.rs b/src/model/sticker.rs index b58b9224bdf..40dadfc188a 100644 --- a/src/model/sticker.rs +++ b/src/model/sticker.rs @@ -208,6 +208,12 @@ impl Sticker { } } +impl ExtractKey for Sticker { + fn extract_key(&self) -> &StickerId { + &self.id + } +} + enum_number! { /// Differentiates between sticker types. /// diff --git a/src/model/user.rs b/src/model/user.rs index 7ba3ce1f079..71ccbedd1f7 100644 --- a/src/model/user.rs +++ b/src/model/user.rs @@ -289,6 +289,12 @@ pub struct User { pub member: Option>, } +impl ExtractKey for User { + fn extract_key(&self) -> &UserId { + &self.id + } +} + enum_number! { /// Premium types denote the level of premium a user has. Visit the [Nitro](https://discord.com/nitro) /// page to learn more about the premium plans Discord currently offers. diff --git a/src/model/utils.rs b/src/model/utils.rs index 5d559254b26..d7f7c79a379 100644 --- a/src/model/utils.rs +++ b/src/model/utils.rs @@ -1,10 +1,8 @@ use std::cell::Cell; use std::fmt; -use std::hash::Hash; -use std::marker::PhantomData; use arrayvec::ArrayVec; -use serde::ser::{Serialize, SerializeSeq, Serializer}; +use serde::ser::SerializeSeq; use super::prelude::*; use crate::internal::prelude::*; @@ -157,7 +155,7 @@ impl<'de> serde::Deserialize<'de> for StrOrInt<'de> { #[track_caller] pub(crate) fn assert_json(data: &T, json: Value) where - T: Serialize + for<'de> Deserialize<'de> + PartialEq + std::fmt::Debug, + T: serde::Serialize + for<'de> Deserialize<'de> + PartialEq + std::fmt::Debug, { // test serialization let serialized = serde_json::to_value(data).unwrap(); @@ -174,69 +172,6 @@ where ); } -/// Used with `#[serde(with = "emojis")]` -pub mod emojis { - use std::collections::HashMap; - - use serde::Deserializer; - - use super::SequenceToMapVisitor; - use crate::model::guild::Emoji; - use crate::model::id::EmojiId; - - pub fn deserialize<'de, D: Deserializer<'de>>( - deserializer: D, - ) -> Result, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|emoji: &Emoji| emoji.id)) - } - - pub use super::serialize_map_values as serialize; -} - -pub fn deserialize_guild_channels<'de, D: Deserializer<'de>>( - deserializer: D, -) -> StdResult, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|channel: &GuildChannel| channel.id)) -} - -/// Used with `#[serde(with = "members")] -pub mod members { - use std::collections::HashMap; - - use serde::Deserializer; - - use super::SequenceToMapVisitor; - use crate::model::guild::Member; - use crate::model::id::UserId; - - pub fn deserialize<'de, D: Deserializer<'de>>( - deserializer: D, - ) -> Result, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|member: &Member| member.user.id)) - } - - pub use super::serialize_map_values as serialize; -} - -/// Used with `#[serde(with = "presences")]` -pub mod presences { - use std::collections::HashMap; - - use serde::Deserializer; - - use super::SequenceToMapVisitor; - use crate::model::gateway::Presence; - use crate::model::id::UserId; - - pub fn deserialize<'de, D: Deserializer<'de>>( - deserializer: D, - ) -> Result, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|p: &Presence| p.user.id)) - } - - pub use super::serialize_map_values as serialize; -} - pub fn deserialize_buttons<'de, D: Deserializer<'de>>( deserializer: D, ) -> StdResult, D::Error> { @@ -253,44 +188,6 @@ pub fn deserialize_buttons<'de, D: Deserializer<'de>>( }) } -/// Used with `#[serde(with = "roles")]` -pub mod roles { - use std::collections::HashMap; - - use serde::Deserializer; - - use super::SequenceToMapVisitor; - use crate::model::guild::Role; - use crate::model::id::RoleId; - - pub fn deserialize<'de, D: Deserializer<'de>>( - deserializer: D, - ) -> Result, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|role: &Role| role.id)) - } - - pub use super::serialize_map_values as serialize; -} - -/// Used with `#[serde(with = "stickers")]` -pub mod stickers { - use std::collections::HashMap; - - use serde::Deserializer; - - use super::SequenceToMapVisitor; - use crate::model::id::StickerId; - use crate::model::sticker::Sticker; - - pub fn deserialize<'de, D: Deserializer<'de>>( - deserializer: D, - ) -> Result, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|sticker: &Sticker| sticker.id)) - } - - pub use super::serialize_map_values as serialize; -} - /// Used with `#[serde(with = "comma_separated_string")]` pub mod comma_separated_string { use serde::{Deserialize, Deserializer, Serializer}; @@ -361,56 +258,3 @@ pub mod secret { secret.as_ref().map(ExposeSecret::expose_secret).serialize(serializer) } } - -pub fn deserialize_voice_states<'de, D: Deserializer<'de>>( - deserializer: D, -) -> StdResult, D::Error> { - deserializer.deserialize_seq(SequenceToMapVisitor::new(|state: &VoiceState| state.user_id)) -} - -pub fn serialize_map_values( - map: &HashMap, - serializer: S, -) -> StdResult { - serializer.collect_seq(map.values()) -} - -/// Deserializes a sequence and builds a `HashMap` with the key extraction function. -pub(in crate::model) struct SequenceToMapVisitor { - key: F, - marker: PhantomData, -} - -impl SequenceToMapVisitor { - pub fn new(key: F) -> Self { - Self { - key, - marker: PhantomData, - } - } -} - -impl<'de, F, K, V> Visitor<'de> for SequenceToMapVisitor -where - K: Eq + Hash, - V: Deserialize<'de>, - F: FnMut(&V) -> K, -{ - type Value = HashMap; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("sequence") - } - - fn visit_seq(mut self, mut seq: A) -> StdResult - where - A: serde::de::SeqAccess<'de>, - { - let mut map = seq.size_hint().map_or_else(HashMap::new, HashMap::with_capacity); - while let Some(elem) = seq.next_element()? { - map.insert((self.key)(&elem), elem); - } - - Ok(map) - } -} diff --git a/src/model/voice.rs b/src/model/voice.rs index d258d479346..a25711928b4 100644 --- a/src/model/voice.rs +++ b/src/model/voice.rs @@ -54,6 +54,12 @@ pub struct VoiceState { pub request_to_speak_timestamp: Option, } +impl extract_map::ExtractKey for VoiceState { + fn extract_key(&self) -> &UserId { + &self.user_id + } +} + // Manual impl needed to insert guild_id into Member impl<'de> Deserialize<'de> for VoiceStateGeneratedOriginal { fn deserialize>(deserializer: D) -> Result { diff --git a/src/model/webhook.rs b/src/model/webhook.rs index fbb643d06f3..1e8e115df33 100644 --- a/src/model/webhook.rs +++ b/src/model/webhook.rs @@ -92,6 +92,12 @@ pub struct Webhook { pub url: Option, } +impl ExtractKey for Webhook { + fn extract_key(&self) -> &WebhookId { + &self.id + } +} + /// The guild object returned by a [`Webhook`], of type [`WebhookType::ChannelFollower`]. #[cfg_attr(feature = "typesize", derive(typesize::derive::TypeSize))] #[derive(Debug, Clone, Deserialize, Serialize)] diff --git a/src/utils/argument_convert/channel.rs b/src/utils/argument_convert/channel.rs index f5bf17534f1..5fe9664dbca 100644 --- a/src/utils/argument_convert/channel.rs +++ b/src/utils/argument_convert/channel.rs @@ -54,7 +54,7 @@ async fn lookup_channel_global( #[cfg(feature = "cache")] if let Some(cache) = ctx.cache() { if let Some(guild) = cache.guild(guild_id) { - let channel = guild.channels.values().find(|c| c.name.eq_ignore_ascii_case(s)); + let channel = guild.channels.iter().find(|c| c.name.eq_ignore_ascii_case(s)); if let Some(channel) = channel { return Ok(Channel::Guild(channel.clone())); } diff --git a/src/utils/argument_convert/emoji.rs b/src/utils/argument_convert/emoji.rs index b42a067afc7..9a5d986ddb5 100644 --- a/src/utils/argument_convert/emoji.rs +++ b/src/utils/argument_convert/emoji.rs @@ -64,7 +64,7 @@ impl ArgumentConvert for Emoji { } if let Some(emoji) = - guild.emojis.values().find(|emoji| emoji.name.eq_ignore_ascii_case(s)).cloned() + guild.emojis.iter().find(|emoji| emoji.name.eq_ignore_ascii_case(s)).cloned() { return Ok(emoji); } diff --git a/src/utils/argument_convert/role.rs b/src/utils/argument_convert/role.rs index 6ea6807c2a1..c18319d9a9a 100644 --- a/src/utils/argument_convert/role.rs +++ b/src/utils/argument_convert/role.rs @@ -84,7 +84,7 @@ impl ArgumentConvert for Role { } #[cfg(feature = "cache")] - if let Some(role) = roles.values().find(|role| role.name.eq_ignore_ascii_case(s)) { + if let Some(role) = roles.iter().find(|role| role.name.eq_ignore_ascii_case(s)) { return Ok(role.clone()); } #[cfg(not(feature = "cache"))] diff --git a/src/utils/content_safe.rs b/src/utils/content_safe.rs index e50c581eed6..607af0290ff 100644 --- a/src/utils/content_safe.rs +++ b/src/utils/content_safe.rs @@ -220,6 +220,7 @@ mod tests { let member = Member { nick: Some(FixedString::from_static_trunc("Ferris")), + user: user.clone(), ..Default::default() }; @@ -235,9 +236,9 @@ mod tests { ..Default::default() }; - guild.channels.insert(channel.id, channel.clone()); - guild.members.insert(user.id, member.clone()); - guild.roles.insert(role.id, role); + guild.channels.insert(channel.clone()); + guild.members.insert(member.clone()); + guild.roles.insert(role); let with_user_mentions = "<@!100000000000000000> <@!000000000000000000> <@123> <@!123> \ <@!123123123123123123123> <@123> <@123123123123123123> <@!invalid> \