Skip to content

Commit

Permalink
Allow nsfw force (#115)
Browse files Browse the repository at this point in the history
* force overrides nsfw protection

* warn when forcing nsfw to non-nsfw
  • Loading branch information
circuitsacul authored Jul 3, 2022
1 parent 3e028fd commit 3464cbc
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
84 changes: 71 additions & 13 deletions starboard/commands/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from starboard.database import Message, SBMessage, Starboard
from starboard.exceptions import StarboardError
from starboard.utils import jump
from starboard.views import Paginator
from starboard.views import Confirm, Paginator

from ._autocomplete import starboard_autocomplete
from ._checks import has_guild_perms
Expand Down Expand Up @@ -228,6 +228,39 @@ async def toggle_trashed(


# FORCING
async def validate_nsfw_force_if_needed(
bot: Bot, ctx: crescent.Context, sbids: list[int]
) -> bool:
for sbid in sbids:
channel = await bot.cache.gof_guild_channel_wnsfw(sbid)
if channel is None:
continue

assert channel.is_nsfw is not None
if not channel.is_nsfw:
return await validate_nsfw_force(ctx)

return True


async def validate_nsfw_force(ctx: crescent.Context) -> bool:
confirm = Confirm(ctx.user.id, True)

msg = await ctx.respond(
"This message is from an NSFW channel, and forcing it will cause it "
"to appear on non-NSFW starboards. Are you sure you want to do this?",
ephemeral=True,
components=confirm.build(),
ensure_message=True,
)
confirm.start(msg)
await confirm.wait()

if not confirm.result:
return False
return True


@plugin.include
@utils.child
@crescent.command(
Expand Down Expand Up @@ -271,25 +304,39 @@ async def callback(self, ctx: crescent.Context) -> None:
if self.starboard:
sb = await Starboard.from_name(ctx.guild_id, self.starboard)
sbids = [sb.id]
sbchids = [sb.channel_id]
else:
assert ctx.guild_id
sbids = [
sb.id
for sb in await Starboard.fetch_query()
sbids = []
sbchids = []
for sb in (
await Starboard.fetch_query()
.where(guild_id=ctx.guild_id)
.fetchmany()
]
):
sbids.append(sb.id)
sbchids.append(sb.channel_id)
if not sbids:
raise StarboardError(
"This server has no starboards, so you can't force this "
"message."
)

replied = False
if msg.is_nsfw:
replied = True
if not await validate_nsfw_force_if_needed(bot, ctx, sbchids):
await ctx.edit("Cancelled.", components=[])
return

orig_force = set(msg.forced_to)
orig_force.update(sbids)
msg.forced_to = list(orig_force)
await msg.save()
await ctx.respond("Message forced.", ephemeral=True)
if replied:
await ctx.edit("Message forced.", components=[])
else:
await ctx.respond("Message forced.", ephemeral=True)
await refresh_message(bot, msg, sbids, force=True)


Expand Down Expand Up @@ -348,12 +395,13 @@ async def force_message(
) -> None:
bot = cast("Bot", ctx.app)

sbids = [
sb.id
for sb in await Starboard.fetch_query()
.where(guild_id=ctx.guild_id)
.fetchmany()
]
sbids: list[int] = []
sbchids: list[int] = []
for sb in (
await Starboard.fetch_query().where(guild_id=ctx.guild_id).fetchmany()
):
sbids.append(sb.id)
sbchids.append(sb.channel_id)
if not sbids:
raise StarboardError(
"There are no starboards in this server, so you can't force this "
Expand All @@ -378,9 +426,19 @@ async def force_message(
obj.author.is_bot,
)

replied = False
if msg.is_nsfw:
replied = True
if not await validate_nsfw_force_if_needed(bot, ctx, sbchids):
await ctx.edit("Cancelled.", components=[])
return

msg.forced_to = sbids
await msg.save()
await ctx.respond("Message forced to all starboards.", ephemeral=True)
if replied:
await ctx.edit("Message forced to all starboards.", components=[])
else:
await ctx.respond("Message forced to all starboards.", ephemeral=True)
await refresh_message(bot, msg, sbids, force=True)


Expand Down
2 changes: 1 addition & 1 deletion starboard/core/starboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def _refresh_message_for_starboard(
force: bool,
premium: bool,
) -> None:
if orig_msg.is_nsfw:
if orig_msg.is_nsfw and config.starboard.id not in orig_msg.forced_to:
sbchannel = await bot.cache.gof_guild_channel_wnsfw(
config.starboard.channel_id
)
Expand Down

0 comments on commit 3464cbc

Please sign in to comment.