from bs4 import BeautifulSoup
import MySQLdb
from pebble import ProcessPool
from concurrent.futures import TimeoutError
import markovify
import requests
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA256
from Crypto.Signature import PKCS1_v1_5
from base64 import b64decode, b64encode
from mastodon import Mastodon, MastodonUnauthorizedError
import html, re, json
cfg = json.load(open('config.json'))
class nlt_fixed(markovify.NewlineText): # modified version of NewlineText that never rejects sentences
def test_sentence_input(self, sentence):
return True # all sentences are valid <3
def extract_post(post):
post = html.unescape(post) # convert HTML escape codes to text
soup = BeautifulSoup(post, "html.parser")
for lb in soup.select("br"): # replace
with linebreak
lb.insert_after("\n")
lb.decompose()
for p in soup.select("p"): # ditto for
p.insert_after("\n")
p.unwrap()
for ht in soup.select("a.hashtag"): # convert hashtags from links to text
ht.unwrap()
for link in soup.select("a"): #ocnvert 1:
id = args[1]
acct = args[3]
handle = args[0]
# print("Generating post for {}".format(handle))
bot, post = generate_output(handle)
# post will be None if there's no posts for the bot to learn from.
# in such a case, we should just exit without doing anything.
if post == None: return
client = Mastodon(
client_id = bot['client_id'],
client_secret = bot['client_secret'],
access_token = bot['secret'],
api_base_url = "https://{}".format(handle.split("@")[2])
)
db = MySQLdb.connect(
host = cfg['db_host'],
user=cfg['db_user'],
passwd=cfg['db_pass'],
db=cfg['db_name'],
use_unicode=True,
charset="utf8mb4"
)
c = db.cursor()
# print(post)
visibility = bot['post_privacy'] if len(args) == 1 else args[2]
visibilities = ['public', 'unlisted', 'private']
if visibilities.index(visibility) < visibilities.index(bot['post_privacy']):
# if post_privacy is set to a more restricted level than the visibility of the post we're replying to, use the user's setting
visibility = bot['post_privacy']
if acct is not None:
post = "{} {}".format(acct, post)
# ensure post isn't longer than bot['length']
# TODO: ehhhhhhhhh
post = post[:bot['length']]
# send toot!!
try:
client.status_post(post, id, visibility = visibility, spoiler_text = bot['content_warning'])
except MastodonUnauthorizedError:
# user has revoked the token given to the bot
# this needs to be dealt with properly later on, but for now, we'll just disable the bot
c.execute("UPDATE bots SET enabled = FALSE WHERE handle = %s", (handle,))
except:
print("Failed to submit post for {}".format(handle))
if id == None:
# this wasn't a reply, it was a regular post, so update the last post date
c.execute("UPDATE bots SET last_post = CURRENT_TIMESTAMP() WHERE handle = %s", (handle,))
db.commit()
c.close()
def task_done(future):
try:
result = future.result() # blocks until results are ready
except TimeoutError as error:
if not future.silent: print("Timed out on {}.".format(future.function_data))
def do_in_pool(function, data, timeout=30, silent=False):
with ProcessPool(max_workers=5, max_tasks=10) as pool:
for i in data:
future = pool.schedule(function, args=[i], timeout=timeout)
future.silent = silent
future.function_data = i
future.add_done_callback(task_done)
def get_key():
db = MySQLdb.connect(
host = cfg['db_host'],
user=cfg['db_user'],
passwd=cfg['db_pass'],
db=cfg['db_name'],
use_unicode=True,
charset="utf8mb4"
)
dc = db.cursor(MySQLdb.cursors.DictCursor)
dc.execute("SELECT * FROM http_auth_key")
key = dc.fetchone()
if key == None:
# generate new key
key = {}
privkey = RSA.generate(4096)
key['private'] = privkey.exportKey('PEM').decode('utf-8')
key['public'] = privkey.publickey().exportKey('PEM').decode('utf-8')
dc.execute("INSERT INTO http_auth_key (private, public) VALUES (%s, %s)", (key['private'], key['public']))
dc.close()
db.commit()
return key
def signed_get(url, timeout = 10, additional_headers = {}, request_json = True):
headers = {}
if request_json:
headers = {
"Accept": "application/json",
"Content-Type": "application/json"
}
headers = {**headers, **additional_headers}
# sign request headers
key = RSA.importKey(get_key()['private'])
sigstring = ''
for header, value in headers.items():
sigstring += '{}: {}\n'.format(header.lower(), value)
sigstring.rstrip("\n")
pkcs = PKCS1_v1_5.new(key)
h = SHA256.new()
h.update(sigstring.encode('ascii'))
signed_sigstring = b64encode(pkcs.sign(h)).decode('ascii')
sig = {
'keyId': "{}/actor".format(cfg['base_uri']),
'algorithm': 'rsa-sha256',
'headers': ' '.join(headers.keys()),
'signature': signed_sigstring
}
sig_header = ['{}="{}"'.format(k, v) for k, v in sig.items()]
headers['signature'] = ','.join(sig_header)
r = requests.Request('GET', url, headers)
return r.headers
# return requests.get(url, timeout = timeout)