feat: add redis cache to store polls state
This commit is contained in:
@@ -3,13 +3,15 @@ import requests
|
|||||||
import json
|
import json
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
|
ALPHAVANTAGE_API_KEY = os.getenv('ALPHAVANTAGE_API_KEY')
|
||||||
|
LUNAR_CRUSH_API_KEY = os.getenv("LUNAR_CRUSH_API_KEY")
|
||||||
|
|
||||||
class Finance(commands.Cog):
|
class Finance(commands.Cog):
|
||||||
"""Commands to query stock and crypto prices."""
|
"""Commands to query stock and crypto prices."""
|
||||||
|
|
||||||
@commands.command(usage="<tickers>")
|
@commands.command(usage="<tickers>")
|
||||||
async def stock(self, ctx, *tickers: str) -> None:
|
async def stock(self, ctx, *tickers: str) -> None:
|
||||||
"""Gets prices for the given stock tickers. Will default tickers to AAPL, GOOG, MSFT and AMZN if none provided."""
|
"""Gets prices for the given stock tickers. Will default tickers to AAPL, GOOG, MSFT and AMZN if none provided."""
|
||||||
ALPHAVANTAGE_API_KEY = os.getenv('ALPHAVANTAGE_API_KEY')
|
|
||||||
|
|
||||||
# default tickers
|
# default tickers
|
||||||
if len(tickers) == 0:
|
if len(tickers) == 0:
|
||||||
@@ -43,7 +45,6 @@ class Finance(commands.Cog):
|
|||||||
@commands.command(usage="<tickers>")
|
@commands.command(usage="<tickers>")
|
||||||
async def crypto(self, ctx, *tickers: str) -> None:
|
async def crypto(self, ctx, *tickers: str) -> None:
|
||||||
"""Gets prices for the given crypto tickers. Will default tickers to BTC, ETH and LTC if none provided."""
|
"""Gets prices for the given crypto tickers. Will default tickers to BTC, ETH and LTC if none provided."""
|
||||||
LUNAR_CRUSH_API_KEY = os.getenv("LUNAR_CRUSH_API_KEY")
|
|
||||||
|
|
||||||
# default tickers
|
# default tickers
|
||||||
if len(tickers) == 0:
|
if len(tickers) == 0:
|
||||||
|
|||||||
68
cogs/fun.py
68
cogs/fun.py
@@ -1,21 +1,17 @@
|
|||||||
from os import name
|
|
||||||
import random
|
import random
|
||||||
import discord
|
import discord
|
||||||
from discord.ext.commands.core import command
|
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
from data import poll_cache
|
from model import Poll
|
||||||
|
from data import polls
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from discord.raw_models import RawReactionActionEvent
|
||||||
|
|
||||||
class Fun(commands.Cog):
|
class Fun(commands.Cog):
|
||||||
"""Commands that are good for the soul!"""
|
"""Commands that are good for the soul!"""
|
||||||
|
|
||||||
OPTION_EMOJIS = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣"]
|
|
||||||
FULL_CHAR = "█"
|
|
||||||
BAR_LENGTH = 20
|
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def flip(self, ctx) -> None:
|
async def flip(self, ctx) -> None:
|
||||||
"""Flips a coin and return either heads or tails."""
|
"""Flips a coin and return either heads or tails."""
|
||||||
@@ -65,27 +61,25 @@ class Fun(commands.Cog):
|
|||||||
await ctx.send("**Please provide the parameters for the poll. ex.** `!poll \"My question?\" option1 option2`")
|
await ctx.send("**Please provide the parameters for the poll. ex.** `!poll \"My question?\" option1 option2`")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# parse the parameters
|
||||||
question = params[0]
|
question = params[0]
|
||||||
options = params[1:]
|
options = params[1:]
|
||||||
|
|
||||||
# create embed response
|
# create embed response
|
||||||
poll_embed = discord.Embed(title=question, description="\u200b", color=0x3DCCDD)
|
poll_embed = discord.Embed(title=question, description="\u200b", color=0x3DCCDD)
|
||||||
for i in range(len(options)):
|
for i in range(len(options)):
|
||||||
poll_embed.add_field(value="\u200b", name=f"{self.OPTION_EMOJIS[i]} {options[i]}", inline=False)
|
poll_embed.add_field(value="\u200b", name=f"{Poll.OPTION_EMOJIS[i]} {options[i]}", inline=False)
|
||||||
poll_embed.set_footer(text=f"Poll created on • {date.today().strftime('%m/%d/%y')}")
|
poll_embed.set_footer(text=f"Poll created on • {date.today().strftime('%m/%d/%y')}")
|
||||||
|
|
||||||
message = await ctx.send(embed=poll_embed)
|
message = await ctx.send(embed=poll_embed)
|
||||||
|
|
||||||
# add poll to cache
|
# add poll to cache
|
||||||
poll_cache[message.id] = {}
|
poll = Poll(message.id, question, options)
|
||||||
for i in range(len(options)):
|
polls[message.id] = json.dumps(poll.__dict__)
|
||||||
poll_cache[message.id][f"{self.OPTION_EMOJIS[i]}"] = []
|
|
||||||
poll_cache[message.id]["question"] = question
|
|
||||||
poll_cache[message.id]["options"] = ",".join(options)
|
|
||||||
|
|
||||||
# add option emojis to poll
|
# add option emoji reactions to poll
|
||||||
for i in range(len(options)):
|
for i in range(len(options)):
|
||||||
await message.add_reaction(self.OPTION_EMOJIS[i])
|
await message.add_reaction(Poll.OPTION_EMOJIS[i])
|
||||||
|
|
||||||
@poll.command(usage="<message_id>", name="results")
|
@poll.command(usage="<message_id>", name="results")
|
||||||
async def poll_results(self, ctx, message_id: int = None) -> None:
|
async def poll_results(self, ctx, message_id: int = None) -> None:
|
||||||
@@ -95,27 +89,39 @@ class Fun(commands.Cog):
|
|||||||
await ctx.send("**Please provide the message id of a poll. ex.** `!poll results 870733802190807050")
|
await ctx.send("**Please provide the message id of a poll. ex.** `!poll results 870733802190807050")
|
||||||
return
|
return
|
||||||
|
|
||||||
if message_id not in poll_cache:
|
if message_id not in polls:
|
||||||
await ctx.send("**Could not find poll with that message id.**")
|
await ctx.send("**Could not find poll with that message id.**")
|
||||||
return
|
return
|
||||||
|
|
||||||
poll = poll_cache[message_id]
|
poll = Poll.create_from(json.loads(polls[message_id]))
|
||||||
question = poll["question"]
|
|
||||||
options = poll["options"].split(",")
|
|
||||||
|
|
||||||
total_votes = sum([0 if item[0] == "question" or item[0] == "options" else len(item[1]) for item in poll.items()])
|
|
||||||
option_percentages = []
|
|
||||||
for i in range(len(options)):
|
|
||||||
option_votes = len(poll[self.OPTION_EMOJIS[i]])
|
|
||||||
option_percentages.append(0 if total_votes == 0 else option_votes / total_votes)
|
|
||||||
|
|
||||||
# construct embed response
|
# construct embed response
|
||||||
results_embed = discord.Embed(title=question, description="🥁", color=0x3DCCDD)
|
results_embed = discord.Embed(title=poll.question, description="🥁", color=0x3DCCDD)
|
||||||
for i in range(len(options)):
|
for option in poll.options:
|
||||||
str_bar = "".join([self.FULL_CHAR for _ in range(int(self.BAR_LENGTH * option_percentages[i]))])
|
str_bar = "".join([Poll.BAR_CHAR for _ in range(int(Poll.BAR_LENGTH * poll.get_vote_percentage(option)))])
|
||||||
str_percent = f"{option_percentages[i] * 100:.2f}%"
|
str_percent = f"{poll.get_vote_percentage(option) * 100:.2f}%"
|
||||||
str_votes = f"({int(option_percentages[i] * total_votes)} votes)"
|
str_votes = f"({poll.get_vote_count(option)} votes)"
|
||||||
results_embed.add_field(name=f"{self.OPTION_EMOJIS[i]} {options[i]}", value=f"{str_bar} {str_percent} {str_votes}", inline=False)
|
results_embed.add_field(name=f"{poll.get_emoji(option)} {option}", value=f"{str_bar} {str_percent} {str_votes}", inline=False)
|
||||||
results_embed.set_footer(text=f"Poll queried on • {date.today().strftime('%m/%d/%y')}")
|
results_embed.set_footer(text=f"Poll queried on • {date.today().strftime('%m/%d/%y')}")
|
||||||
|
|
||||||
await ctx.send(embed=results_embed)
|
await ctx.send(embed=results_embed)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def poll_vote(event: RawReactionActionEvent) -> None:
|
||||||
|
"""Handles a user voting on a poll."""
|
||||||
|
|
||||||
|
# check if message is a poll and emoji is a valid option
|
||||||
|
if event.message_id in polls and event.emoji.name in Poll.OPTION_EMOJIS:
|
||||||
|
# get poll from cache
|
||||||
|
poll = Poll.create_from(json.loads(polls[event.message_id]))
|
||||||
|
|
||||||
|
print(event)
|
||||||
|
|
||||||
|
# check reaction emoji type
|
||||||
|
if event.event_type == "REACTION_ADD":
|
||||||
|
poll.add_vote(event.user_id, event.emoji.name)
|
||||||
|
elif event.event_type == "REACTION_REMOVE":
|
||||||
|
poll.remove_vote(event.user_id, event.emoji.name)
|
||||||
|
|
||||||
|
# update cache
|
||||||
|
polls[event.message_id] = json.dumps(poll.__dict__)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
poll_cache = {} # TODO: use a real cache
|
from data.redis import polls
|
||||||
6
data/redis.py
Normal file
6
data/redis.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import os
|
||||||
|
from walrus import *
|
||||||
|
|
||||||
|
_REDIS_URL = os.getenv('REDIS_URL')
|
||||||
|
_db = Database.from_url(_REDIS_URL)
|
||||||
|
polls = _db.Hash("polls")
|
||||||
1
model/__init__.py
Normal file
1
model/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from model.poll import Poll
|
||||||
65
model/poll.py
Normal file
65
model/poll.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
class Poll:
|
||||||
|
|
||||||
|
OPTION_EMOJIS = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣"]
|
||||||
|
BAR_CHAR = "█"
|
||||||
|
BAR_LENGTH = 20
|
||||||
|
|
||||||
|
def __init__(self, message_id: int, question: str, options: List[str], votes: Dict = None):
|
||||||
|
self.message_id = message_id
|
||||||
|
self.question = question
|
||||||
|
self.options = options
|
||||||
|
self.votes = {option: [] for option in options} if not votes else votes
|
||||||
|
|
||||||
|
def get_vote_count(self, option: str) -> int:
|
||||||
|
if option not in self.options:
|
||||||
|
return 0
|
||||||
|
return len(self.votes[option])
|
||||||
|
|
||||||
|
def get_total_vote_count(self) -> int:
|
||||||
|
return sum([self.get_vote_count(option) for option in self.options])
|
||||||
|
|
||||||
|
def get_vote_percentage(self, option: str) -> float:
|
||||||
|
vote_count = self.get_vote_count(option)
|
||||||
|
|
||||||
|
# avoid divide by zero
|
||||||
|
return 0 if self.get_total_vote_count() == 0 else vote_count / self.get_total_vote_count()
|
||||||
|
|
||||||
|
def add_vote(self, user_id: str, emoji: str) -> None:
|
||||||
|
option = self._get_option(emoji)
|
||||||
|
|
||||||
|
# only add vote if user has not already voted
|
||||||
|
if user_id not in self.votes[option]:
|
||||||
|
self.votes[option].append(user_id)
|
||||||
|
|
||||||
|
print(self.votes) # TODO: remove
|
||||||
|
|
||||||
|
def remove_vote(self, user_id: str, emoji: str) -> None:
|
||||||
|
option = self._get_option(emoji)
|
||||||
|
|
||||||
|
# only remove if user has voted
|
||||||
|
if user_id in self.votes[option]:
|
||||||
|
self.votes[option].remove(user_id)
|
||||||
|
|
||||||
|
print(self.votes) # TODO: remove
|
||||||
|
|
||||||
|
def get_emoji(self, option: str) -> str:
|
||||||
|
return Poll.OPTION_EMOJIS[self.options.index(option)]
|
||||||
|
|
||||||
|
def _get_option(self, emoji: str) -> str:
|
||||||
|
# check if valid emoji
|
||||||
|
if emoji not in self.OPTION_EMOJIS:
|
||||||
|
return None
|
||||||
|
|
||||||
|
option_index = Poll.OPTION_EMOJIS.index(emoji)
|
||||||
|
|
||||||
|
# check if valid vote
|
||||||
|
if option_index >= len(self.options):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.options[option_index]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_from(poll_dict: Dict) -> "Poll":
|
||||||
|
return Poll(int(poll_dict["message_id"]), poll_dict["question"], poll_dict["options"], poll_dict["votes"])
|
||||||
36
pyvis.py
36
pyvis.py
@@ -1,20 +1,18 @@
|
|||||||
|
# load environment variables from .env file
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import os
|
import os
|
||||||
from discord.flags import MessageFlags
|
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from cogs import Programming
|
from cogs import Programming
|
||||||
from cogs import Fun
|
from cogs import Fun
|
||||||
from cogs import Finance
|
from cogs import Finance
|
||||||
from data import poll_cache
|
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from dotenv import load_dotenv
|
|
||||||
from pretty_help import DefaultMenu, PrettyHelp
|
from pretty_help import DefaultMenu, PrettyHelp
|
||||||
|
|
||||||
# load environment variables from .env file
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN')
|
DISCORD_BOT_TOKEN = os.getenv('DISCORD_BOT_TOKEN')
|
||||||
DISCORD_GUILD_NAME = os.getenv('DISCORD_GUILD_NAME')
|
DISCORD_GUILD_NAME = os.getenv('DISCORD_GUILD_NAME')
|
||||||
GIPHY_API_KEY = os.getenv("GIPHY_API_KEY")
|
GIPHY_API_KEY = os.getenv("GIPHY_API_KEY")
|
||||||
@@ -36,7 +34,7 @@ async def on_ready():
|
|||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_member_join(member):
|
async def on_member_join(member):
|
||||||
""" WELCOME START """
|
""" WELCOME START """ # TODO: extract to a separate file
|
||||||
guild = discord.utils.find(lambda g: g.name == DISCORD_GUILD_NAME, bot.guilds) # TODO: check if find will throw an error
|
guild = discord.utils.find(lambda g: g.name == DISCORD_GUILD_NAME, bot.guilds) # TODO: check if find will throw an error
|
||||||
welcome_channel = discord.utils.find(lambda c: c.name == "👋welcome", guild.channels) # TODO: remove welcome channel hardcode
|
welcome_channel = discord.utils.find(lambda c: c.name == "👋welcome", guild.channels) # TODO: remove welcome channel hardcode
|
||||||
|
|
||||||
@@ -57,7 +55,7 @@ async def on_message(message):
|
|||||||
if message.author == bot.user:
|
if message.author == bot.user:
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO: is channel is a proposal channel check for poll command
|
# TODO: if channel is a proposal channel check for poll command
|
||||||
|
|
||||||
# repsond with command prefix
|
# repsond with command prefix
|
||||||
if bot.user.mentioned_in(message) and not message.mention_everyone:
|
if bot.user.mentioned_in(message) and not message.mention_everyone:
|
||||||
@@ -71,32 +69,14 @@ async def on_raw_reaction_add(event):
|
|||||||
# avoid bot self trigger
|
# avoid bot self trigger
|
||||||
if event.user_id == bot.user.id:
|
if event.user_id == bot.user.id:
|
||||||
return
|
return
|
||||||
|
Fun.poll_vote(event)
|
||||||
""" POLL START """
|
|
||||||
# check if message is in cache and if reaction is valid
|
|
||||||
if event.message_id in poll_cache and event.emoji.name in Fun.OPTION_EMOJIS:
|
|
||||||
# if user has not voted yet, add their vote
|
|
||||||
if event.user_id not in poll_cache[event.message_id][event.emoji.name]:
|
|
||||||
poll_cache[event.message_id][event.emoji.name].append(event.user_id)
|
|
||||||
""" POLL END """
|
|
||||||
|
|
||||||
print(poll_cache)
|
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_raw_reaction_remove(event):
|
async def on_raw_reaction_remove(event):
|
||||||
# avoid bot self trigger
|
# avoid bot self trigger
|
||||||
if event.user_id == bot.user.id:
|
if event.user_id == bot.user.id:
|
||||||
return
|
return
|
||||||
|
Fun.poll_vote(event)
|
||||||
""" POLL START """
|
|
||||||
# check if message is in cache and if reaction is valid
|
|
||||||
if event.message_id in poll_cache and event.emoji.name in Fun.OPTION_EMOJIS:
|
|
||||||
# if user has voted, remove their vote
|
|
||||||
if event.user_id in poll_cache[event.message_id][event.emoji.name]:
|
|
||||||
poll_cache[event.message_id][event.emoji.name].remove(event.user_id)
|
|
||||||
""" POLL END """
|
|
||||||
|
|
||||||
print(poll_cache)
|
|
||||||
|
|
||||||
# override default help command
|
# override default help command
|
||||||
menu = DefaultMenu('◀️', '▶️', '❌')
|
menu = DefaultMenu('◀️', '▶️', '❌')
|
||||||
|
|||||||
@@ -9,7 +9,9 @@ discord.py==1.7.3
|
|||||||
idna==3.2
|
idna==3.2
|
||||||
multidict==5.1.0
|
multidict==5.1.0
|
||||||
python-dotenv==0.19.0
|
python-dotenv==0.19.0
|
||||||
|
redis==3.5.3
|
||||||
requests==2.26.0
|
requests==2.26.0
|
||||||
typing-extensions==3.10.0.0
|
typing-extensions==3.10.0.0
|
||||||
urllib3==1.26.6
|
urllib3==1.26.6
|
||||||
|
walrus==0.8.2
|
||||||
yarl==1.6.3
|
yarl==1.6.3
|
||||||
|
|||||||
Reference in New Issue
Block a user