Added polls and reconnect. Closes #1 & #7, I hope-

This commit is contained in:
Casey 2022-08-26 18:37:36 +03:00
parent 239957bb81
commit 8088cca8f0
Signed by: hkc
GPG Key ID: F0F6CFE11CDB0960
10 changed files with 147 additions and 47 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@ __pycache__
config-*.ini
venv
# :3
tmp.py

5
TODO
View File

@ -1,6 +1,7 @@
[integrations,core] Add database support so remote messages are stored and can be used to reply to them
[integrations,discord] Add Discord functionality
[core] Somehow find a way to get your user ID by token
[core] Maybe get rid of `main.list` field and create one automatically on a startup?
[integrations] Add support for shellscript integration
[integrations,telegram] Add formatting option
[integrations] Add formatting option
[integrations] Add filters
[integrations,vk] Add VK integration

View File

@ -28,6 +28,11 @@ user = 107914495779447227
; address bar while you have that list open)
list = 1
; Should we automatically reconnect to the streaming socket?
; That option exists because it's not really a big deal when crossposter runs
; as a service and restarts automatically by the service manager.
auto-reconnect = yes
; Example Telegram integration. You can use it as a template
[module/telegram]

View File

@ -0,0 +1,33 @@
from asyncio import gather
from configparser import ConfigParser
from typing import List, Optional
from mastoposter.integrations.base import BaseIntegration
from mastoposter.integrations import DiscordIntegration, TelegramIntegration
from mastoposter.types import Status
def load_integrations_from(config: ConfigParser) -> List[BaseIntegration]:
modules: List[BaseIntegration] = []
for module_name in config.get("main", "modules").split():
module = config[f"module/{module_name}"]
if module["type"] == "telegram":
modules.append(
TelegramIntegration(
token=module["token"],
chat_id=module["chat"],
show_post_link=module.getboolean("show_post_link", fallback=True),
show_boost_from=module.getboolean("show_boost_from", fallback=True),
)
)
elif module["type"] == "discord":
modules.append(DiscordIntegration(webhook=module["webhook"]))
else:
raise ValueError("Invalid module type %r" % module["type"])
return modules
async def execute_integrations(
status: Status, sinks: List[BaseIntegration]
) -> List[Optional[str]]:
return await gather(*[sink.post(status) for sink in sinks], return_exceptions=True)

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
from asyncio import run
from configparser import ConfigParser
from mastoposter.integrations import DiscordIntegration, TelegramIntegration
from mastoposter import execute_integrations, load_integrations_from
from mastoposter.sources import websocket_source
from typing import AsyncGenerator, Callable, List
from mastoposter.integrations.base import BaseIntegration
@ -30,8 +30,7 @@ async def listen(
):
continue
for drain in drains:
await drain.post(status)
await execute_integrations(status, drains)
def main(config_path: str):
@ -49,22 +48,7 @@ def main(config_path: str):
for k in _remove:
del conf[section][k]
modules: List[BaseIntegration] = []
for module_name in conf.get("main", "modules").split():
module = conf[f"module/{module_name}"]
if module["type"] == "telegram":
modules.append(
TelegramIntegration(
token=module["token"],
chat_id=module["chat"],
show_post_link=module.getboolean("show_post_link", fallback=True),
show_boost_from=module.getboolean("show_boost_from", fallback=True),
)
)
elif module["type"] == "discord":
modules.append(DiscordIntegration(webhook=module["webhook"]))
else:
raise ValueError("Invalid module type %r" % module["type"])
modules = load_integrations_from(conf)
url = "wss://{}/api/v1/streaming".format(conf["main"]["instance"])
run(
@ -73,6 +57,7 @@ def main(config_path: str):
modules,
conf["main"]["user"],
url=url,
reconnect=conf["main"].getboolean("auto_reconnect", fallback=False),
list=conf["main"]["list"],
access_token=conf["main"]["token"],
)

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from mastoposter.types import Status
@ -8,5 +9,5 @@ class BaseIntegration(ABC):
pass
@abstractmethod
async def post(self, status: Status) -> str:
async def post(self, status: Status) -> Optional[str]:
raise NotImplemented

View File

@ -66,7 +66,7 @@ class DiscordIntegration(BaseIntegration):
)
).json()
async def post(self, status: Status) -> str:
async def post(self, status: Status) -> Optional[str]:
source = status.reblog or status
embeds: List[DiscordEmbed] = []
@ -111,4 +111,4 @@ class DiscordIntegration(BaseIntegration):
embeds=embeds,
)
return ""
return None

View File

