Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions plextraktsync/commands/watch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from click import ClickException
from trakt.errors import OAuthRefreshException

from plextraktsync.factory import factory
from plextraktsync.watch.events import (
ActivityNotification,
Expand All @@ -11,26 +14,30 @@


def watch(server: str):
factory.run_config.update(
server=server,
)
ws = factory.web_socket_listener
updater = factory.watch_state_updater
try:
factory.watch_fatal_error.clear()
factory.run_config.update(
server=server,
)
ws = factory.web_socket_listener
updater = factory.watch_state_updater

ws.on(ServerStarted, updater.on_start)
ws.on(
PlaySessionStateNotification,
updater.on_play,
state=["playing", "stopped", "paused"],
)
ws.on(
ActivityNotification,
updater.on_activity,
type="library.refresh.items",
event="ended",
progress=100,
)
ws.on(TimelineEntry, updater.on_delete, state=9, metadata_state="deleted")
ws.on(Error, updater.on_error)
ws.on(ServerStarted, updater.on_start)
ws.on(
PlaySessionStateNotification,
updater.on_play,
state=["playing", "stopped", "paused"],
)
ws.on(
ActivityNotification,
updater.on_activity,
type="library.refresh.items",
event="ended",
progress=100,
)
ws.on(TimelineEntry, updater.on_delete, state=9, metadata_state="deleted")
ws.on(Error, updater.on_error)

ws.listen()
ws.listen()
except OAuthRefreshException as e:
raise ClickException(f"Trakt error: Unable to refresh token: {e}") from e
Comment thread
simonc56 marked this conversation as resolved.
10 changes: 8 additions & 2 deletions plextraktsync/factory/Factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def enable_self_update(self):
def web_socket_listener(self):
from plextraktsync.watch.WebSocketListener import WebSocketListener

return WebSocketListener(plex=self.plex_server)
return WebSocketListener(plex=self.plex_server, fatal_error=self.watch_fatal_error)

@cached_property
def watch_state_updater(self):
Expand All @@ -224,6 +224,12 @@ def watch_state_updater(self):
config=self.config,
)

@cached_property
def watch_fatal_error(self):
from plextraktsync.watch.FatalErrorState import FatalErrorState

return FatalErrorState()

@cached_property
def logging(self):
import logging
Expand Down Expand Up @@ -324,7 +330,7 @@ def queue(self):
TraktMarkWatchedWorker(),
TraktScrobbleWorker(),
]
task = BackgroundTask(self.batch_delay_timer, *workers)
task = BackgroundTask(self.batch_delay_timer, *workers, fatal_error=self.watch_fatal_error)
queue = Queue(task)

return queue
Expand Down
4 changes: 3 additions & 1 deletion plextraktsync/media/MediaFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from plexapi.exceptions import PlexApiException
from requests import RequestException
from trakt.errors import TraktException
from trakt.errors import OAuthRefreshException, TraktException

