diff --git a/functions.py b/functions.py index 5321941..00bc351 100644 --- a/functions.py +++ b/functions.py @@ -1,5 +1,14 @@ from bs4 import BeautifulSoup -import html, re +import MySQLdb +import markovify +from mastodon import Mastodon +import html, re, json + +cfg = json.load(open('config.json')) + +class nlt_fixed(markovify.NewlineText): # modified version of NewlineText that never rejects sentences + def test_sentence_input(self, sentence): + return True # all sentences are valid <3 def extract_post(post): post = html.unescape(post) # convert HTML escape codes to text @@ -24,3 +33,70 @@ def extract_post(post): text = re.sub("https://([^/]+)/users/([^ ]+)", r"@\2@\1", text) # put pleroma-style mentions back in text = text.rstrip("\n") # remove trailing newline(s) return text + +def make_post(handle): + handle = handle[0] + db = MySQLdb.connect( + host = cfg['db_host'], + user=cfg['db_user'], + passwd=cfg['db_pass'], + db=cfg['db_name'] + ) + print("Generating post for {}".format(handle)) + c = db.cursor() + c.execute(""" + SELECT + learn_from_cw, client_id, client_secret, secret + FROM + bots, credentials + WHERE + bots.credentials_id = (SELECT + credentials_id + FROM + bots + WHERE + handle = %s) + """, (handle,)) + + bot = c.fetchone() + client = Mastodon( + client_id = bot[1], + client_secret = bot[2], + access_token = bot[3], + api_base_url = "https://{}".format(handle.split("@")[2]) + ) + + # by default, only select posts that don't have CWs. + # if learn_from_cw, then also select posts with CWs + cw_list = [False] + if bot[0]: + cw_list = [False, True] + + # select 1000 random posts for the bot to learn from + c.execute("SELECT content FROM posts WHERE fedi_id IN (SELECT fedi_id FROM bot_learned_accounts WHERE bot_id = %s) AND cw IN %s ORDER BY RAND() LIMIT 1000", (handle, cw_list)) + + # this line is a little gross/optimised but here's what it does + # 1. fetch all of the results from the above query + # 2. turn (('this',), ('format')) into ('this', 'format') + # 3. convert the tuple to a list + # 4. join the list into a string separated by newlines + posts = "\n".join(list(sum(c.fetchall(), ()))) + + model = nlt_fixed(posts) + tries = 0 + sentence = None + # even with such a high tries value for markovify, it still sometimes returns none. + # so we implement our own tries function as well, and try ten times. + while sentence is None and tries < 10: + sentence = model.make_short_sentence(500, tries = 10000) + tries += 1 + + # TODO: mention handling + + if sentence == None: + # TODO: send an error email + pass + else: + client.status_post(sentence) + + # TODO: update date of last post diff --git a/service.py b/service.py index f68af7f..1c75fea 100755 --- a/service.py +++ b/service.py @@ -9,10 +9,6 @@ import functions cfg = json.load(open('config.json')) -class nlt_fixed(markovify.NewlineText): # modified version of NewlineText that never rejects sentences - def test_sentence_input(self, sentence): - return True # all sentences are valid <3 - def scrape_posts(account): handle = account[0] outbox = account[1] @@ -88,67 +84,6 @@ def scrape_posts(account): db.commit() c.close() -def make_post(handle): - handle = handle[0] - print("Generating post for {}".format(handle)) - c = db.cursor() - c.execute(""" - SELECT - learn_from_cw, client_id, client_secret, secret - FROM - bots, credentials - WHERE - bots.credentials_id = (SELECT - credentials_id - FROM - bots - WHERE - handle = %s) - """, (handle,)) - - bot = c.fetchone() - client = Mastodon( - client_id = bot[1], - client_secret = bot[2], - access_token = bot[3], - api_base_url = "https://{}".format(handle.split("@")[2]) - ) - - # by default, only select posts that don't have CWs. - # if learn_from_cw, then also select posts with CWs - cw_list = [False] - if bot[0]: - cw_list = [False, True] - - # select 1000 random posts for the bot to learn from - c.execute("SELECT content FROM posts WHERE fedi_id IN (SELECT fedi_id FROM bot_learned_accounts WHERE bot_id = %s) AND cw IN %s ORDER BY RAND() LIMIT 1000", (handle, cw_list)) - - # this line is a little gross/optimised but here's what it does - # 1. fetch all of the results from the above query - # 2. turn (('this',), ('format')) into ('this', 'format') - # 3. convert the tuple to a list - # 4. join the list into a string separated by newlines - posts = "\n".join(list(sum(c.fetchall(), ()))) - - model = nlt_fixed(posts) - tries = 0 - sentence = None - # even with such a high tries value for markovify, it still sometimes returns none. - # so we implement our own tries function as well, and try ten times. - while sentence is None and tries < 10: - sentence = model.make_short_sentence(500, tries = 10000) - tries += 1 - - # TODO: mention handling - - if sentence == None: - # TODO: send an error email - pass - else: - client.status_post(sentence) - - # TODO: update date of last post - print("Establishing DB connection") db = MySQLdb.connect( host = cfg['db_host'], @@ -173,6 +108,6 @@ cursor.execute("SELECT handle FROM bots WHERE enabled = TRUE") bots = cursor.fetchall() with Pool(8) as p: - p.map(make_post, bots) + p.map(functions.make_post, bots) #TODO: other cron tasks should be done here, like updating profile pictures