@ -4,7 +4,7 @@ from typing import Any, List, Mapping, Optional, Union
from bs4 import BeautifulSoup, Tag, PageElement
from httpx import AsyncClient
from mastoposter.integrations.base import BaseIntegration
from mastoposter.types import Attachment, Status
from mastoposter.types import Attachment, Poll, Status
@dataclass
@ -111,6 +111,20 @@ class TelegramIntegration(BaseIntegration):
media=media_list,
)
async def _post_poll(
self, poll: Poll, reply_to: Optional[str] = None
) -> TGResponse:
return await self._tg_request(
"sendPoll",
disable_notification=True,
disable_web_page_preview=True,
chat_id=self.chat_id,
question=f"Poll:{poll.id}",
reply_to_message_id=reply_to,
allow_multiple_answers=poll.multiple,
options=[opt.title for opt in poll.options],
)
@classmethod
def node_to_text(cls, el: PageElement) -> str:
if isinstance(el, Tag):
@ -126,7 +140,7 @@ class TelegramIntegration(BaseIntegration):
return str.join("", map(cls.node_to_text, el.children))
return escape(str(el))
async def post(self, status: Status) -> str:
async def post(self, status: Status) -> Optional[str]:
source = status.reblog or status
text = self.node_to_text(BeautifulSoup(source.content, features="lxml"))
text = text.rstrip()
@ -148,20 +162,33 @@ class TelegramIntegration(BaseIntegration):
+ text
)
ids = []
if not source.media_attachments:
msg = await self._post_plaintext(text)
if (res := await self._post_plaintext(text)).ok:
if res.result:
ids.append(res.result["message_id"])
elif len(source.media_attachments) == 1:
msg = await self._post_media(text, source.media_attachments[0])
if (
res := await self._post_media(text, source.media_attachments[0])
).ok and res.result is not None:
ids.append(res.result["message_id"])
else:
msg = await self._post_mediagroup(text, source.media_attachments)
if (
res := await self._post_mediagroup(text, source.media_attachments)
).ok and res.result is not None:
ids.append(res.result["message_id"])
if not msg.ok:
# raise Exception(msg.error, msg.params)
return "" # XXX: silently ignore for now
if source.poll:
if (
res := await self._post_poll(
source.poll, reply_to=ids[0] if ids else None
)
).ok and res.result:
ids.append(res.result["message_id"])
if msg.result:
return msg.result.get("message_id", "")
return ""
return str.join(",", map(str, ids))
def __repr__(self) -> str:
return (

View File

@ -2,14 +2,18 @@ from json import loads
from typing import AsyncGenerator
from urllib.parse import urlencode
from mastoposter.types import Status
async def websocket_source(url: str, **params) -> AsyncGenerator[Status, None]:
async def websocket_source(
url: str, reconnect: bool = False, **params
) -> AsyncGenerator[Status, None]:
from websockets.client import connect
from websockets.exceptions import WebSocketException
url = f"{url}?" + urlencode({"stream": "list", **params})
while True:
try:
async with connect(url) as ws:
while (msg := await ws.recv()) != None:
event = loads(msg)
@ -17,3 +21,6 @@ async def websocket_source(url: str, **params) -> AsyncGenerator[Status, None]:
raise Exception(event["error"])
if event["event"] == "update":
yield Status.from_dict(loads(event["payload"]))
except WebSocketException:
if not reconnect:
raise

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional, List, Literal
@ -208,6 +208,43 @@ class Tag:
return cls(**data)
@dataclass
class Poll:
@dataclass
class PollOption:
title: str
votes_count: Optional[int] = None
id: str
expires_at: Optional[datetime]
expired: bool
multiple: bool
votes_count: int
voters_count: Optional[int] = None
options: List[PollOption] = field(default_factory=list)
emojis: List[Emoji] = field(default_factory=list)
@classmethod
def from_dict(cls, data: dict) -> "Poll":
return cls(
id=data["id"],
expires_at=(
datetime.fromisoformat(data["expires_at"].rstrip("Z"))
if data.get("expires_at") is not None
else None
),
expired=data["expired"],
multiple=data["multiple"],
votes_count=data["votes_count"],
voters_count=(
int(data["voters_count"])
if data.get("voters_count") is not None
else None
),
options=[cls.PollOption(**opt) for opt in data["options"]],
)
@dataclass
class Status:
id: str
@ -227,7 +264,7 @@ class Status:
in_reply_to_id: Optional[str] = None
in_reply_to_account_id: Optional[str] = None
reblog: Optional["Status"] = None
poll: Optional[dict] = None
poll: Optional[Poll] = None
card: Optional[dict] = None
language: Optional[str] = None
text: Optional[str] = None
@ -262,7 +299,9 @@ class Status:
if data.get("reblog") is not None
else None
),
poll=data.get("poll"),
poll=(
Poll.from_dict(data["poll"]) if data.get("poll") is not None else None
),
card=data.get("card"),
language=data.get("language"),
text=data.get("text"),