264 lines
9.6 KiB
Python
264 lines
9.6 KiB
Python
|
import pdb
|
||
|
from pprint import pformat
|
||
|
|
||
|
import argparse
|
||
|
from pathlib import Path
|
||
|
from typing import List, Optional, Tuple, Union, Dict
|
||
|
from threading import Event
|
||
|
from dataclasses import dataclass
|
||
|
from datetime import datetime
|
||
|
|
||
|
from wiki_postbot.clients.client import Client
|
||
|
from wiki_postbot.creds import Mediawiki_Creds, Slack_Creds
|
||
|
from wiki_postbot.interfaces.mediawiki import Wiki
|
||
|
from wiki_postbot.patterns.wikilink import Wikilink
|
||
|
|
||
|
from slack_sdk.web import WebClient
|
||
|
from slack_sdk.socket_mode import SocketModeClient
|
||
|
from slack_sdk.socket_mode.response import SocketModeResponse
|
||
|
from slack_sdk.socket_mode.request import SocketModeRequest
|
||
|
|
||
|
|
||
|
class SlackClient(Client):
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
creds: Slack_Creds,
|
||
|
wiki:Wiki,
|
||
|
name:str="wikibot_slack",
|
||
|
reply_channel="wikibot",
|
||
|
log_dir:Path = Path('/var/log/wikibot'),
|
||
|
):
|
||
|
"""Wikibot but for slack!"""
|
||
|
super(SlackClient, self).__init__(wiki=wiki, name=name, log_dir=log_dir)
|
||
|
|
||
|
self.creds = creds
|
||
|
self._initialized = False
|
||
|
self.web_client = None # type: Optional[WebClient]
|
||
|
self.socket_client = None # type: Optional[SocketModeClient]
|
||
|
self.reply_channel_name = reply_channel
|
||
|
self._reply_channel = None
|
||
|
self._channel_inverse = None
|
||
|
# self.web_client, self.socket_client = self._init_client(self.creds)
|
||
|
|
||
|
@property
|
||
|
def reply_channel(self) -> Union[str, None]:
|
||
|
if self._reply_channel is None:
|
||
|
self.logger.debug(f"Getting channel named {self.reply_channel_name}")
|
||
|
channels = self.web_client.conversations_list()
|
||
|
channel = [c for c in channels['channels'] if c['name'] == self.reply_channel_name]
|
||
|
# self.logger.debug(channel)
|
||
|
if len(channel) == 1:
|
||
|
self._reply_channel = channel[0]['id']
|
||
|
elif len(channel) > 1:
|
||
|
self.logger.exception(f"Got too many channels to reply to!")
|
||
|
self._reply_channel = None
|
||
|
else:
|
||
|
self.logger("Reply channel not found!")
|
||
|
self._reply_channel = None
|
||
|
return self._reply_channel
|
||
|
|
||
|
@property
|
||
|
def channel_inverse(self) -> Dict[str,str]:
|
||
|
"""Maps channel IDs to channel names"""
|
||
|
if self._channel_inverse is None:
|
||
|
self.logger.debug("Getting inverse channel map")
|
||
|
channels = self.web_client.conversations_list()
|
||
|
self._channel_inverse = {c['id']: c['name'] for c in channels['channels']}
|
||
|
return self._channel_inverse
|
||
|
|
||
|
|
||
|
def _init_client(self, creds: Slack_Creds) -> Tuple[WebClient, SocketModeClient]:
|
||
|
web_client = WebClient(creds.bot_token)
|
||
|
socket_client = SocketModeClient(app_token=creds.app_token, web_client=web_client)
|
||
|
return web_client, socket_client
|
||
|
|
||
|
def handle_event(self, client:SocketModeClient, req: SocketModeRequest):
|
||
|
self.logger.debug(f"type: {req.type}, payload_type: {req.payload['type']}\n{pformat(req.payload)}")
|
||
|
|
||
|
if req.type == "events_api":
|
||
|
# acknowledge we got it so it don't get resent
|
||
|
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||
|
client.send_socket_mode_response(response)
|
||
|
else:
|
||
|
self.logger.debug(f'Unhandled event type: {req.type}')
|
||
|
return
|
||
|
|
||
|
|
||
|
if req.type == "events_api" and \
|
||
|
req.payload['type'] == 'event_callback' and \
|
||
|
req.payload['event'].get('type', False) == 'message':
|
||
|
# Handle messages
|
||
|
self.logger.debug(f"Handling message")
|
||
|
|
||
|
message_text = req.payload['event']['text']
|
||
|
|
||
|
if 'good bot' in message_text.lower():
|
||
|
self.good_bot(client, req)
|
||
|
|
||
|
try:
|
||
|
wl = self.parse_wikilinks(message_text)
|
||
|
if len(wl) > 0:
|
||
|
self.logger.debug(f"Parsed wikilinks: {wl}")
|
||
|
self.handle_wikilinks(req, wl, client)
|
||
|
else:
|
||
|
self.logger.debug("No wikilinks found")
|
||
|
|
||
|
except Exception:
|
||
|
self.logger.exception("Error parsing wikilinks! got exception...")
|
||
|
|
||
|
else:
|
||
|
self.logger.debug(f"Was event, but not a message. Payload type: {req.payload['type']}, Event type: {req.payload.get('event', {}).get('type', 'unknown')}")
|
||
|
return
|
||
|
|
||
|
# pdb.set_trace()
|
||
|
# if
|
||
|
|
||
|
def parse_wikilinks(self, message) -> List[Wikilink]:
|
||
|
wikilinks = Wikilink.parse(message)
|
||
|
return wikilinks
|
||
|
|
||
|
def handle_wikilinks(self, message, wl: List[Wikilink], client:SocketModeClient):
|
||
|
self.react('hourglass', client, message)
|
||
|
try:
|
||
|
# expand fields in message
|
||
|
channel_id = message.payload['event']['channel']
|
||
|
channel_name = self.channel_inverse[channel_id]
|
||
|
msg = SlackMessage(
|
||
|
content = message.payload['event']['text'],
|
||
|
user_id = message.payload['event']['user'],
|
||
|
channel_id=channel_id,
|
||
|
channel = channel_name,
|
||
|
timestamp = message.payload['event']['ts']
|
||
|
)
|
||
|
msg.complete(client.web_client)
|
||
|
self.logger.debug(f"Posting message:\n{msg}")
|
||
|
result = self.wiki.handle_slack(msg)
|
||
|
ok = result.ok
|
||
|
|
||
|
except Exception:
|
||
|
self.logger.exception("Error handling slack message")
|
||
|
result = None
|
||
|
ok = False
|
||
|
|
||
|
self.react('hourglass', client, message, remove=True)
|
||
|
if ok:
|
||
|
self.react('white_check_mark', client, message)
|
||
|
else:
|
||
|
self.react('x', client, message)
|
||
|
|
||
|
if result and result.reply:
|
||
|
if self.reply_channel is None:
|
||
|
self.logger.exception(f"Do not have channel to reply to!")
|
||
|
else:
|
||
|
# await self.reply_channel.send(embed=Embed().add_field(name="WikiLinks", value=result.reply))
|
||
|
self.logger.debug('TODO: should reply here!')
|
||
|
|
||
|
|
||
|
def good_bot(self, client:SocketModeClient, req: SocketModeRequest):
|
||
|
self.logger.info('Got told we are a good bot ^_^')
|
||
|
self.react('heart', client, req)
|
||
|
self.react('heavy_plus_sign', client, req)
|
||
|
self.react('fire', client, req)
|
||
|
self.react('heavy_equals_sign', client, req)
|
||
|
self.react('heart_on_fire', client, req)
|
||
|
|
||
|
def react(self, emoji:str, client:SocketModeClient, req: SocketModeRequest, remove:bool=False):
|
||
|
if remove:
|
||
|
client.web_client.reactions_remove(
|
||
|
name=emoji,
|
||
|
channel=req.payload["event"]["channel"],
|
||
|
timestamp=req.payload["event"]["ts"],
|
||
|
)
|
||
|
else:
|
||
|
client.web_client.reactions_add(
|
||
|
name=emoji,
|
||
|
channel=req.payload["event"]["channel"],
|
||
|
timestamp=req.payload["event"]["ts"],
|
||
|
)
|
||
|
|
||
|
def run(self, creds: Optional[Slack_Creds] = None):
|
||
|
if creds is None:
|
||
|
creds = self.creds
|
||
|
else:
|
||
|
self.creds = creds
|
||
|
|
||
|
self.web_client, self.socket_client = self._init_client(creds)
|
||
|
|
||
|
self.socket_client.socket_mode_request_listeners.append(self.handle_event)
|
||
|
|
||
|
self.logger.debug("Connecting...")
|
||
|
|
||
|
self.socket_client.connect()
|
||
|
|
||
|
self.logger.debug(f"Got Reply Channel ID {self.reply_channel} for {self.reply_channel_name}")
|
||
|
|
||
|
try:
|
||
|
self.logger.info("Slack Client Listening")
|
||
|
Event().wait()
|
||
|
except KeyboardInterrupt:
|
||
|
self.logger.info("Quitting Slack client!")
|
||
|
|
||
|
@dataclass
|
||
|
class SlackMessage:
|
||
|
content: str # the actual content of the message
|
||
|
user_id: str
|
||
|
channel_id: str
|
||
|
channel: str # currently expected to be passed at instantiation because the client keeps a reverse index. bad information hiding i know.
|
||
|
timestamp: str
|
||
|
"""The unix epoch string timestamp stored as a slack event's ts attribute"""
|
||
|
|
||
|
# these need to be filled in after instantiation by passing a webclient to the completion methods
|
||
|
avatar: str = '' # URL of image
|
||
|
permalink: str = ''
|
||
|
author:str = '' # display name
|
||
|
|
||
|
_complete:bool=False
|
||
|
|
||
|
def get_permalink(self, client: WebClient):
|
||
|
permalink = client.chat_getPermalink(channel=self.channel_id, message_ts=self.timestamp)
|
||
|
self.permalink = permalink['permalink']
|
||
|
|
||
|
def get_user(self, client:WebClient):
|
||
|
user_info = client.users_info(user=self.user_id)
|
||
|
self.avatar = user_info['user']['profile']['image_192']
|
||
|
self.author = user_info['user']['profile']['display_name']
|
||
|
|
||
|
def complete(self, client:WebClient):
|
||
|
self.get_permalink(client)
|
||
|
self.get_user(client)
|
||
|
self._complete = True
|
||
|
|
||
|
@property
|
||
|
def date_sent(self) -> datetime:
|
||
|
return datetime.fromtimestamp(float(self.timestamp))
|
||
|
|
||
|
|
||
|
|
||
|
def argparser() -> argparse.ArgumentParser:
|
||
|
parser = argparse.ArgumentParser(
|
||
|
prog="slack_bot",
|
||
|
description="A slack bot for posting messages with wikilinks to an associated mediawiki wiki"
|
||
|
)
|
||
|
parser.add_argument('-d', '--directory', default='/etc/wikibot/', type=Path,
|
||
|
help="Directory that stores credential files and logs")
|
||
|
parser.add_argument('-w', '--wiki', help="URL of wiki", type=str)
|
||
|
return parser
|
||
|
|
||
|
def main():
|
||
|
parser = argparser()
|
||
|
args = parser.parse_args()
|
||
|
directory = Path(args.directory)
|
||
|
log_dir = directory / "logs"
|
||
|
|
||
|
slack_creds = Slack_Creds.from_json(directory / 'slack_creds.json')
|
||
|
wiki_creds = Mediawiki_Creds.from_json(directory / 'mediawiki_creds.json')
|
||
|
|
||
|
wiki = Wiki(url=args.wiki, log_dir=log_dir, creds=wiki_creds)
|
||
|
wiki.login(wiki_creds)
|
||
|
|
||
|
client = SlackClient(creds=slack_creds, wiki=wiki, log_dir=log_dir)
|
||
|
client.run()
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|