Compare commits

...

5 Commits

2 changed files with 100 additions and 16 deletions

View File

@ -37,7 +37,7 @@ from mastoposter import (
__description__, __description__,
) )
from mastoposter.integrations import FilteredIntegration from mastoposter.integrations import FilteredIntegration
from mastoposter.sources import websocket_source from mastoposter.sources import websocket_source, single_status_source
from mastoposter.types import Account, Status from mastoposter.types import Account, Status
from mastoposter.utils import normalize_config from mastoposter.utils import normalize_config
@ -64,8 +64,8 @@ async def listen(
source: Callable[..., AsyncGenerator[Status, None]], source: Callable[..., AsyncGenerator[Status, None]],
drains: List[FilteredIntegration], drains: List[FilteredIntegration],
user: str, user: str,
replies_to_other_accounts_should_not_be_skipped: bool = False,
/, /,
replies_to_other_accounts_should_not_be_skipped: bool = False,
**kwargs, **kwargs,
): ):
logger.info("Starting listening...") logger.info("Starting listening...")
@ -110,6 +110,14 @@ def main():
"config", nargs="?", default=getenv("MASTOPOSTER_CONFIG_FILE") "config", nargs="?", default=getenv("MASTOPOSTER_CONFIG_FILE")
) )
parser.add_argument("-v", action="version", version=__version__) parser.add_argument("-v", action="version", version=__version__)
parser.add_argument(
"--single-status", nargs="?", type=str,
help="process single status and exit"
)
parser.add_argument(
"--no-skip-replies", action="store_true",
help="override replies_to_other_accounts_should_not_be_skipped to true"
)
args = parser.parse_args() args = parser.parse_args()
if not args.config: if not args.config:
@ -142,22 +150,38 @@ def main():
"wss://{}/api/v1/streaming".format(conf["main"]["instance"]), "wss://{}/api/v1/streaming".format(conf["main"]["instance"]),
) )
replies_to_other_accounts_should_not_be_skipped = conf[
"main"
].getboolean(
"replies_to_other_accounts_should_not_be_skipped", False
)
if args.no_skip_replies:
replies_to_other_accounts_should_not_be_skipped = True
source = websocket_source
source_params = dict(
url=url,
replies_to_other_accounts_should_not_be_skipped=(
replies_to_other_accounts_should_not_be_skipped),
reconnect=conf["main"].getboolean("auto_reconnect", False),
reconnect_delay=conf["main"].getfloat("reconnect_delay", 1.0),
connect_timeout=conf["main"].getfloat("connect_timeout", 60.0),
list=conf["main"]["list"],
access_token=conf["main"]["token"],
)
if args.single_status:
source = single_status_source
source_params["status_url"] = args.single_status
source_params["retries"] = retries
run( run(
listen( listen(
websocket_source, source,
modules, modules,
user_id, user_id,
url=url, **source_params
replies_to_other_accounts_should_not_be_skipped=conf[
"main"
].getboolean(
"replies_to_other_accounts_should_not_be_skipped", False
),
reconnect=conf["main"].getboolean("auto_reconnect", False),
reconnect_delay=conf["main"].getfloat("reconnect_delay", 1.0),
connect_timeout=conf["main"].getfloat("connect_timeout", 60.0),
list=conf["main"]["list"],
access_token=conf["main"]["token"],
) )
) )

View File

@ -16,8 +16,8 @@ GNU General Public License for more details.
from asyncio import exceptions, sleep from asyncio import exceptions, sleep
from json import loads from json import loads
from logging import getLogger from logging import getLogger
from typing import AsyncGenerator from typing import AsyncGenerator, List
from urllib.parse import urlencode from urllib.parse import urlencode, urlparse
from mastoposter.types import Status from mastoposter.types import Status
logger = getLogger("sources") logger = getLogger("sources")
@ -66,3 +66,63 @@ async def websocket_source(
"but we're not done yet" "but we're not done yet"
) )
await sleep(reconnect_delay) await sleep(reconnect_delay)
async def single_status_source(
status_url: str, url: str = None, access_token: str = None,
retries: int = 5, **kwargs
) -> AsyncGenerator[Status, None]:
# TODO: catch exceptions
from httpx import Client, HTTPTransport
user_authority = urlparse(url).netloc if url is not None else None
try:
status_url = \
f"https://{user_authority}/api/v1/statuses/{int(status_url)}"
except ValueError:
pass
parsed_status_url = urlparse(status_url)
with Client(transport=HTTPTransport(retries=retries)) as c:
status: Status
if parsed_status_url.path.startswith("/api/v1/statuses/"):
if parsed_status_url.netloc != user_authority:
access_token = None
# headers = {}
# if access_token is not None:
# headers['Authorization'] = 'Bearer ' + access_token
params = {}
if access_token is not None:
params['access_token'] = access_token
rq = c.get(
status_url,
params=params,
)
status = Status.from_dict(rq.json())
else:
search_instance = user_authority if user_authority is not None \
else parsed_status_url.netloc
if search_instance != user_authority:
access_token = None
params = {}
if access_token is not None:
params["access_token"] = access_token
params["q"] = status_url
params["resolve"] = "true"
rq = c.get(
f"https://{search_instance}/api/v2/search",
params=params,
)
statuses: List[dict] = rq.json().get("statuses", [])
if len(statuses) < 1:
logger.error("Instance %s hasn't found status %r",
search_instance, status_url)
return
status = Status.from_dict(statuses[0])
yield status