from plextraktsync.factory import logging
from plextraktsync.media.Media import Media
Expand Down Expand Up @@ -66,6 +66,8 @@ def resolve_guid(self, guid: PlexGuid, show: Media = None):
tm = self.trakt.find_episode_guid(guid, show.seasons)
else:
tm = self.trakt.find_by_guid(guid)
except OAuthRefreshException:
raise
except (TraktException, RequestException) as e:
self.logger.warning(
f"{guid.title_link}: Skipping {guid}: Trakt errors: {e}",
Expand Down
16 changes: 15 additions & 1 deletion plextraktsync/queue/BackgroundTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from queue import Empty
from typing import TYPE_CHECKING

from trakt.errors import OAuthRefreshException

from plextraktsync.factory import logging

if TYPE_CHECKING:
Expand All @@ -20,10 +22,11 @@ class BackgroundTask:

logger = logging.getLogger(__name__)

def __init__(self, timer: Timer = None, *tasks):
def __init__(self, timer: Timer = None, *tasks, fatal_error=None):
self.queues = defaultdict(list)
self.timer = timer
self.tasks = tasks
self.fatal_error = fatal_error

def check_timer(self):
if not self.timer:
Expand All @@ -40,6 +43,11 @@ def timed_events(self):
for task in self.tasks:
try:
task(self.queues)
except OAuthRefreshException as e:
if self.fatal_error is not None:
self.fatal_error.set(e)
return
raise
Comment on lines +46 to +50
except Exception as e:
self.logger.error(f"Got exception while working on {task}: {e}")

Expand All @@ -60,6 +68,9 @@ def __call__(self, queue: SimpleQueue):
"""

while True:
if self.fatal_error is not None:
self.fatal_error.raise_if_set()

try:
message = queue.get(timeout=1)
except Empty:
Expand All @@ -71,3 +82,6 @@ def __call__(self, queue: SimpleQueue):
self.process_message(message)

self.check_timer()

if self.fatal_error is not None:
self.fatal_error.raise_if_set()
9 changes: 8 additions & 1 deletion plextraktsync/watch/EventDispatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from trakt.errors import OAuthRefreshException

from plextraktsync.factory import logging
from plextraktsync.watch.EventFactory import EventFactory
from plextraktsync.watch.events import Error, ServerStarted
Expand All @@ -8,9 +10,10 @@
class EventDispatcher:
logger = logging.getLogger(__name__)

def __init__(self):
def __init__(self, fatal_error=None):
self.event_listeners = []
self.event_factory = EventFactory()
self.fatal_error = fatal_error

def on(self, event_type, listener, **kwargs):
self.event_listeners.append(
Expand Down Expand Up @@ -38,6 +41,10 @@ def dispatch(self, event):

try:
listener["listener"](event)
except OAuthRefreshException as e:
if self.fatal_error is not None:
self.fatal_error.set(e)
raise
except Exception as e:
self.logger.error(f"{type(e).__name__} was raised: {e}")

Expand Down
25 changes: 25 additions & 0 deletions plextraktsync/watch/FatalErrorState.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from threading import Lock


class FatalErrorState:
def __init__(self):
self._error = None
self._lock = Lock()

def clear(self):
with self._lock:
self._error = None

def set(self, error: Exception):
with self._lock:
if self._error is None:
self._error = error

def raise_if_set(self):
with self._lock:
error = self._error

if error is not None:
raise error
10 changes: 8 additions & 2 deletions plextraktsync/watch/WebSocketListener.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,29 @@
class WebSocketListener:
logger = logging.getLogger(__name__)

def __init__(self, plex: PlexServer, poll_interval=5, restart_interval=15):
def __init__(self, plex: PlexServer, poll_interval=5, restart_interval=15, fatal_error=None):
self.plex = plex
self.poll_interval = poll_interval
self.restart_interval = restart_interval
self.dispatcher = EventDispatcher()
self.fatal_error = fatal_error
self.dispatcher = EventDispatcher(fatal_error=fatal_error)

def on(self, event_type, listener, **kwargs):
self.dispatcher.on(event_type, listener, **kwargs)

def listen(self):
self.logger.info("Listening for events!")
while True:
if self.fatal_error is not None:
self.fatal_error.raise_if_set()

notifier = self.plex.startAlertListener(callback=self.dispatcher.event_handler)
self.dispatcher.event_handler(ServerStarted(notifier=notifier))

while notifier.is_alive():
sleep(self.poll_interval)
if self.fatal_error is not None:
self.fatal_error.raise_if_set()

self.dispatcher.event_handler(Error(msg="Server closed connection"))
self.logger.error(f"Listener finished. Restarting in {self.restart_interval} seconds")
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from os.path import dirname
from os.path import join as join_path

from trakt.errors import OAuthRefreshException
from trakt.tv import TVShow

from plextraktsync.factory import Factory
Expand Down Expand Up @@ -35,3 +36,17 @@ def make(cls=None, **kwargs) -> TVShow:
cls = cls if cls is not None else "object"
# https://stackoverflow.com/a/2827726/2314626
return type(cls, (object,), kwargs)


def make_oauth_refresh_exception(
error: str = "invalid_grant",
error_description: str = "The provided authorization grant is invalid.",
) -> OAuthRefreshException:
response = make(
cls="Response",
json=lambda self: {
"error": error,
"error_description": error_description,
},
)()
return OAuthRefreshException(response)
30 changes: 29 additions & 1 deletion tests/test_events.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#!/usr/bin/env python3 -m pytest
from __future__ import annotations

import pytest
from trakt.errors import OAuthRefreshException

from plextraktsync.watch.EventDispatcher import EventDispatcher
from plextraktsync.watch.EventFactory import EventFactory
from plextraktsync.watch.events import ActivityNotification
from tests.conftest import load_mock
from plextraktsync.watch.FatalErrorState import FatalErrorState
from tests.conftest import load_mock, make_oauth_refresh_exception


def test_events():
Expand Down Expand Up @@ -73,3 +77,27 @@ def test_event_dispatcher():
dispatcher = EventDispatcher().on(ActivityNotification, lambda x: events.append(x), event=["ended"], progress=99)
dispatcher.event_handler(raw_events[4])
assert len(events) == 0, "No match for event=ended and progress=99"


def test_event_dispatcher_reraises_oauth_refresh_exception():
fatal_error = FatalErrorState()
dispatcher = EventDispatcher(fatal_error=fatal_error).on(
ActivityNotification,
lambda _: (_ for _ in ()).throw(make_oauth_refresh_exception()),
)

with pytest.raises(OAuthRefreshException):
dispatcher.dispatch(
ActivityNotification(
Activity={
"event": "ended",
"type": "library.refresh.items",
"progress": 100,
"Context": {"key": "/library/metadata/1"},
},
event="ended",
)
)

with pytest.raises(OAuthRefreshException):
fatal_error.raise_if_set()
73 changes: 73 additions & 0 deletions tests/test_watch_fatal_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3 -m pytest
from __future__ import annotations

import pytest
from trakt.errors import OAuthRefreshException

from plextraktsync.queue.BackgroundTask import BackgroundTask
from plextraktsync.watch.FatalErrorState import FatalErrorState
from plextraktsync.watch.WebSocketListener import WebSocketListener
from tests.conftest import make_oauth_refresh_exception


def test_background_task_records_oauth_refresh_exception():
fatal_error = FatalErrorState()

def task(_queues):
raise make_oauth_refresh_exception()

background_task = BackgroundTask(None, task, fatal_error=fatal_error)
background_task.timed_events()

with pytest.raises(OAuthRefreshException):
fatal_error.raise_if_set()


def test_background_task_continues_without_fatal_error():
fatal_error = FatalErrorState()
calls = []

class Queue:
def __init__(self):
self.messages = iter([
("test", 1),
None,
])

def get(self, timeout):
return next(self.messages)

def task(queues):
calls.append(list(queues["test"]))

background_task = BackgroundTask(None, task, fatal_error=fatal_error)
background_task(Queue())

assert calls == [[1]]


def test_websocket_listener_raises_recorded_oauth_refresh_exception(monkeypatch):
fatal_error = FatalErrorState()

class Notifier:
def __init__(self):
self.calls = 0

def is_alive(self):
self.calls += 1
return self.calls == 1

class Plex:
def startAlertListener(self, callback):
self.callback = callback
return Notifier()

def fail_sleep(_interval):
fatal_error.set(make_oauth_refresh_exception())

monkeypatch.setattr("plextraktsync.watch.WebSocketListener.sleep", fail_sleep)

listener = WebSocketListener(Plex(), poll_interval=0, fatal_error=fatal_error)

with pytest.raises(OAuthRefreshException):
listener.listen()
Loading