mirror of
https://github.com/alexta69/metube.git
synced 2026-06-13 16:40:05 +00:00
add subscriptions; change persistence file format to JSON (closes #901, #76, #113, #170, #242, #444, #503, #555, #566)
This commit is contained in:
+167
-28
@@ -17,6 +17,7 @@ import re
|
||||
from watchfiles import DefaultFilter, Change, awatch
|
||||
|
||||
from ytdl import DownloadQueueNotifier, DownloadQueue, Download
|
||||
from subscriptions import SubscriptionManager, SubscriptionNotifier, SubscriptionInfo
|
||||
from yt_dlp.version import __version__ as yt_dlp_version
|
||||
|
||||
log = logging.getLogger('main')
|
||||
@@ -50,6 +51,9 @@ class Config:
|
||||
'OUTPUT_TEMPLATE_PLAYLIST': '%(playlist_title)s/%(title)s.%(ext)s',
|
||||
'OUTPUT_TEMPLATE_CHANNEL': '%(channel)s/%(title)s.%(ext)s',
|
||||
'DEFAULT_OPTION_PLAYLIST_ITEM_LIMIT' : '0',
|
||||
'SUBSCRIPTION_DEFAULT_CHECK_INTERVAL': '60',
|
||||
'SUBSCRIPTION_SCAN_PLAYLIST_END': '50',
|
||||
'SUBSCRIPTION_MAX_SEEN_IDS': '50000',
|
||||
'CLEAR_COMPLETED_AFTER': '0',
|
||||
'YTDL_OPTIONS': '{}',
|
||||
'YTDL_OPTIONS_FILE': '',
|
||||
@@ -114,6 +118,7 @@ class Config:
|
||||
'PUBLIC_HOST_URL',
|
||||
'PUBLIC_HOST_AUDIO_URL',
|
||||
'DEFAULT_OPTION_PLAYLIST_ITEM_LIMIT',
|
||||
'SUBSCRIPTION_DEFAULT_CHECK_INTERVAL',
|
||||
)
|
||||
|
||||
def frontend_safe(self) -> dict:
|
||||
@@ -272,6 +277,34 @@ dqueue = DownloadQueue(config, Notifier())
|
||||
app.on_startup.append(lambda app: dqueue.initialize())
|
||||
app.on_cleanup.append(lambda app: Download.shutdown_manager())
|
||||
|
||||
|
||||
class MetubeSubscriptionNotifier(SubscriptionNotifier):
|
||||
async def subscription_added(self, sub: SubscriptionInfo):
|
||||
log.info("Subscription added: %s", sub.name)
|
||||
await sio.emit('subscription_added', serializer.encode(sub.to_public_dict()))
|
||||
|
||||
async def subscription_updated(self, sub: SubscriptionInfo):
|
||||
await sio.emit('subscription_updated', serializer.encode(sub.to_public_dict()))
|
||||
|
||||
async def subscription_removed(self, sub_id: str):
|
||||
log.info("Subscription removed: %s", sub_id)
|
||||
await sio.emit('subscription_removed', serializer.encode(sub_id))
|
||||
|
||||
async def subscriptions_all(self, subs: list[SubscriptionInfo]):
|
||||
await sio.emit('subscriptions_all', serializer.encode([s.to_public_dict() for s in subs]))
|
||||
|
||||
|
||||
submgr = SubscriptionManager(config, dqueue, MetubeSubscriptionNotifier())
|
||||
app.on_cleanup.append(lambda app: submgr.close())
|
||||
|
||||
|
||||
async def _subscription_loop_startup(app):
|
||||
"""aiohttp on_startup requires awaitable receivers; start_background_loop is sync."""
|
||||
submgr.start_background_loop()
|
||||
|
||||
|
||||
app.on_startup.append(_subscription_loop_startup)
|
||||
|
||||
class FileOpsFilter(DefaultFilter):
|
||||
def __call__(self, change_type: int, path: str) -> bool:
|
||||
# Check if this path matches our YTDL_OPTIONS_FILE
|
||||
@@ -332,27 +365,17 @@ async def _read_json_request(request: web.Request) -> dict:
|
||||
return post
|
||||
|
||||
|
||||
@routes.post(config.URL_PREFIX + 'add')
|
||||
async def add(request):
|
||||
log.info("Received request to add download")
|
||||
post = await _read_json_request(request)
|
||||
post = _migrate_legacy_request(post)
|
||||
log.info(
|
||||
"Add download request: type=%s quality=%s format=%s has_folder=%s auto_start=%s",
|
||||
post.get('download_type'),
|
||||
post.get('quality'),
|
||||
post.get('format'),
|
||||
bool(post.get('folder')),
|
||||
post.get('auto_start'),
|
||||
)
|
||||
def parse_download_options(post: dict) -> dict:
|
||||
"""Validate add/subscribe body; raise HTTPBadRequest on invalid input."""
|
||||
post = _migrate_legacy_request(dict(post))
|
||||
url = post.get('url')
|
||||
download_type = post.get('download_type')
|
||||
codec = post.get('codec')
|
||||
format = post.get('format')
|
||||
quality = post.get('quality')
|
||||
if not url or not quality or not download_type:
|
||||
log.error("Bad request: missing 'url', 'download_type', or 'quality'")
|
||||
raise web.HTTPBadRequest()
|
||||
raise web.HTTPBadRequest(reason="missing 'url', 'download_type', or 'quality'")
|
||||
url = str(url).strip()
|
||||
folder = post.get('folder')
|
||||
custom_name_prefix = post.get('custom_name_prefix')
|
||||
playlist_item_limit = post.get('playlist_item_limit')
|
||||
@@ -429,20 +452,54 @@ async def add(request):
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise web.HTTPBadRequest(reason='playlist_item_limit must be an integer') from exc
|
||||
|
||||
return {
|
||||
'url': url,
|
||||
'download_type': download_type,
|
||||
'codec': codec,
|
||||
'format': format,
|
||||
'quality': quality,
|
||||
'folder': folder,
|
||||
'custom_name_prefix': custom_name_prefix,
|
||||
'playlist_item_limit': playlist_item_limit,
|
||||
'auto_start': auto_start,
|
||||
'split_by_chapters': split_by_chapters,
|
||||
'chapter_template': chapter_template,
|
||||
'subtitle_language': subtitle_language,
|
||||
'subtitle_mode': subtitle_mode,
|
||||
}
|
||||
|
||||
|
||||
@routes.post(config.URL_PREFIX + 'add')
|
||||
async def add(request):
|
||||
log.info("Received request to add download")
|
||||
post = await _read_json_request(request)
|
||||
try:
|
||||
o = parse_download_options(post)
|
||||
except web.HTTPBadRequest as e:
|
||||
log.error("Bad request: %s", e.reason)
|
||||
raise
|
||||
log.info(
|
||||
"Add download request: type=%s quality=%s format=%s has_folder=%s auto_start=%s",
|
||||
o['download_type'],
|
||||
o['quality'],
|
||||
o['format'],
|
||||
bool(o.get('folder')),
|
||||
o['auto_start'],
|
||||
)
|
||||
status = await dqueue.add(
|
||||
url,
|
||||
download_type,
|
||||
codec,
|
||||
format,
|
||||
quality,
|
||||
folder,
|
||||
custom_name_prefix,
|
||||
playlist_item_limit,
|
||||
auto_start,
|
||||
split_by_chapters,
|
||||
chapter_template,
|
||||
subtitle_language,
|
||||
subtitle_mode,
|
||||
o['url'],
|
||||
o['download_type'],
|
||||
o['codec'],
|
||||
o['format'],
|
||||
o['quality'],
|
||||
o['folder'],
|
||||
o['custom_name_prefix'],
|
||||
o['playlist_item_limit'],
|
||||
o['auto_start'],
|
||||
o['split_by_chapters'],
|
||||
o['chapter_template'],
|
||||
o['subtitle_language'],
|
||||
o['subtitle_mode'],
|
||||
)
|
||||
return web.Response(text=serializer.encode(status))
|
||||
|
||||
@@ -451,6 +508,82 @@ async def cancel_add(request):
|
||||
dqueue.cancel_add()
|
||||
return web.Response(text=serializer.encode({'status': 'ok'}), content_type='application/json')
|
||||
|
||||
|
||||
@routes.post(config.URL_PREFIX + 'subscribe')
|
||||
async def subscribe(request):
|
||||
post = await _read_json_request(request)
|
||||
try:
|
||||
o = parse_download_options(post)
|
||||
except web.HTTPBadRequest:
|
||||
raise
|
||||
cic = post.get('check_interval_minutes')
|
||||
if cic is None:
|
||||
cic = config.SUBSCRIPTION_DEFAULT_CHECK_INTERVAL
|
||||
try:
|
||||
cic = int(cic)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise web.HTTPBadRequest(reason='check_interval_minutes must be an integer') from exc
|
||||
if cic < 1:
|
||||
raise web.HTTPBadRequest(reason='check_interval_minutes must be at least 1')
|
||||
|
||||
result = await submgr.add_subscription(
|
||||
o['url'],
|
||||
check_interval_minutes=cic,
|
||||
download_type=o['download_type'],
|
||||
codec=o['codec'],
|
||||
format=o['format'],
|
||||
quality=o['quality'],
|
||||
folder=o['folder'] or '',
|
||||
custom_name_prefix=o['custom_name_prefix'],
|
||||
auto_start=o['auto_start'],
|
||||
playlist_item_limit=o['playlist_item_limit'],
|
||||
split_by_chapters=o['split_by_chapters'],
|
||||
chapter_template=o['chapter_template'],
|
||||
subtitle_language=o['subtitle_language'],
|
||||
subtitle_mode=o['subtitle_mode'],
|
||||
)
|
||||
return web.Response(text=serializer.encode(result))
|
||||
|
||||
|
||||
@routes.get(config.URL_PREFIX + 'subscriptions')
|
||||
async def subscriptions_list(request):
|
||||
return web.Response(text=serializer.encode([s.to_public_dict() for s in submgr.list_all()]))
|
||||
|
||||
|
||||
@routes.post(config.URL_PREFIX + 'subscriptions/update')
|
||||
async def subscriptions_update(request):
|
||||
post = await _read_json_request(request)
|
||||
sub_id = post.get('id')
|
||||
if not sub_id:
|
||||
raise web.HTTPBadRequest(reason='missing subscription id')
|
||||
changes = {k: v for k, v in post.items() if k != 'id' and k in ('enabled', 'check_interval_minutes', 'name')}
|
||||
if not changes:
|
||||
raise web.HTTPBadRequest(reason='no valid fields to update')
|
||||
log.info("Subscription update requested for %s: %s", sub_id, sorted(changes.keys()))
|
||||
result = await submgr.update_subscription(str(sub_id), changes)
|
||||
return web.Response(text=serializer.encode(result))
|
||||
|
||||
|
||||
@routes.post(config.URL_PREFIX + 'subscriptions/delete')
|
||||
async def subscriptions_delete(request):
|
||||
post = await _read_json_request(request)
|
||||
ids = post.get('ids')
|
||||
if not ids or not isinstance(ids, list):
|
||||
raise web.HTTPBadRequest(reason='missing ids list')
|
||||
result = await submgr.delete_subscriptions([str(i) for i in ids])
|
||||
return web.Response(text=serializer.encode(result))
|
||||
|
||||
|
||||
@routes.post(config.URL_PREFIX + 'subscriptions/check')
|
||||
async def subscriptions_check(request):
|
||||
post = await _read_json_request(request)
|
||||
ids = post.get('ids')
|
||||
if ids is not None and not isinstance(ids, list):
|
||||
raise web.HTTPBadRequest(reason='ids must be a list')
|
||||
log.info("Subscription check-now requested for ids=%s", ids if ids else "all-enabled")
|
||||
result = await submgr.check_now([str(i) for i in ids] if ids else None)
|
||||
return web.Response(text=serializer.encode(result))
|
||||
|
||||
@routes.post(config.URL_PREFIX + 'delete')
|
||||
async def delete(request):
|
||||
post = await _read_json_request(request)
|
||||
@@ -554,6 +687,7 @@ async def history(request):
|
||||
async def connect(sid, environ):
|
||||
log.info(f"Client connected: {sid}")
|
||||
await sio.emit('all', serializer.encode(dqueue.get()), to=sid)
|
||||
await sio.emit('subscriptions_all', serializer.encode([s.to_public_dict() for s in submgr.list_all()]), to=sid)
|
||||
await sio.emit('configuration', serializer.encode(config.frontend_safe()), to=sid)
|
||||
if config.CUSTOM_DIRS:
|
||||
await sio.emit('custom_dirs', serializer.encode(get_custom_dirs()), to=sid)
|
||||
@@ -672,6 +806,11 @@ async def add_cors(request):
|
||||
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'add', add_cors)
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'cancel-add', add_cors)
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'subscribe', add_cors)
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'subscriptions', add_cors)
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'subscriptions/update', add_cors)
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'subscriptions/delete', add_cors)
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'subscriptions/check', add_cors)
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'upload-cookies', add_cors)
|
||||
app.router.add_route('OPTIONS', config.URL_PREFIX + 'delete-cookies', add_cors)
|
||||
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import collections.abc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shelve
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
log = logging.getLogger("state_store")
|
||||
|
||||
STATE_SCHEMA_VERSION = 2
|
||||
_BYTES_MARKER = "__metube_bytes__"
|
||||
_DATETIME_MARKER = "__metube_datetime__"
|
||||
|
||||
|
||||
def to_json_compatible(value: Any) -> Any:
|
||||
if value is None or isinstance(value, (bool, int, float, str)):
|
||||
return value
|
||||
if isinstance(value, bytes):
|
||||
return {_BYTES_MARKER: base64.b64encode(value).decode("ascii")}
|
||||
if isinstance(value, datetime):
|
||||
return {_DATETIME_MARKER: value.isoformat()}
|
||||
if isinstance(value, collections.abc.Mapping):
|
||||
return {str(k): to_json_compatible(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple, set, frozenset)):
|
||||
return [to_json_compatible(v) for v in value]
|
||||
if isinstance(value, collections.abc.Iterable):
|
||||
return [to_json_compatible(v) for v in value]
|
||||
raise TypeError(f"Value of type {type(value).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def from_json_compatible(value: Any) -> Any:
|
||||
if isinstance(value, list):
|
||||
return [from_json_compatible(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
if set(value.keys()) == {_BYTES_MARKER}:
|
||||
return base64.b64decode(value[_BYTES_MARKER].encode("ascii"))
|
||||
if set(value.keys()) == {_DATETIME_MARKER}:
|
||||
return datetime.fromisoformat(value[_DATETIME_MARKER])
|
||||
return {k: from_json_compatible(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
def read_legacy_shelf(path: str) -> Optional[list[tuple[Any, Any]]]:
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
try:
|
||||
with shelve.open(path, "r") as shelf:
|
||||
return list(shelf.items())
|
||||
except Exception as exc:
|
||||
log.warning("Could not read legacy shelf at %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
|
||||
class AtomicJsonStore:
|
||||
def __init__(self, path: str, *, kind: str, schema_version: int = STATE_SCHEMA_VERSION):
|
||||
self.path = path
|
||||
self.kind = kind
|
||||
self.schema_version = schema_version
|
||||
|
||||
def _ensure_parent(self) -> None:
|
||||
parent = os.path.dirname(self.path)
|
||||
if parent and not os.path.isdir(parent):
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
def _build_payload(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
payload = {
|
||||
"schema_version": self.schema_version,
|
||||
"kind": self.kind,
|
||||
}
|
||||
payload.update(data)
|
||||
return payload
|
||||
|
||||
def load(self) -> Optional[dict[str, Any]]:
|
||||
if not os.path.exists(self.path):
|
||||
return None
|
||||
try:
|
||||
with open(self.path, encoding="utf-8") as f:
|
||||
payload = json.load(f)
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("State file must contain a JSON object")
|
||||
if payload.get("kind") != self.kind:
|
||||
raise ValueError(
|
||||
f"State file kind mismatch: expected {self.kind}, got {payload.get('kind')}"
|
||||
)
|
||||
return payload
|
||||
except Exception as exc:
|
||||
self.quarantine_invalid_file(exc)
|
||||
return None
|
||||
|
||||
def save(self, data: dict[str, Any]) -> None:
|
||||
self._ensure_parent()
|
||||
payload = self._build_payload(data)
|
||||
parent = os.path.dirname(self.path) or "."
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
prefix=f".{os.path.basename(self.path)}.",
|
||||
suffix=".tmp",
|
||||
dir=parent,
|
||||
text=True,
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, separators=(",", ":"))
|
||||
f.write("\n")
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, self.path)
|
||||
self._fsync_directory(parent)
|
||||
except Exception:
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
def quarantine_invalid_file(self, exc: Exception) -> None:
|
||||
if not os.path.exists(self.path):
|
||||
return
|
||||
ts = time.strftime("%Y%m%d%H%M%S")
|
||||
backup_path = f"{self.path}.invalid.{ts}"
|
||||
try:
|
||||
os.replace(self.path, backup_path)
|
||||
log.warning(
|
||||
"State file at %s was invalid (%s); moved it to %s",
|
||||
self.path,
|
||||
exc,
|
||||
backup_path,
|
||||
)
|
||||
except OSError as move_exc:
|
||||
log.warning(
|
||||
"State file at %s was invalid (%s) and could not be moved aside: %s",
|
||||
self.path,
|
||||
exc,
|
||||
move_exc,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fsync_directory(path: str) -> None:
|
||||
try:
|
||||
flags = os.O_RDONLY
|
||||
if hasattr(os, "O_DIRECTORY"):
|
||||
flags |= os.O_DIRECTORY
|
||||
fd = os.open(path, flags)
|
||||
except OSError:
|
||||
return
|
||||
try:
|
||||
os.fsync(fd)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
os.close(fd)
|
||||
@@ -0,0 +1,666 @@
|
||||
"""Channel/playlist subscriptions: periodic yt-dlp flat extract + queue new videos."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Optional
|
||||
|
||||
import yt_dlp
|
||||
import yt_dlp.networking.impersonate
|
||||
from state_store import AtomicJsonStore, read_legacy_shelf
|
||||
|
||||
log = logging.getLogger("subscriptions")
|
||||
|
||||
VIDEO_ONLY_MSG = (
|
||||
"This URL points to a single video, not a channel or playlist. Use Download instead."
|
||||
)
|
||||
_MEDIA_HINT_FIELDS = (
|
||||
"duration",
|
||||
"timestamp",
|
||||
"release_timestamp",
|
||||
"upload_date",
|
||||
"view_count",
|
||||
"live_status",
|
||||
"availability",
|
||||
)
|
||||
|
||||
|
||||
def _impersonate_opt(ytdl_options: dict) -> dict:
|
||||
opts = dict(ytdl_options)
|
||||
if "impersonate" in opts:
|
||||
opts["impersonate"] = yt_dlp.networking.impersonate.ImpersonateTarget.from_str(
|
||||
opts["impersonate"]
|
||||
)
|
||||
return opts
|
||||
|
||||
|
||||
def _build_ydl_params(config, *, playlistend: Optional[int] = None) -> dict:
|
||||
params: dict[str, Any] = {
|
||||
"quiet": not logging.getLogger().isEnabledFor(logging.DEBUG),
|
||||
"verbose": logging.getLogger().isEnabledFor(logging.DEBUG),
|
||||
"no_color": True,
|
||||
"extract_flat": True,
|
||||
"ignore_no_formats_error": True,
|
||||
"lazy_playlist": True,
|
||||
"paths": {"home": config.DOWNLOAD_DIR, "temp": config.TEMP_DIR},
|
||||
**config.YTDL_OPTIONS,
|
||||
}
|
||||
params = _impersonate_opt(params)
|
||||
if playlistend is not None and playlistend > 0:
|
||||
params["playlistend"] = playlistend
|
||||
return params
|
||||
|
||||
|
||||
def _is_media_entry(entry: Any) -> bool:
|
||||
if not isinstance(entry, dict):
|
||||
return False
|
||||
etype = str(entry.get("_type") or "")
|
||||
if etype in ("playlist", "multi_video", "channel"):
|
||||
return False
|
||||
if entry.get("entries"):
|
||||
return False
|
||||
url = _entry_video_url(entry)
|
||||
if not url:
|
||||
return False
|
||||
ie_key = str(entry.get("ie_key") or entry.get("extractor_key") or "").lower()
|
||||
if any(token in ie_key for token in ("playlist", "channel", "tab")):
|
||||
return any(entry.get(field) is not None for field in _MEDIA_HINT_FIELDS)
|
||||
return True
|
||||
|
||||
|
||||
def extract_flat_playlist(config, url: str, playlistend: int, *, _depth: int = 0):
|
||||
"""Return (info_dict, entries_list) for playlist/channel URLs."""
|
||||
params = _build_ydl_params(config, playlistend=playlistend)
|
||||
with yt_dlp.YoutubeDL(params=params) as ydl:
|
||||
info = ydl.extract_info(url, download=False)
|
||||
if not info:
|
||||
return None, []
|
||||
etype = info.get("_type") or "video"
|
||||
if etype == "video":
|
||||
return info, []
|
||||
if etype in ("playlist", "channel"):
|
||||
entries = info.get("entries") or []
|
||||
if isinstance(entries, types.GeneratorType):
|
||||
entries = list(entries)
|
||||
# Drop None placeholders from incomplete flat playlists
|
||||
entries = [e for e in entries if e]
|
||||
media_entries = [e for e in entries if _is_media_entry(e)]
|
||||
if media_entries:
|
||||
return info, media_entries
|
||||
if _depth < 1:
|
||||
for ent in entries[:5]:
|
||||
nested_url = _entry_video_url(ent)
|
||||
if not nested_url:
|
||||
continue
|
||||
nested_info, nested_entries = extract_flat_playlist(
|
||||
config,
|
||||
nested_url,
|
||||
playlistend,
|
||||
_depth=_depth + 1,
|
||||
)
|
||||
if nested_entries:
|
||||
return nested_info, nested_entries
|
||||
return info, entries
|
||||
if etype.startswith("url") and info.get("url"):
|
||||
# Single nested URL without playlist wrapper — treat as non-subscribable
|
||||
return info, []
|
||||
return info, []
|
||||
|
||||
|
||||
def _entry_video_url(entry: dict) -> Optional[str]:
|
||||
return entry.get("webpage_url") or entry.get("url")
|
||||
|
||||
|
||||
def _entry_id(entry: dict) -> Optional[str]:
|
||||
eid = entry.get("id")
|
||||
if eid is not None:
|
||||
return str(eid)
|
||||
url = _entry_video_url(entry)
|
||||
return url
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubscriptionInfo:
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
enabled: bool = True
|
||||
check_interval_minutes: int = 60
|
||||
download_type: str = "video"
|
||||
codec: str = "auto"
|
||||
format: str = "any"
|
||||
quality: str = "best"
|
||||
folder: str = ""
|
||||
custom_name_prefix: str = ""
|
||||
auto_start: bool = True
|
||||
playlist_item_limit: int = 0
|
||||
split_by_chapters: bool = False
|
||||
chapter_template: str = ""
|
||||
subtitle_language: str = "en"
|
||||
subtitle_mode: str = "prefer_manual"
|
||||
last_checked: Optional[float] = None
|
||||
seen_ids: list[str] = field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
def seen_set(self) -> set[str]:
|
||||
return set(self.seen_ids)
|
||||
|
||||
def to_public_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"url": self.url,
|
||||
"enabled": self.enabled,
|
||||
"check_interval_minutes": self.check_interval_minutes,
|
||||
"download_type": self.download_type,
|
||||
"codec": self.codec,
|
||||
"format": self.format,
|
||||
"quality": self.quality,
|
||||
"folder": self.folder,
|
||||
"last_checked": self.last_checked,
|
||||
"seen_count": len(self.seen_ids),
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
def _subscription_to_record(sub: SubscriptionInfo) -> dict[str, Any]:
|
||||
return {
|
||||
"id": sub.id,
|
||||
"name": sub.name,
|
||||
"url": sub.url,
|
||||
"enabled": sub.enabled,
|
||||
"check_interval_minutes": sub.check_interval_minutes,
|
||||
"download_type": sub.download_type,
|
||||
"codec": sub.codec,
|
||||
"format": sub.format,
|
||||
"quality": sub.quality,
|
||||
"folder": sub.folder,
|
||||
"custom_name_prefix": sub.custom_name_prefix,
|
||||
"auto_start": sub.auto_start,
|
||||
"playlist_item_limit": sub.playlist_item_limit,
|
||||
"split_by_chapters": sub.split_by_chapters,
|
||||
"chapter_template": sub.chapter_template,
|
||||
"subtitle_language": sub.subtitle_language,
|
||||
"subtitle_mode": sub.subtitle_mode,
|
||||
"last_checked": sub.last_checked,
|
||||
"seen_ids": list(sub.seen_ids),
|
||||
"error": sub.error,
|
||||
}
|
||||
|
||||
|
||||
def _subscription_from_record(record: Any) -> Optional[SubscriptionInfo]:
|
||||
field_names = {f.name for f in fields(SubscriptionInfo)}
|
||||
if isinstance(record, SubscriptionInfo):
|
||||
return record
|
||||
if isinstance(record, dict):
|
||||
try:
|
||||
return SubscriptionInfo(**{k: v for k, v in record.items() if k in field_names})
|
||||
except TypeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class SubscriptionNotifier:
|
||||
"""Hook for Socket.IO / UI updates."""
|
||||
|
||||
async def subscription_added(self, sub: SubscriptionInfo) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def subscription_updated(self, sub: SubscriptionInfo) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def subscription_removed(self, sub_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def subscriptions_all(self, subs: list[SubscriptionInfo]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SubscriptionManager:
|
||||
def __init__(self, config, download_queue, notifier: SubscriptionNotifier):
|
||||
self.config = config
|
||||
self.dqueue = download_queue
|
||||
self.notifier = notifier
|
||||
pdir = config.STATE_DIR
|
||||
if not os.path.isdir(pdir):
|
||||
os.makedirs(pdir, exist_ok=True)
|
||||
self._legacy_path = os.path.join(pdir, "subscriptions")
|
||||
self._path = os.path.join(pdir, "subscriptions.json")
|
||||
self._store = AtomicJsonStore(self._path, kind="subscriptions")
|
||||
self._subs: dict[str, SubscriptionInfo] = {}
|
||||
self._url_index: dict[str, str] = {} # normalized url -> id
|
||||
self._pending_urls: set[str] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._loop_task: Optional[asyncio.Task] = None
|
||||
self._load_all()
|
||||
|
||||
def close(self) -> None:
|
||||
# No persistent shelf handle to close.
|
||||
return
|
||||
|
||||
def _normalize_url(self, url: str) -> str:
|
||||
return (url or "").strip()
|
||||
|
||||
def _normalize_seen_ids(self, seen_ids: list[str]) -> list[str]:
|
||||
max_seen = int(getattr(self.config, "SUBSCRIPTION_MAX_SEEN_IDS", 50000))
|
||||
normalized = [str(sid) for sid in dict.fromkeys(seen_ids)]
|
||||
if len(normalized) > max_seen:
|
||||
normalized = normalized[:max_seen]
|
||||
return normalized
|
||||
|
||||
def _load_all(self) -> None:
|
||||
payload = self._store.load()
|
||||
loaded_from_legacy = False
|
||||
if payload is not None:
|
||||
records = payload.get("items") or []
|
||||
else:
|
||||
legacy_items = read_legacy_shelf(self._legacy_path)
|
||||
records = [raw for _key, raw in legacy_items] if legacy_items else []
|
||||
if records:
|
||||
loaded_from_legacy = True
|
||||
|
||||
loaded_subs = self._iter_valid_subs(records)
|
||||
compact_records = []
|
||||
for sub in loaded_subs:
|
||||
sub.seen_ids = self._normalize_seen_ids(sub.seen_ids)
|
||||
self._subs[sub.id] = sub
|
||||
self._url_index[self._normalize_url(sub.url)] = sub.id
|
||||
compact_records.append(_subscription_to_record(sub))
|
||||
|
||||
if loaded_from_legacy or (
|
||||
payload is not None
|
||||
and (
|
||||
payload.get("schema_version") != self._store.schema_version
|
||||
or compact_records != records
|
||||
)
|
||||
):
|
||||
self._store.save({"items": compact_records})
|
||||
|
||||
def _iter_valid_subs(self, records: list[Any]) -> list[SubscriptionInfo]:
|
||||
subs: list[SubscriptionInfo] = []
|
||||
for record in records:
|
||||
sub = _subscription_from_record(record)
|
||||
if sub is not None:
|
||||
subs.append(sub)
|
||||
return subs
|
||||
|
||||
def _save_locked(self) -> None:
|
||||
self._store.save({"items": [_subscription_to_record(sub) for sub in self._subs.values()]})
|
||||
|
||||
async def _queue_subscription_entries(
|
||||
self,
|
||||
entries: list[dict],
|
||||
*,
|
||||
download_type: str,
|
||||
codec: str,
|
||||
format: str,
|
||||
quality: str,
|
||||
folder: str,
|
||||
custom_name_prefix: str,
|
||||
playlist_item_limit: int,
|
||||
auto_start: bool,
|
||||
split_by_chapters: bool,
|
||||
chapter_template: str,
|
||||
subtitle_language: str,
|
||||
subtitle_mode: str,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
queued_ids: list[str] = []
|
||||
queue_errors: list[str] = []
|
||||
for ent in entries:
|
||||
eid = _entry_id(ent)
|
||||
vurl = _entry_video_url(ent)
|
||||
if not eid or not vurl:
|
||||
continue
|
||||
queue_entry = dict(ent)
|
||||
queue_entry["_type"] = "video"
|
||||
queue_entry["webpage_url"] = vurl
|
||||
result = await self.dqueue.add_entry(
|
||||
queue_entry,
|
||||
download_type,
|
||||
codec,
|
||||
format,
|
||||
quality,
|
||||
folder or None,
|
||||
custom_name_prefix,
|
||||
playlist_item_limit,
|
||||
auto_start,
|
||||
split_by_chapters,
|
||||
chapter_template or None,
|
||||
subtitle_language,
|
||||
subtitle_mode,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("status") == "error":
|
||||
msg = str(result.get("msg") or f"Queueing failed for {vurl}")
|
||||
queue_errors.append(msg)
|
||||
log.warning("Subscription queueing failed for %s: %s", vurl, msg)
|
||||
continue
|
||||
queued_ids.append(eid)
|
||||
return queued_ids, queue_errors
|
||||
|
||||
def list_all(self) -> list[SubscriptionInfo]:
|
||||
return list(self._subs.values())
|
||||
|
||||
def get(self, sub_id: str) -> Optional[SubscriptionInfo]:
|
||||
return self._subs.get(sub_id)
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
if self._loop_task is not None and not self._loop_task.done():
|
||||
return
|
||||
self._loop_task = asyncio.create_task(self._periodic_loop())
|
||||
self._loop_task.add_done_callback(
|
||||
lambda t: log.error("Subscription loop failed: %s", t.exception())
|
||||
if not t.cancelled() and t.exception()
|
||||
else None
|
||||
)
|
||||
|
||||
async def _periodic_loop(self) -> None:
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
try:
|
||||
await self.run_due_checks()
|
||||
except Exception as e:
|
||||
log.exception("Subscription periodic check error: %s", e)
|
||||
|
||||
async def run_due_checks(self) -> None:
|
||||
now = time.time()
|
||||
due: list[SubscriptionInfo] = []
|
||||
async with self._lock:
|
||||
for sub in list(self._subs.values()):
|
||||
if not sub.enabled:
|
||||
continue
|
||||
interval_sec = max(60, int(sub.check_interval_minutes) * 60)
|
||||
if sub.last_checked is None:
|
||||
due.append(sub)
|
||||
continue
|
||||
if now - sub.last_checked < interval_sec:
|
||||
continue
|
||||
due.append(sub)
|
||||
for sub in due:
|
||||
await self._check_one_unlocked(sub)
|
||||
|
||||
async def add_subscription(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
check_interval_minutes: int,
|
||||
download_type: str,
|
||||
codec: str,
|
||||
format: str,
|
||||
quality: str,
|
||||
folder: str,
|
||||
custom_name_prefix: str,
|
||||
auto_start: bool,
|
||||
playlist_item_limit: int,
|
||||
split_by_chapters: bool,
|
||||
chapter_template: str,
|
||||
subtitle_language: str,
|
||||
subtitle_mode: str,
|
||||
) -> dict:
|
||||
url = self._normalize_url(url)
|
||||
if not url:
|
||||
return {"status": "error", "msg": "Missing URL"}
|
||||
|
||||
async with self._lock:
|
||||
if url in self._url_index or url in self._pending_urls:
|
||||
return {"status": "error", "msg": "This URL is already subscribed"}
|
||||
self._pending_urls.add(url)
|
||||
|
||||
try:
|
||||
scan_first = max(int(getattr(self.config, "SUBSCRIPTION_SCAN_PLAYLIST_END", 50)), 1)
|
||||
try:
|
||||
info, entries = extract_flat_playlist(self.config, url, scan_first)
|
||||
except yt_dlp.utils.YoutubeDLError as exc:
|
||||
return {"status": "error", "msg": str(exc)}
|
||||
|
||||
if not info:
|
||||
return {"status": "error", "msg": "Could not resolve URL"}
|
||||
|
||||
etype = info.get("_type") or "video"
|
||||
if etype not in ("playlist", "channel"):
|
||||
return {"status": "error", "msg": VIDEO_ONLY_MSG}
|
||||
|
||||
name = (
|
||||
info.get("title")
|
||||
or info.get("channel")
|
||||
or info.get("playlist_title")
|
||||
or info.get("uploader")
|
||||
or url
|
||||
)
|
||||
|
||||
seen_entries = [ent for ent in entries if _is_media_entry(ent)]
|
||||
all_ids: list[str] = []
|
||||
for ent in seen_entries:
|
||||
eid = _entry_id(ent)
|
||||
if eid:
|
||||
all_ids.append(eid)
|
||||
|
||||
sub = SubscriptionInfo(
|
||||
id=str(uuid.uuid4()),
|
||||
name=str(name),
|
||||
url=url,
|
||||
enabled=True,
|
||||
check_interval_minutes=max(1, int(check_interval_minutes)),
|
||||
download_type=download_type,
|
||||
codec=codec,
|
||||
format=format,
|
||||
quality=quality,
|
||||
folder=folder or "",
|
||||
custom_name_prefix=custom_name_prefix or "",
|
||||
auto_start=bool(auto_start),
|
||||
playlist_item_limit=int(playlist_item_limit),
|
||||
split_by_chapters=bool(split_by_chapters),
|
||||
chapter_template=chapter_template or "",
|
||||
subtitle_language=subtitle_language,
|
||||
subtitle_mode=subtitle_mode,
|
||||
last_checked=time.time(),
|
||||
seen_ids=list(dict.fromkeys(all_ids)),
|
||||
error=None,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if url in self._url_index:
|
||||
return {"status": "error", "msg": "This URL is already subscribed"}
|
||||
self._subs[sub.id] = sub
|
||||
self._url_index[url] = sub.id
|
||||
try:
|
||||
self._save_locked()
|
||||
except Exception:
|
||||
self._subs.pop(sub.id, None)
|
||||
self._url_index.pop(url, None)
|
||||
raise
|
||||
|
||||
await self.notifier.subscription_added(sub)
|
||||
return {"status": "ok", "subscription": sub.to_public_dict()}
|
||||
finally:
|
||||
async with self._lock:
|
||||
self._pending_urls.discard(url)
|
||||
|
||||
async def delete_subscriptions(self, ids: list[str]) -> dict:
|
||||
removed: list[str] = []
|
||||
async with self._lock:
|
||||
previous_subs = self._subs.copy()
|
||||
previous_index = self._url_index.copy()
|
||||
for sid in ids:
|
||||
sub = self._subs.pop(sid, None)
|
||||
if sub:
|
||||
normalized_url = self._normalize_url(sub.url)
|
||||
self._url_index.pop(normalized_url, None)
|
||||
removed.append(sid)
|
||||
if removed:
|
||||
try:
|
||||
self._save_locked()
|
||||
except Exception:
|
||||
self._subs = previous_subs
|
||||
self._url_index = previous_index
|
||||
raise
|
||||
for sid in removed:
|
||||
await self.notifier.subscription_removed(sid)
|
||||
return {"status": "ok"}
|
||||
|
||||
async def update_subscription(self, sub_id: str, changes: dict) -> dict:
|
||||
async with self._lock:
|
||||
sub = self._subs.get(sub_id)
|
||||
if not sub:
|
||||
return {"status": "error", "msg": "Subscription not found"}
|
||||
previous = copy.deepcopy(sub)
|
||||
old_enabled = sub.enabled
|
||||
|
||||
if "enabled" in changes:
|
||||
sub.enabled = bool(changes["enabled"])
|
||||
if "check_interval_minutes" in changes:
|
||||
sub.check_interval_minutes = max(1, int(changes["check_interval_minutes"]))
|
||||
if "name" in changes and changes["name"]:
|
||||
sub.name = str(changes["name"])
|
||||
|
||||
try:
|
||||
self._save_locked()
|
||||
except Exception:
|
||||
self._subs[sub_id] = previous
|
||||
raise
|
||||
updated = sub
|
||||
if "enabled" in changes and updated.enabled != old_enabled:
|
||||
log.info(
|
||||
"Subscription %s %s",
|
||||
updated.name,
|
||||
"resumed" if updated.enabled else "paused",
|
||||
)
|
||||
await self.notifier.subscription_updated(updated)
|
||||
return {"status": "ok", "subscription": updated.to_public_dict()}
|
||||
|
||||
async def check_now(self, ids: Optional[list[str]] = None) -> dict:
|
||||
async with self._lock:
|
||||
targets = (
|
||||
[self._subs[i] for i in ids if i in self._subs]
|
||||
if ids
|
||||
else [s for s in self._subs.values() if s.enabled]
|
||||
)
|
||||
log.info(
|
||||
"Manual subscription check requested for %d subscription(s)",
|
||||
len(targets),
|
||||
)
|
||||
for sub in targets:
|
||||
await self._check_one_unlocked(sub)
|
||||
return {"status": "ok"}
|
||||
|
||||
async def _check_one_unlocked(self, sub: SubscriptionInfo) -> None:
|
||||
sid = sub.id
|
||||
scan = int(getattr(self.config, "SUBSCRIPTION_SCAN_PLAYLIST_END", 50))
|
||||
log.info("Checking subscription: %s", sub.name)
|
||||
try:
|
||||
info, entries = extract_flat_playlist(self.config, sub.url, scan)
|
||||
except yt_dlp.utils.YoutubeDLError as exc:
|
||||
async with self._lock:
|
||||
cur = self._subs.get(sid)
|
||||
if cur:
|
||||
previous = copy.deepcopy(cur)
|
||||
cur.error = str(exc)
|
||||
try:
|
||||
self._save_locked()
|
||||
except Exception:
|
||||
self._subs[sid] = previous
|
||||
raise
|
||||
sub = cur
|
||||
log.warning("Subscription check failed for %s: %s", sub.name, exc)
|
||||
await self.notifier.subscription_updated(sub)
|
||||
return
|
||||
entries = [ent for ent in entries if _is_media_entry(ent)]
|
||||
|
||||
etype = (info or {}).get("_type") or "video"
|
||||
if etype == "video" or not entries:
|
||||
async with self._lock:
|
||||
cur = self._subs.get(sid)
|
||||
if cur:
|
||||
previous = copy.deepcopy(cur)
|
||||
cur.error = VIDEO_ONLY_MSG
|
||||
try:
|
||||
self._save_locked()
|
||||
except Exception:
|
||||
self._subs[sid] = previous
|
||||
raise
|
||||
sub = cur
|
||||
log.warning("Subscription %s no longer resolves to a subscribable feed", sub.name)
|
||||
await self.notifier.subscription_updated(sub)
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
cur = self._subs.get(sid)
|
||||
if not cur:
|
||||
return
|
||||
seen = cur.seen_set()
|
||||
seen_ids_snapshot = list(cur.seen_ids)
|
||||
dl_type = cur.download_type
|
||||
dl_codec = cur.codec
|
||||
dl_format = cur.format
|
||||
dl_quality = cur.quality
|
||||
dl_folder = cur.folder
|
||||
dl_prefix = cur.custom_name_prefix
|
||||
dl_plimit = cur.playlist_item_limit
|
||||
dl_autostart = cur.auto_start
|
||||
dl_split = cur.split_by_chapters
|
||||
dl_chapter = cur.chapter_template
|
||||
dl_sublang = cur.subtitle_language
|
||||
dl_submode = cur.subtitle_mode
|
||||
|
||||
new_entries: list[dict] = []
|
||||
new_ids: list[str] = []
|
||||
for ent in entries:
|
||||
eid = _entry_id(ent)
|
||||
if not eid or eid in seen:
|
||||
continue
|
||||
new_entries.append(ent)
|
||||
new_ids.append(eid)
|
||||
|
||||
queued_ids, queue_errors = await self._queue_subscription_entries(
|
||||
new_entries,
|
||||
download_type=dl_type,
|
||||
codec=dl_codec,
|
||||
format=dl_format,
|
||||
quality=dl_quality,
|
||||
folder=dl_folder,
|
||||
custom_name_prefix=dl_prefix,
|
||||
playlist_item_limit=dl_plimit,
|
||||
auto_start=dl_autostart,
|
||||
split_by_chapters=dl_split,
|
||||
chapter_template=dl_chapter or "",
|
||||
subtitle_language=dl_sublang,
|
||||
subtitle_mode=dl_submode,
|
||||
)
|
||||
log.info(
|
||||
"Subscription check finished for %s: %d new, %d queued, %d failed",
|
||||
sub.name,
|
||||
len(new_entries),
|
||||
len(queued_ids),
|
||||
len(queue_errors),
|
||||
)
|
||||
|
||||
merged = list(dict.fromkeys(queued_ids + seen_ids_snapshot))
|
||||
max_seen = int(getattr(self.config, "SUBSCRIPTION_MAX_SEEN_IDS", 50000))
|
||||
if len(merged) > max_seen:
|
||||
merged = merged[:max_seen]
|
||||
|
||||
async with self._lock:
|
||||
cur = self._subs.get(sid)
|
||||
if not cur:
|
||||
return
|
||||
previous = copy.deepcopy(cur)
|
||||
cur.seen_ids = merged
|
||||
cur.last_checked = time.time()
|
||||
cur.error = "; ".join(queue_errors[:3]) if queue_errors else None
|
||||
try:
|
||||
self._save_locked()
|
||||
except Exception:
|
||||
self._subs[sid] = previous
|
||||
raise
|
||||
sub = cur
|
||||
await self.notifier.subscription_updated(sub)
|
||||
|
||||
async def emit_all(self) -> None:
|
||||
await self.notifier.subscriptions_all(self.list_all())
|
||||
@@ -144,3 +144,34 @@ async def test_start_pending_moves_to_queue(dq_env):
|
||||
with patch.object(DownloadQueue, "_DownloadQueue__start_download", AsyncMock()):
|
||||
await dq.start_pending([url])
|
||||
assert not dq.pending.exists(url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_entry_queues_single_video_without_reextracting(dq_env):
|
||||
notifier = AsyncMock()
|
||||
dq = DownloadQueue(dq_env, notifier)
|
||||
entry = {
|
||||
"_type": "video",
|
||||
"id": "vid1",
|
||||
"title": "Test Video",
|
||||
"url": "https://example.com/watch?v=1",
|
||||
"webpage_url": "https://example.com/watch?v=1",
|
||||
"playlist_index": "01",
|
||||
"playlist_title": "Playlist",
|
||||
}
|
||||
|
||||
with patch.object(DownloadQueue, "_DownloadQueue__extract_info", side_effect=AssertionError("should not re-extract")):
|
||||
result = await dq.add_entry(
|
||||
entry,
|
||||
"video",
|
||||
"auto",
|
||||
"any",
|
||||
"best",
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
auto_start=False,
|
||||
)
|
||||
|
||||
assert result["status"] == "ok"
|
||||
assert dq.pending.exists("https://example.com/watch?v=1")
|
||||
|
||||
@@ -1,12 +1,39 @@
|
||||
"""Integration tests for ``PersistentQueue`` (shelve-backed storage)."""
|
||||
"""Integration tests for ``PersistentQueue`` using the JSON state store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import shelve
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
fake_yt_dlp = types.ModuleType("yt_dlp")
|
||||
fake_networking = types.ModuleType("yt_dlp.networking")
|
||||
fake_impersonate = types.ModuleType("yt_dlp.networking.impersonate")
|
||||
fake_utils = types.ModuleType("yt_dlp.utils")
|
||||
|
||||
|
||||
class _ImpersonateTarget:
|
||||
@staticmethod
|
||||
def from_str(value):
|
||||
return value
|
||||
|
||||
|
||||
fake_impersonate.ImpersonateTarget = _ImpersonateTarget
|
||||
fake_networking.impersonate = fake_impersonate
|
||||
fake_utils.STR_FORMAT_RE_TMPL = r"(?P<prefix>)%\((?P<has_key>{})\)(?P<format>[-0-9.]*{})"
|
||||
fake_utils.STR_FORMAT_TYPES = "diouxXeEfFgGcrsa"
|
||||
fake_yt_dlp.networking = fake_networking
|
||||
fake_yt_dlp.utils = fake_utils
|
||||
sys.modules.setdefault("yt_dlp", fake_yt_dlp)
|
||||
sys.modules.setdefault("yt_dlp.networking", fake_networking)
|
||||
sys.modules.setdefault("yt_dlp.networking.impersonate", fake_impersonate)
|
||||
sys.modules.setdefault("yt_dlp.utils", fake_utils)
|
||||
|
||||
from ytdl import DownloadInfo, PersistentQueue
|
||||
|
||||
|
||||
@@ -36,6 +63,12 @@ def _make_info(url: str = "https://example.com/v") -> DownloadInfo:
|
||||
)
|
||||
|
||||
|
||||
def _create_legacy_shelf(path: str, *infos: DownloadInfo) -> None:
|
||||
with shelve.open(path, "c") as shelf:
|
||||
for info in infos:
|
||||
shelf[info.url] = info
|
||||
|
||||
|
||||
class PersistentQueueTests(unittest.TestCase):
|
||||
def test_put_get_delete_roundtrip(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
@@ -43,6 +76,7 @@ class PersistentQueueTests(unittest.TestCase):
|
||||
pq = PersistentQueue("queue", path)
|
||||
dl = _FakeDownload(_make_info("http://a.example"))
|
||||
pq.put(dl)
|
||||
self.assertTrue(os.path.exists(path + ".json"))
|
||||
self.assertTrue(pq.exists("http://a.example"))
|
||||
self.assertFalse(pq.empty())
|
||||
got = pq.get("http://a.example")
|
||||
@@ -63,7 +97,7 @@ class PersistentQueueTests(unittest.TestCase):
|
||||
keys = [k for k, _ in pq.saved_items()]
|
||||
self.assertEqual(keys, ["http://first.example", "http://second.example"])
|
||||
|
||||
def test_load_restores_from_shelve(self):
|
||||
def test_load_restores_from_json(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue")
|
||||
pq1 = PersistentQueue("queue", path)
|
||||
@@ -72,21 +106,159 @@ class PersistentQueueTests(unittest.TestCase):
|
||||
pq2.load()
|
||||
self.assertTrue(pq2.exists("http://load.example"))
|
||||
|
||||
def test_put_rollbacks_in_memory_queue_when_shelf_write_fails(self):
|
||||
def test_load_imports_legacy_shelve(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue")
|
||||
_create_legacy_shelf(path, _make_info("http://legacy.example"))
|
||||
pq = PersistentQueue("queue", path)
|
||||
pq.load()
|
||||
self.assertTrue(pq.exists("http://legacy.example"))
|
||||
self.assertTrue(os.path.exists(path + ".json"))
|
||||
|
||||
def test_queue_persists_only_compact_entry_subset(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue")
|
||||
pq = PersistentQueue("queue", path)
|
||||
info = _make_info("http://entry.example")
|
||||
info.entry = {
|
||||
"playlist_index": "01",
|
||||
"playlist_title": "Playlist",
|
||||
"channel_index": "02",
|
||||
"channel_title": "Channel",
|
||||
"formats": [{"id": "huge"}],
|
||||
"description": "very large payload",
|
||||
}
|
||||
pq.put(_FakeDownload(info))
|
||||
|
||||
with open(path + ".json", encoding="utf-8") as f:
|
||||
payload = json.load(f)
|
||||
|
||||
record = payload["items"][0]["info"]
|
||||
self.assertEqual(
|
||||
record["entry"],
|
||||
{
|
||||
"playlist_index": "01",
|
||||
"playlist_title": "Playlist",
|
||||
"channel_index": "02",
|
||||
"channel_title": "Channel",
|
||||
},
|
||||
)
|
||||
self.assertNotIn("formats", record["entry"])
|
||||
self.assertNotIn("description", record["entry"])
|
||||
|
||||
def test_completed_queue_does_not_persist_entry_or_transient_progress(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "completed")
|
||||
pq = PersistentQueue("completed", path)
|
||||
info = _make_info("http://done.example")
|
||||
info.status = "finished"
|
||||
info.percent = 88
|
||||
info.speed = 123
|
||||
info.eta = 9
|
||||
info.entry = {
|
||||
"playlist_index": "01",
|
||||
"playlist_title": "Playlist",
|
||||
"formats": [{"id": "huge"}],
|
||||
}
|
||||
info.filename = "done.mp4"
|
||||
pq.put(_FakeDownload(info))
|
||||
|
||||
with open(path + ".json", encoding="utf-8") as f:
|
||||
payload = json.load(f)
|
||||
|
||||
record = payload["items"][0]["info"]
|
||||
self.assertNotIn("entry", record)
|
||||
self.assertNotIn("percent", record)
|
||||
self.assertNotIn("speed", record)
|
||||
self.assertNotIn("eta", record)
|
||||
self.assertEqual(record["filename"], "done.mp4")
|
||||
|
||||
def test_invalid_json_is_quarantined_and_legacy_is_imported(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue")
|
||||
_create_legacy_shelf(path, _make_info("http://legacy.example"))
|
||||
with open(path + ".json", "w", encoding="utf-8") as f:
|
||||
f.write("{not valid json")
|
||||
|
||||
pq = PersistentQueue("queue", path)
|
||||
pq.load()
|
||||
|
||||
self.assertTrue(pq.exists("http://legacy.example"))
|
||||
self.assertTrue(
|
||||
any(name.startswith("queue.json.invalid.") for name in os.listdir(tmp))
|
||||
)
|
||||
|
||||
def test_loading_old_json_rewrites_to_compact_format(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue")
|
||||
with open(path + ".json", "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"schema_version": 1,
|
||||
"kind": "persistent_queue:queue",
|
||||
"items": [
|
||||
{
|
||||
"key": "http://legacy-json.example",
|
||||
"info": {
|
||||
"id": "id1",
|
||||
"title": "Title",
|
||||
"url": "http://legacy-json.example",
|
||||
"quality": "best",
|
||||
"download_type": "video",
|
||||
"codec": "auto",
|
||||
"format": "any",
|
||||
"folder": "",
|
||||
"custom_name_prefix": "",
|
||||
"playlist_item_limit": 0,
|
||||
"split_by_chapters": False,
|
||||
"chapter_template": "",
|
||||
"subtitle_language": "en",
|
||||
"subtitle_mode": "prefer_manual",
|
||||
"status": "pending",
|
||||
"timestamp": 1,
|
||||
"entry": {
|
||||
"playlist_index": "01",
|
||||
"playlist_title": "Playlist",
|
||||
"formats": [{"id": "huge"}],
|
||||
},
|
||||
"percent": 15,
|
||||
"speed": 20,
|
||||
"eta": 30,
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
pq = PersistentQueue("queue", path)
|
||||
pq.load()
|
||||
|
||||
with open(path + ".json", encoding="utf-8") as f:
|
||||
payload = json.load(f)
|
||||
|
||||
record = payload["items"][0]["info"]
|
||||
self.assertEqual(payload["schema_version"], 2)
|
||||
self.assertEqual(record["entry"], {"playlist_index": "01", "playlist_title": "Playlist"})
|
||||
self.assertNotIn("percent", record)
|
||||
self.assertNotIn("speed", record)
|
||||
self.assertNotIn("eta", record)
|
||||
|
||||
def test_put_rollbacks_in_memory_queue_when_state_write_fails(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue")
|
||||
pq = PersistentQueue("queue", path)
|
||||
dl = _FakeDownload(_make_info("http://rollback.example"))
|
||||
self.assertFalse(pq.exists("http://rollback.example"))
|
||||
|
||||
orig_open = __import__("shelve").open
|
||||
orig_save = __import__("state_store").AtomicJsonStore.save
|
||||
|
||||
def bad_open(filename, flag="c", *args, **kwargs):
|
||||
if flag == "w":
|
||||
def bad_save(store, data):
|
||||
if store.path == path + ".json":
|
||||
raise OSError("simulated shelf failure")
|
||||
return orig_open(filename, flag, *args, **kwargs)
|
||||
return orig_save(store, data)
|
||||
|
||||
with patch("ytdl.shelve.open", bad_open):
|
||||
with patch("ytdl.AtomicJsonStore.save", bad_save):
|
||||
with self.assertRaises(OSError):
|
||||
pq.put(dl)
|
||||
|
||||
@@ -101,14 +273,14 @@ class PersistentQueueTests(unittest.TestCase):
|
||||
second.info.title = "Replaced title"
|
||||
pq.put(first)
|
||||
|
||||
orig_open = __import__("shelve").open
|
||||
orig_save = __import__("state_store").AtomicJsonStore.save
|
||||
|
||||
def bad_open(filename, flag="c", *args, **kwargs):
|
||||
if flag == "w":
|
||||
def bad_save(store, data):
|
||||
if store.path == path + ".json":
|
||||
raise OSError("simulated shelf failure")
|
||||
return orig_open(filename, flag, *args, **kwargs)
|
||||
return orig_save(store, data)
|
||||
|
||||
with patch("ytdl.shelve.open", bad_open):
|
||||
with patch("ytdl.AtomicJsonStore.save", bad_save):
|
||||
with self.assertRaises(OSError):
|
||||
pq.put(second)
|
||||
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from state_store import AtomicJsonStore, from_json_compatible, to_json_compatible
|
||||
|
||||
|
||||
class StateStoreTests(unittest.TestCase):
|
||||
def test_save_and_load_roundtrip(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue.json")
|
||||
store = AtomicJsonStore(path, kind="persistent_queue:queue")
|
||||
store.save({"items": [{"key": "a", "info": {"title": "hello"}}]})
|
||||
|
||||
payload = store.load()
|
||||
|
||||
self.assertEqual(payload["kind"], "persistent_queue:queue")
|
||||
self.assertEqual(payload["schema_version"], 2)
|
||||
self.assertEqual(payload["items"][0]["info"]["title"], "hello")
|
||||
|
||||
def test_invalid_file_is_quarantined(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue.json")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write("{broken")
|
||||
|
||||
store = AtomicJsonStore(path, kind="persistent_queue:queue")
|
||||
payload = store.load()
|
||||
|
||||
self.assertIsNone(payload)
|
||||
self.assertTrue(
|
||||
any(name.startswith("queue.json.invalid.") for name in os.listdir(tmp))
|
||||
)
|
||||
|
||||
def test_json_compat_helpers_roundtrip_bytes_and_datetime(self):
|
||||
raw = {
|
||||
"payload": b"abc",
|
||||
"timestamp": datetime(2024, 1, 2, 3, 4, 5),
|
||||
"items": (1, 2, 3),
|
||||
}
|
||||
|
||||
restored = from_json_compatible(to_json_compatible(raw))
|
||||
|
||||
self.assertEqual(restored["payload"], b"abc")
|
||||
self.assertEqual(restored["timestamp"], datetime(2024, 1, 2, 3, 4, 5))
|
||||
self.assertEqual(restored["items"], [1, 2, 3])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,443 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import shelve
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
fake_yt_dlp = types.ModuleType("yt_dlp")
|
||||
fake_networking = types.ModuleType("yt_dlp.networking")
|
||||
fake_impersonate = types.ModuleType("yt_dlp.networking.impersonate")
|
||||
|
||||
|
||||
class _ImpersonateTarget:
|
||||
@staticmethod
|
||||
def from_str(value):
|
||||
return value
|
||||
|
||||
|
||||
fake_impersonate.ImpersonateTarget = _ImpersonateTarget
|
||||
fake_networking.impersonate = fake_impersonate
|
||||
fake_yt_dlp.networking = fake_networking
|
||||
fake_yt_dlp.utils = types.SimpleNamespace(YoutubeDLError=Exception)
|
||||
sys.modules.setdefault("yt_dlp", fake_yt_dlp)
|
||||
sys.modules.setdefault("yt_dlp.networking", fake_networking)
|
||||
sys.modules.setdefault("yt_dlp.networking.impersonate", fake_impersonate)
|
||||
|
||||
from subscriptions import SubscriptionManager, extract_flat_playlist
|
||||
|
||||
|
||||
class _Config:
|
||||
def __init__(self, state_dir: str):
|
||||
self.STATE_DIR = state_dir
|
||||
self.SUBSCRIPTION_SCAN_PLAYLIST_END = 50
|
||||
self.SUBSCRIPTION_MAX_SEEN_IDS = 50000
|
||||
self.DOWNLOAD_DIR = state_dir
|
||||
self.TEMP_DIR = state_dir
|
||||
self.YTDL_OPTIONS = {}
|
||||
|
||||
|
||||
class _Queue:
|
||||
def __init__(self):
|
||||
self.entries = []
|
||||
self.fail = False
|
||||
|
||||
async def add(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def add_entry(self, entry, *args, **kwargs):
|
||||
if self.fail:
|
||||
return {"status": "error", "msg": "queue failed"}
|
||||
self.entries.append((entry, args, kwargs))
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class _Notifier:
|
||||
async def subscription_added(self, sub):
|
||||
return None
|
||||
|
||||
async def subscription_updated(self, sub):
|
||||
return None
|
||||
|
||||
async def subscription_removed(self, sub_id):
|
||||
return None
|
||||
|
||||
async def subscriptions_all(self, subs):
|
||||
return None
|
||||
|
||||
|
||||
def _create_legacy_shelf(path: str, record) -> None:
|
||||
with shelve.open(path, "c") as shelf:
|
||||
shelf["sub-1"] = record
|
||||
|
||||
|
||||
class SubscriptionPersistenceTests(unittest.IsolatedAsyncioTestCase):
|
||||
def test_load_imports_legacy_subscription_shelf(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
legacy_path = os.path.join(tmp, "subscriptions")
|
||||
json_path = os.path.join(tmp, "subscriptions.json")
|
||||
_create_legacy_shelf(
|
||||
legacy_path,
|
||||
{
|
||||
"id": "sub-1",
|
||||
"name": "Channel",
|
||||
"url": "https://example.com/channel",
|
||||
"timestamp": 1.0,
|
||||
},
|
||||
)
|
||||
|
||||
mgr = SubscriptionManager(_Config(tmp), _Queue(), _Notifier())
|
||||
|
||||
self.assertEqual(len(mgr.list_all()), 1)
|
||||
self.assertTrue(os.path.exists(json_path))
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
payload = json.load(f)
|
||||
self.assertEqual(payload["schema_version"], 2)
|
||||
self.assertNotIn("timestamp", payload["items"][0])
|
||||
|
||||
def test_invalid_json_is_quarantined_and_legacy_is_imported(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
legacy_path = os.path.join(tmp, "subscriptions")
|
||||
json_path = os.path.join(tmp, "subscriptions.json")
|
||||
_create_legacy_shelf(
|
||||
legacy_path,
|
||||
{
|
||||
"id": "sub-1",
|
||||
"name": "Channel",
|
||||
"url": "https://example.com/channel",
|
||||
"timestamp": 1.0,
|
||||
},
|
||||
)
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write("{not valid json")
|
||||
|
||||
mgr = SubscriptionManager(_Config(tmp), _Queue(), _Notifier())
|
||||
|
||||
self.assertEqual(len(mgr.list_all()), 1)
|
||||
self.assertTrue(
|
||||
any(name.startswith("subscriptions.json.invalid.") for name in os.listdir(tmp))
|
||||
)
|
||||
|
||||
def test_load_rewrites_old_json_and_trims_seen_ids(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
json_path = os.path.join(tmp, "subscriptions.json")
|
||||
cfg = _Config(tmp)
|
||||
cfg.SUBSCRIPTION_MAX_SEEN_IDS = 2
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"schema_version": 1,
|
||||
"kind": "subscriptions",
|
||||
"items": [
|
||||
{
|
||||
"id": "sub-1",
|
||||
"name": "Channel",
|
||||
"url": "https://example.com/channel",
|
||||
"enabled": True,
|
||||
"check_interval_minutes": 60,
|
||||
"download_type": "video",
|
||||
"codec": "auto",
|
||||
"format": "any",
|
||||
"quality": "best",
|
||||
"folder": "",
|
||||
"custom_name_prefix": "",
|
||||
"auto_start": True,
|
||||
"playlist_item_limit": 0,
|
||||
"split_by_chapters": False,
|
||||
"chapter_template": "",
|
||||
"subtitle_language": "en",
|
||||
"subtitle_mode": "prefer_manual",
|
||||
"last_checked": None,
|
||||
"seen_ids": ["a", "b", "a", "c"],
|
||||
"error": None,
|
||||
"timestamp": 123,
|
||||
}
|
||||
],
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
mgr = SubscriptionManager(cfg, _Queue(), _Notifier())
|
||||
self.assertEqual(mgr.list_all()[0].seen_ids, ["a", "b"])
|
||||
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
payload = json.load(f)
|
||||
|
||||
self.assertEqual(payload["schema_version"], 2)
|
||||
self.assertEqual(payload["items"][0]["seen_ids"], ["a", "b"])
|
||||
self.assertNotIn("timestamp", payload["items"][0])
|
||||
|
||||
async def test_add_subscription_rolls_back_when_state_write_fails(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
mgr = SubscriptionManager(_Config(tmp), _Queue(), _Notifier())
|
||||
|
||||
orig_save = __import__("state_store").AtomicJsonStore.save
|
||||
|
||||
def bad_save(store, data):
|
||||
if store.path == mgr._path:
|
||||
raise OSError("simulated shelf failure")
|
||||
return orig_save(store, data)
|
||||
|
||||
with patch(
|
||||
"subscriptions.extract_flat_playlist",
|
||||
return_value=(
|
||||
{"_type": "channel", "title": "Channel"},
|
||||
[{"id": "v1", "webpage_url": "https://example.com/v1"}],
|
||||
),
|
||||
):
|
||||
with patch("subscriptions.AtomicJsonStore.save", bad_save):
|
||||
with self.assertRaises(OSError):
|
||||
await mgr.add_subscription(
|
||||
"https://example.com/channel",
|
||||
check_interval_minutes=60,
|
||||
download_type="video",
|
||||
codec="auto",
|
||||
format="any",
|
||||
quality="best",
|
||||
folder="",
|
||||
custom_name_prefix="",
|
||||
auto_start=True,
|
||||
playlist_item_limit=0,
|
||||
split_by_chapters=False,
|
||||
chapter_template="",
|
||||
subtitle_language="en",
|
||||
subtitle_mode="prefer_manual",
|
||||
)
|
||||
|
||||
self.assertEqual(mgr.list_all(), [])
|
||||
self.assertNotIn("https://example.com/channel", mgr._url_index)
|
||||
|
||||
async def test_add_subscription_marks_existing_videos_seen_without_queueing(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
queue = _Queue()
|
||||
mgr = SubscriptionManager(_Config(tmp), queue, _Notifier())
|
||||
|
||||
with patch(
|
||||
"subscriptions.extract_flat_playlist",
|
||||
return_value=(
|
||||
{"_type": "channel", "title": "Channel"},
|
||||
[
|
||||
{"id": "v1", "title": "One", "webpage_url": "https://example.com/v1"},
|
||||
{"id": "v2", "title": "Two", "webpage_url": "https://example.com/v2"},
|
||||
{"id": "v3", "title": "Three", "webpage_url": "https://example.com/v3"},
|
||||
],
|
||||
),
|
||||
):
|
||||
result = await mgr.add_subscription(
|
||||
"https://example.com/channel",
|
||||
check_interval_minutes=60,
|
||||
download_type="video",
|
||||
codec="auto",
|
||||
format="any",
|
||||
quality="best",
|
||||
folder="",
|
||||
custom_name_prefix="",
|
||||
auto_start=True,
|
||||
playlist_item_limit=0,
|
||||
split_by_chapters=False,
|
||||
chapter_template="",
|
||||
subtitle_language="en",
|
||||
subtitle_mode="prefer_manual",
|
||||
)
|
||||
|
||||
assert result["status"] == "ok"
|
||||
sub = mgr.list_all()[0]
|
||||
self.assertEqual(sub.seen_ids, ["v1", "v2", "v3"])
|
||||
self.assertIsNone(sub.error)
|
||||
self.assertEqual(queue.entries, [])
|
||||
|
||||
async def test_add_subscription_skips_collection_tab_entries(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
queue = _Queue()
|
||||
mgr = SubscriptionManager(_Config(tmp), queue, _Notifier())
|
||||
|
||||
with patch(
|
||||
"subscriptions.extract_flat_playlist",
|
||||
return_value=(
|
||||
{"_type": "channel", "title": "Channel"},
|
||||
[
|
||||
{
|
||||
"_type": "url",
|
||||
"ie_key": "YoutubeTab",
|
||||
"title": "Channel - Live",
|
||||
"url": "https://example.com/live",
|
||||
"webpage_url": "https://example.com/live",
|
||||
},
|
||||
{
|
||||
"_type": "url",
|
||||
"ie_key": "Youtube",
|
||||
"id": "v1",
|
||||
"title": "One",
|
||||
"duration": 10,
|
||||
"webpage_url": "https://example.com/v1",
|
||||
},
|
||||
],
|
||||
),
|
||||
):
|
||||
result = await mgr.add_subscription(
|
||||
"https://example.com/channel",
|
||||
check_interval_minutes=60,
|
||||
download_type="video",
|
||||
codec="auto",
|
||||
format="any",
|
||||
quality="best",
|
||||
folder="",
|
||||
custom_name_prefix="",
|
||||
auto_start=True,
|
||||
playlist_item_limit=0,
|
||||
split_by_chapters=False,
|
||||
chapter_template="",
|
||||
subtitle_language="en",
|
||||
subtitle_mode="prefer_manual",
|
||||
)
|
||||
|
||||
self.assertEqual(result["status"], "ok")
|
||||
sub = mgr.list_all()[0]
|
||||
self.assertEqual(sub.seen_ids, ["v1"])
|
||||
self.assertEqual(queue.entries, [])
|
||||
|
||||
async def test_check_now_keeps_failed_queue_items_unseen_and_sets_error(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
queue = _Queue()
|
||||
mgr = SubscriptionManager(_Config(tmp), queue, _Notifier())
|
||||
|
||||
with patch(
|
||||
"subscriptions.extract_flat_playlist",
|
||||
side_effect=[
|
||||
(
|
||||
{"_type": "channel", "title": "Channel"},
|
||||
[{"id": "v1", "title": "One", "webpage_url": "https://example.com/v1"}],
|
||||
),
|
||||
(
|
||||
{"_type": "channel", "title": "Channel"},
|
||||
[{"id": "v2", "title": "Two", "webpage_url": "https://example.com/v2"}],
|
||||
),
|
||||
],
|
||||
):
|
||||
result = await mgr.add_subscription(
|
||||
"https://example.com/channel",
|
||||
check_interval_minutes=60,
|
||||
download_type="video",
|
||||
codec="auto",
|
||||
format="any",
|
||||
quality="best",
|
||||
folder="",
|
||||
custom_name_prefix="",
|
||||
auto_start=True,
|
||||
playlist_item_limit=0,
|
||||
split_by_chapters=False,
|
||||
chapter_template="",
|
||||
subtitle_language="en",
|
||||
subtitle_mode="prefer_manual",
|
||||
)
|
||||
queue.fail = True
|
||||
await mgr.check_now([result["subscription"]["id"]])
|
||||
|
||||
sub = mgr.list_all()[0]
|
||||
self.assertEqual(sub.error, "queue failed")
|
||||
self.assertEqual(sub.seen_ids, ["v1"])
|
||||
|
||||
async def test_check_now_queues_new_video_and_updates_seen_ids(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
queue = _Queue()
|
||||
mgr = SubscriptionManager(_Config(tmp), queue, _Notifier())
|
||||
|
||||
with patch(
|
||||
"subscriptions.extract_flat_playlist",
|
||||
side_effect=[
|
||||
(
|
||||
{"_type": "channel", "title": "Channel"},
|
||||
[{"id": "v1", "title": "One", "webpage_url": "https://example.com/v1"}],
|
||||
),
|
||||
(
|
||||
{"_type": "channel", "title": "Channel"},
|
||||
[
|
||||
{"id": "v2", "title": "Two", "webpage_url": "https://example.com/v2"},
|
||||
{"id": "v1", "title": "One", "webpage_url": "https://example.com/v1"},
|
||||
],
|
||||
),
|
||||
],
|
||||
):
|
||||
result = await mgr.add_subscription(
|
||||
"https://example.com/channel",
|
||||
check_interval_minutes=60,
|
||||
download_type="video",
|
||||
codec="auto",
|
||||
format="any",
|
||||
quality="best",
|
||||
folder="",
|
||||
custom_name_prefix="",
|
||||
auto_start=True,
|
||||
playlist_item_limit=0,
|
||||
split_by_chapters=False,
|
||||
chapter_template="",
|
||||
subtitle_language="en",
|
||||
subtitle_mode="prefer_manual",
|
||||
)
|
||||
await mgr.check_now([result["subscription"]["id"]])
|
||||
|
||||
sub = mgr.list_all()[0]
|
||||
self.assertIsNotNone(sub.last_checked)
|
||||
self.assertIsNone(sub.error)
|
||||
self.assertEqual(sub.seen_ids[:2], ["v2", "v1"])
|
||||
self.assertEqual([entry["webpage_url"] for entry, _, _ in queue.entries], ["https://example.com/v2"])
|
||||
|
||||
class ExtractFlatPlaylistTests(unittest.TestCase):
|
||||
def test_descends_one_level_when_root_entries_are_nested_collections(self):
|
||||
responses = iter(
|
||||
[
|
||||
{
|
||||
"_type": "channel",
|
||||
"entries": [
|
||||
{
|
||||
"_type": "url",
|
||||
"ie_key": "YoutubeTab",
|
||||
"title": "Channel - Videos",
|
||||
"url": "https://example.com/videos",
|
||||
"webpage_url": "https://example.com/videos",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"_type": "playlist",
|
||||
"entries": [
|
||||
{
|
||||
"_type": "url",
|
||||
"ie_key": "Youtube",
|
||||
"id": "v1",
|
||||
"title": "One",
|
||||
"duration": 10,
|
||||
"webpage_url": "https://example.com/v1",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
class _FakeYDL:
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def extract_info(self, url, download=False):
|
||||
return next(responses)
|
||||
|
||||
cfg = _Config(tempfile.mkdtemp())
|
||||
with patch("subscriptions.yt_dlp.YoutubeDL", _FakeYDL, create=True):
|
||||
info, entries = extract_flat_playlist(cfg, "https://example.com/channel", 50)
|
||||
|
||||
self.assertEqual(info.get("_type"), "playlist")
|
||||
self.assertEqual([entry["webpage_url"] for entry in entries], ["https://example.com/v1"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -3,13 +3,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
fake_yt_dlp = types.ModuleType("yt_dlp")
|
||||
fake_networking = types.ModuleType("yt_dlp.networking")
|
||||
fake_impersonate = types.ModuleType("yt_dlp.networking.impersonate")
|
||||
fake_utils = types.ModuleType("yt_dlp.utils")
|
||||
|
||||
|
||||
class _ImpersonateTarget:
|
||||
@staticmethod
|
||||
def from_str(value):
|
||||
return value
|
||||
|
||||
|
||||
fake_impersonate.ImpersonateTarget = _ImpersonateTarget
|
||||
fake_networking.impersonate = fake_impersonate
|
||||
fake_utils.STR_FORMAT_RE_TMPL = r"(?P<prefix>)%\((?P<has_key>{})\)(?P<format>[-0-9.]*{})"
|
||||
fake_utils.STR_FORMAT_TYPES = "diouxXeEfFgGcrsa"
|
||||
fake_yt_dlp.networking = fake_networking
|
||||
fake_yt_dlp.utils = fake_utils
|
||||
sys.modules.setdefault("yt_dlp", fake_yt_dlp)
|
||||
sys.modules.setdefault("yt_dlp.networking", fake_networking)
|
||||
sys.modules.setdefault("yt_dlp.networking.impersonate", fake_impersonate)
|
||||
sys.modules.setdefault("yt_dlp.utils", fake_utils)
|
||||
|
||||
from ytdl import (
|
||||
DownloadInfo,
|
||||
_compact_persisted_entry,
|
||||
_convert_srt_to_txt_file,
|
||||
_outtmpl_substitute_field,
|
||||
_sanitize_entry_for_pickle,
|
||||
@@ -167,6 +193,53 @@ class DownloadInfoSetstateTests(unittest.TestCase):
|
||||
di.__setstate__(state)
|
||||
self.assertEqual(di.subtitle_files, [])
|
||||
|
||||
def test_missing_optional_fields_are_defaulted(self):
|
||||
state = self._base_state(
|
||||
download_type="video",
|
||||
codec="auto",
|
||||
format="any",
|
||||
quality="best",
|
||||
)
|
||||
state.pop("folder")
|
||||
state.pop("custom_name_prefix")
|
||||
state.pop("playlist_item_limit")
|
||||
state.pop("split_by_chapters")
|
||||
state.pop("chapter_template")
|
||||
di = DownloadInfo.__new__(DownloadInfo)
|
||||
di.__setstate__(state)
|
||||
self.assertEqual(di.folder, "")
|
||||
self.assertEqual(di.custom_name_prefix, "")
|
||||
self.assertEqual(di.playlist_item_limit, 0)
|
||||
self.assertFalse(di.split_by_chapters)
|
||||
self.assertEqual(di.chapter_template, "")
|
||||
|
||||
|
||||
class CompactPersistedEntryTests(unittest.TestCase):
|
||||
def test_keeps_only_playlist_and_channel_keys(self):
|
||||
entry = {
|
||||
"playlist_index": "01",
|
||||
"playlist_title": "Playlist",
|
||||
"channel_index": "02",
|
||||
"channel_title": "Channel",
|
||||
"formats": [{"id": "huge"}],
|
||||
"description": "big blob",
|
||||
}
|
||||
|
||||
compact = _compact_persisted_entry(entry)
|
||||
|
||||
self.assertEqual(
|
||||
compact,
|
||||
{
|
||||
"playlist_index": "01",
|
||||
"playlist_title": "Playlist",
|
||||
"channel_index": "02",
|
||||
"channel_title": "Channel",
|
||||
},
|
||||
)
|
||||
|
||||
def test_returns_none_when_no_restart_relevant_keys_exist(self):
|
||||
self.assertIsNone(_compact_persisted_entry({"id": "x", "title": "y"}))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+203
-99
@@ -3,24 +3,23 @@ import shutil
|
||||
import yt_dlp
|
||||
import collections
|
||||
import collections.abc
|
||||
import copy
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
import shelve
|
||||
import time
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import logging
|
||||
import re
|
||||
import types
|
||||
import dbm
|
||||
import subprocess
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from functools import lru_cache
|
||||
|
||||
import yt_dlp.networking.impersonate
|
||||
from yt_dlp.utils import STR_FORMAT_RE_TMPL, STR_FORMAT_TYPES
|
||||
from dl_formats import get_format, get_opts, AUDIO_FORMATS
|
||||
from datetime import datetime
|
||||
from state_store import AtomicJsonStore, from_json_compatible, read_legacy_shelf, to_json_compatible
|
||||
|
||||
log = logging.getLogger('ytdl')
|
||||
|
||||
@@ -250,8 +249,100 @@ class DownloadInfo:
|
||||
|
||||
if not getattr(self, "codec", None):
|
||||
self.codec = "auto"
|
||||
if not hasattr(self, "folder"):
|
||||
self.folder = ""
|
||||
if not hasattr(self, "custom_name_prefix"):
|
||||
self.custom_name_prefix = ""
|
||||
if not hasattr(self, "playlist_item_limit"):
|
||||
self.playlist_item_limit = 0
|
||||
if not hasattr(self, "split_by_chapters"):
|
||||
self.split_by_chapters = False
|
||||
if not hasattr(self, "chapter_template"):
|
||||
self.chapter_template = ""
|
||||
if not hasattr(self, "subtitle_language"):
|
||||
self.subtitle_language = "en"
|
||||
if not hasattr(self, "subtitle_mode"):
|
||||
self.subtitle_mode = "prefer_manual"
|
||||
if not hasattr(self, "entry"):
|
||||
self.entry = None
|
||||
if not hasattr(self, "subtitle_files"):
|
||||
self.subtitle_files = []
|
||||
if not hasattr(self, "chapter_files"):
|
||||
self.chapter_files = []
|
||||
|
||||
|
||||
_PERSISTED_DOWNLOAD_FIELDS = (
|
||||
"id",
|
||||
"title",
|
||||
"url",
|
||||
"quality",
|
||||
"download_type",
|
||||
"codec",
|
||||
"format",
|
||||
"folder",
|
||||
"custom_name_prefix",
|
||||
"playlist_item_limit",
|
||||
"split_by_chapters",
|
||||
"chapter_template",
|
||||
"subtitle_language",
|
||||
"subtitle_mode",
|
||||
"status",
|
||||
"timestamp",
|
||||
"error",
|
||||
"msg",
|
||||
"filename",
|
||||
"size",
|
||||
"chapter_files",
|
||||
)
|
||||
|
||||
|
||||
def _compact_persisted_entry(entry: Any) -> Optional[dict[str, Any]]:
|
||||
if not isinstance(entry, dict):
|
||||
return None
|
||||
compact = {
|
||||
key: value
|
||||
for key, value in entry.items()
|
||||
if key.startswith("playlist") or key.startswith("channel")
|
||||
}
|
||||
return compact or None
|
||||
|
||||
|
||||
def _download_info_to_record(
|
||||
info: DownloadInfo,
|
||||
*,
|
||||
include_entry: bool,
|
||||
) -> dict[str, Any]:
|
||||
record: dict[str, Any] = {}
|
||||
for key in _PERSISTED_DOWNLOAD_FIELDS:
|
||||
if hasattr(info, key):
|
||||
value = getattr(info, key)
|
||||
if value is not None:
|
||||
record[key] = to_json_compatible(value)
|
||||
if include_entry:
|
||||
compact_entry = _compact_persisted_entry(getattr(info, "entry", None))
|
||||
if compact_entry is not None:
|
||||
record["entry"] = to_json_compatible(compact_entry)
|
||||
return record
|
||||
|
||||
|
||||
def _download_info_from_record(record: dict[str, Any]) -> DownloadInfo:
|
||||
info = DownloadInfo.__new__(DownloadInfo)
|
||||
info.__setstate__({key: from_json_compatible(value) for key, value in record.items()})
|
||||
if not hasattr(info, "msg"):
|
||||
info.msg = None
|
||||
if not hasattr(info, "percent"):
|
||||
info.percent = None
|
||||
if not hasattr(info, "speed"):
|
||||
info.speed = None
|
||||
if not hasattr(info, "eta"):
|
||||
info.eta = None
|
||||
if not hasattr(info, "status"):
|
||||
info.status = "pending"
|
||||
if not hasattr(info, "size"):
|
||||
info.size = None
|
||||
if not hasattr(info, "error"):
|
||||
info.error = None
|
||||
return info
|
||||
|
||||
class Download:
|
||||
manager = None
|
||||
@@ -502,11 +593,9 @@ class PersistentQueue:
|
||||
pdir = os.path.dirname(path)
|
||||
if not os.path.isdir(pdir):
|
||||
os.mkdir(pdir)
|
||||
with shelve.open(path, 'c'):
|
||||
pass
|
||||
|
||||
self.path = path
|
||||
self.repair()
|
||||
self.legacy_path = path
|
||||
self.path = f"{path}.json"
|
||||
self.store = AtomicJsonStore(self.path, kind=f"persistent_queue:{name}")
|
||||
self.dict = OrderedDict()
|
||||
|
||||
def load(self):
|
||||
@@ -523,16 +612,75 @@ class PersistentQueue:
|
||||
return self.dict.items()
|
||||
|
||||
def saved_items(self):
|
||||
with shelve.open(self.path, 'r') as shelf:
|
||||
return sorted(shelf.items(), key=lambda item: item[1].timestamp)
|
||||
items = [
|
||||
(item["key"], _download_info_from_record(item["info"]))
|
||||
for item in self._load_state_items()
|
||||
]
|
||||
return sorted(items, key=lambda item: item[1].timestamp)
|
||||
|
||||
def _should_persist_entry(self) -> bool:
|
||||
return self.identifier != "completed"
|
||||
|
||||
def _serialize_items(self):
|
||||
return [
|
||||
{
|
||||
"key": key,
|
||||
"info": _download_info_to_record(
|
||||
download.info,
|
||||
include_entry=self._should_persist_entry(),
|
||||
),
|
||||
}
|
||||
for key, download in self.dict.items()
|
||||
]
|
||||
|
||||
def _save_dict(self):
|
||||
self.store.save({"items": self._serialize_items()})
|
||||
|
||||
def _load_state_items(self):
|
||||
payload = self.store.load()
|
||||
if payload is not None:
|
||||
items = payload.get("items")
|
||||
if isinstance(items, list):
|
||||
compact_items = [
|
||||
{
|
||||
"key": item["key"],
|
||||
"info": _download_info_to_record(
|
||||
_download_info_from_record(item["info"]),
|
||||
include_entry=self._should_persist_entry(),
|
||||
),
|
||||
}
|
||||
for item in items
|
||||
if isinstance(item, dict) and "key" in item and "info" in item
|
||||
]
|
||||
if payload.get("schema_version") != self.store.schema_version or compact_items != items:
|
||||
self.store.save({"items": compact_items})
|
||||
return compact_items
|
||||
log.warning("PersistentQueue:%s state file did not contain an items list", self.identifier)
|
||||
return []
|
||||
|
||||
legacy_items = read_legacy_shelf(self.legacy_path)
|
||||
if legacy_items is None:
|
||||
return []
|
||||
|
||||
items = [
|
||||
{
|
||||
"key": key,
|
||||
"info": _download_info_to_record(
|
||||
value,
|
||||
include_entry=self._should_persist_entry(),
|
||||
),
|
||||
}
|
||||
for key, value in sorted(legacy_items, key=lambda item: item[1].timestamp)
|
||||
]
|
||||
self.store.save({"items": items})
|
||||
return items
|
||||
|
||||
def put(self, value):
|
||||
key = value.info.url
|
||||
old = self.dict.get(key)
|
||||
self.dict[key] = value
|
||||
try:
|
||||
with shelve.open(self.path, 'w') as shelf:
|
||||
shelf[key] = value.info
|
||||
self._save_dict()
|
||||
except Exception:
|
||||
if old is None:
|
||||
del self.dict[key]
|
||||
@@ -542,9 +690,13 @@ class PersistentQueue:
|
||||
|
||||
def delete(self, key):
|
||||
if key in self.dict:
|
||||
old = self.dict[key]
|
||||
del self.dict[key]
|
||||
with shelve.open(self.path, 'w') as shelf:
|
||||
shelf.pop(key, None)
|
||||
try:
|
||||
self._save_dict()
|
||||
except Exception:
|
||||
self.dict[key] = old
|
||||
raise
|
||||
|
||||
def next(self):
|
||||
k, v = next(iter(self.dict.items()))
|
||||
@@ -553,90 +705,6 @@ class PersistentQueue:
|
||||
def empty(self):
|
||||
return not bool(self.dict)
|
||||
|
||||
def repair(self):
|
||||
# check DB format
|
||||
type_check = subprocess.run(
|
||||
["file", self.path],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
db_type = type_check.stdout.lower()
|
||||
|
||||
# create backup (<queue>.old)
|
||||
try:
|
||||
shutil.copy2(self.path, f"{self.path}.old")
|
||||
except Exception as e:
|
||||
# if we cannot backup then its not safe to attempt a repair
|
||||
# since it could be due to a filesystem error
|
||||
log.debug(f"PersistentQueue:{self.identifier} backup failed, skipping repair")
|
||||
return
|
||||
|
||||
if "gnu dbm" in db_type:
|
||||
# perform gdbm repair
|
||||
log_prefix = f"PersistentQueue:{self.identifier} repair (dbm/file)"
|
||||
log.debug(f"{log_prefix} started")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gdbmtool", self.path],
|
||||
input="recover verbose summary\n",
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=60
|
||||
)
|
||||
log.debug(f"{log_prefix} {result.stdout}")
|
||||
if result.stderr:
|
||||
log.debug(f"{log_prefix} failed: {result.stderr}")
|
||||
except FileNotFoundError:
|
||||
log.debug(f"{log_prefix} failed: 'gdbmtool' was not found")
|
||||
|
||||
# perform null key cleanup
|
||||
log_prefix = f"PersistentQueue:{self.identifier} repair (null keys)"
|
||||
log.debug(f"{log_prefix} started")
|
||||
deleted = 0
|
||||
try:
|
||||
with dbm.open(self.path, "w") as db:
|
||||
for key in list(db.keys()):
|
||||
if len(key) > 0 and all(b == 0x00 for b in key):
|
||||
log.debug(f"{log_prefix} deleting key of length {len(key)} (all NUL bytes)")
|
||||
del db[key]
|
||||
deleted += 1
|
||||
log.debug(f"{log_prefix} done - deleted {deleted} key(s)")
|
||||
except dbm.error:
|
||||
log.debug(f"{log_prefix} failed: db type is dbm.gnu, but the module is not available (dbm.error; module support may be missing or the file may be corrupted)")
|
||||
|
||||
elif "sqlite" in db_type:
|
||||
# perform sqlite3 recovery
|
||||
log_prefix = f"PersistentQueue:{self.identifier} repair (sqlite3/file)"
|
||||
log.debug(f"{log_prefix} started")
|
||||
try:
|
||||
recover_proc = subprocess.Popen(
|
||||
["sqlite3", self.path, ".recover"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
run_result = subprocess.run(
|
||||
["sqlite3", f"{self.path}.tmp"],
|
||||
stdin=recover_proc.stdout,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
if recover_proc.stdout is not None:
|
||||
recover_proc.stdout.close()
|
||||
recover_stderr = recover_proc.stderr.read() if recover_proc.stderr is not None else ""
|
||||
recover_proc.wait(timeout=60)
|
||||
if run_result.stderr or recover_stderr:
|
||||
error_text = " ".join(part for part in [recover_stderr.strip(), run_result.stderr.strip()] if part)
|
||||
log.debug(f"{log_prefix} failed: {error_text}")
|
||||
else:
|
||||
shutil.move(f"{self.path}.tmp", self.path)
|
||||
log.debug(f"{log_prefix}{run_result.stdout or ' was successful, no output'}")
|
||||
except FileNotFoundError:
|
||||
log.debug(f"{log_prefix} failed: 'sqlite3' was not found")
|
||||
except subprocess.TimeoutExpired:
|
||||
log.debug(f"{log_prefix} failed: sqlite recovery timed out")
|
||||
|
||||
class DownloadQueue:
|
||||
def __init__(self, config, notifier):
|
||||
self.config = config
|
||||
@@ -949,6 +1017,42 @@ class DownloadQueue:
|
||||
_add_gen,
|
||||
)
|
||||
|
||||
async def add_entry(
|
||||
self,
|
||||
entry,
|
||||
download_type,
|
||||
codec,
|
||||
format,
|
||||
quality,
|
||||
folder,
|
||||
custom_name_prefix,
|
||||
playlist_item_limit,
|
||||
auto_start=True,
|
||||
split_by_chapters=False,
|
||||
chapter_template=None,
|
||||
subtitle_language="en",
|
||||
subtitle_mode="prefer_manual",
|
||||
):
|
||||
normalized_entry = copy.deepcopy(entry) if isinstance(entry, dict) else entry
|
||||
already = set()
|
||||
return await self.__add_entry(
|
||||
normalized_entry,
|
||||
download_type,
|
||||
codec,
|
||||
format,
|
||||
quality,
|
||||
folder,
|
||||
custom_name_prefix,
|
||||
playlist_item_limit,
|
||||
auto_start,
|
||||
split_by_chapters,
|
||||
chapter_template,
|
||||
subtitle_language,
|
||||
subtitle_mode,
|
||||
already,
|
||||
None,
|
||||
)
|
||||
|
||||
async def start_pending(self, ids):
|
||||
for id in ids:
|
||||
if not self.pending.exists(id):
|
||||
|
||||
Reference in New Issue
Block a user