mirror of
https://github.com/alexta69/metube.git
synced 2026-06-13 16:40:05 +00:00
fix pickle (closes #814)
This commit is contained in:
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from ytdl import DownloadInfo, PersistentQueue
|
||||
|
||||
@@ -71,6 +72,48 @@ class PersistentQueueTests(unittest.TestCase):
|
||||
pq2.load()
|
||||
self.assertTrue(pq2.exists("http://load.example"))
|
||||
|
||||
def test_put_rollbacks_in_memory_queue_when_shelf_write_fails(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue")
|
||||
pq = PersistentQueue("queue", path)
|
||||
dl = _FakeDownload(_make_info("http://rollback.example"))
|
||||
self.assertFalse(pq.exists("http://rollback.example"))
|
||||
|
||||
orig_open = __import__("shelve").open
|
||||
|
||||
def bad_open(filename, flag="c", *args, **kwargs):
|
||||
if flag == "w":
|
||||
raise OSError("simulated shelf failure")
|
||||
return orig_open(filename, flag, *args, **kwargs)
|
||||
|
||||
with patch("ytdl.shelve.open", bad_open):
|
||||
with self.assertRaises(OSError):
|
||||
pq.put(dl)
|
||||
|
||||
self.assertFalse(pq.exists("http://rollback.example"))
|
||||
|
||||
def test_put_rollbacks_to_previous_download_when_replace_fails(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "queue")
|
||||
pq = PersistentQueue("queue", path)
|
||||
first = _FakeDownload(_make_info("http://same.example"))
|
||||
second = _FakeDownload(_make_info("http://same.example"))
|
||||
second.info.title = "Replaced title"
|
||||
pq.put(first)
|
||||
|
||||
orig_open = __import__("shelve").open
|
||||
|
||||
def bad_open(filename, flag="c", *args, **kwargs):
|
||||
if flag == "w":
|
||||
raise OSError("simulated shelf failure")
|
||||
return orig_open(filename, flag, *args, **kwargs)
|
||||
|
||||
with patch("ytdl.shelve.open", bad_open):
|
||||
with self.assertRaises(OSError):
|
||||
pq.put(second)
|
||||
|
||||
self.assertEqual(pq.get("http://same.example").info.title, "Title")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -2,15 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from ytdl import (
|
||||
DownloadInfo,
|
||||
_convert_generators_to_lists,
|
||||
_convert_srt_to_txt_file,
|
||||
_outtmpl_substitute_field,
|
||||
_sanitize_entry_for_pickle,
|
||||
_sanitize_path_component,
|
||||
)
|
||||
|
||||
@@ -35,17 +37,41 @@ class OuttmplSubstituteFieldTests(unittest.TestCase):
|
||||
self.assertEqual(_outtmpl_substitute_field("%(other)s", "title", "x"), "%(other)s")
|
||||
|
||||
|
||||
class ConvertGeneratorsToListsTests(unittest.TestCase):
|
||||
class SanitizeEntryForPickleTests(unittest.TestCase):
|
||||
def test_nested(self):
|
||||
def g():
|
||||
yield 1
|
||||
|
||||
obj = {"a": g(), "b": [g()]}
|
||||
out = _convert_generators_to_lists(obj)
|
||||
out = _sanitize_entry_for_pickle(obj)
|
||||
self.assertEqual(out, {"a": [1], "b": [[1]]})
|
||||
pickle.dumps(out)
|
||||
|
||||
def test_plain(self):
|
||||
self.assertEqual(_convert_generators_to_lists(5), 5)
|
||||
self.assertEqual(_sanitize_entry_for_pickle(5), 5)
|
||||
|
||||
def test_set_converted_to_list(self):
|
||||
obj = {"s": {1, 2}}
|
||||
out = _sanitize_entry_for_pickle(obj)
|
||||
self.assertEqual(sorted(out["s"]), [1, 2])
|
||||
pickle.dumps(out)
|
||||
|
||||
def test_map_iterator(self):
|
||||
out = _sanitize_entry_for_pickle({"m": map(int, ["1", "2"])})
|
||||
self.assertEqual(out, {"m": [1, 2]})
|
||||
|
||||
def test_lock_replaced_with_none(self):
|
||||
lock = threading.Lock()
|
||||
out = _sanitize_entry_for_pickle({"k": lock})
|
||||
self.assertIsNone(out["k"])
|
||||
pickle.dumps(out)
|
||||
|
||||
def test_ordered_dict(self):
|
||||
from collections import OrderedDict
|
||||
|
||||
od = OrderedDict([("z", 1), ("a", 2)])
|
||||
out = _sanitize_entry_for_pickle(od)
|
||||
self.assertEqual(out, {"z": 1, "a": 2})
|
||||
|
||||
|
||||
class ConvertSrtToTxtTests(unittest.TestCase):
|
||||
|
||||
+48
-13
@@ -1,6 +1,9 @@
|
||||
import os
|
||||
import shutil
|
||||
import yt_dlp
|
||||
import collections
|
||||
import collections.abc
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
import shelve
|
||||
import time
|
||||
@@ -78,16 +81,40 @@ def _outtmpl_substitute_field(template: str, field: str, value: Any) -> str:
|
||||
|
||||
return pattern.sub(replacement, template)
|
||||
|
||||
def _convert_generators_to_lists(obj):
|
||||
"""Recursively convert generators to lists in a dictionary to make it pickleable."""
|
||||
if isinstance(obj, types.GeneratorType):
|
||||
return list(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: _convert_generators_to_lists(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return type(obj)(_convert_generators_to_lists(item) for item in obj)
|
||||
else:
|
||||
_MAX_ENTRY_SANITIZE_DEPTH = 64
|
||||
|
||||
|
||||
def _sanitize_entry_for_pickle(obj, _depth=0):
|
||||
"""Recursively normalize yt-dlp ``info_dict`` data so it can be stored in shelve/pickle.
|
||||
|
||||
Live streams and newer yt-dlp versions may nest generators, iterators, sets, or
|
||||
non-serializable objects (e.g. locks) inside the extracted metadata. The previous
|
||||
helper only walked plain dict/list/tuple and only expanded ``types.GeneratorType``.
|
||||
"""
|
||||
if _depth > _MAX_ENTRY_SANITIZE_DEPTH:
|
||||
return None
|
||||
if obj is None or isinstance(obj, (bool, int, float, str, bytes)):
|
||||
return obj
|
||||
if isinstance(obj, types.GeneratorType):
|
||||
return _sanitize_entry_for_pickle(list(obj), _depth + 1)
|
||||
if isinstance(obj, collections.abc.Mapping):
|
||||
return {k: _sanitize_entry_for_pickle(v, _depth + 1) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return type(obj)(_sanitize_entry_for_pickle(x, _depth + 1) for x in obj)
|
||||
if isinstance(obj, (set, frozenset)):
|
||||
return [_sanitize_entry_for_pickle(x, _depth + 1) for x in obj]
|
||||
if isinstance(obj, collections.deque):
|
||||
return [_sanitize_entry_for_pickle(x, _depth + 1) for x in obj]
|
||||
if isinstance(obj, collections.abc.Iterator):
|
||||
try:
|
||||
return _sanitize_entry_for_pickle(list(obj), _depth + 1)
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return obj
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _convert_srt_to_txt_file(subtitle_path: str):
|
||||
@@ -178,8 +205,8 @@ class DownloadInfo:
|
||||
self.size = None
|
||||
self.timestamp = time.time_ns()
|
||||
self.error = error
|
||||
# Convert generators to lists to make entry pickleable
|
||||
self.entry = _convert_generators_to_lists(entry) if entry is not None else None
|
||||
# Strip non-pickleable values (generators, iterators, locks, etc.) for shelve
|
||||
self.entry = _sanitize_entry_for_pickle(entry) if entry is not None else None
|
||||
self.playlist_item_limit = playlist_item_limit
|
||||
self.split_by_chapters = split_by_chapters
|
||||
self.chapter_template = chapter_template
|
||||
@@ -501,9 +528,17 @@ class PersistentQueue:
|
||||
|
||||
def put(self, value):
|
||||
key = value.info.url
|
||||
old = self.dict.get(key)
|
||||
self.dict[key] = value
|
||||
with shelve.open(self.path, 'w') as shelf:
|
||||
shelf[key] = value.info
|
||||
try:
|
||||
with shelve.open(self.path, 'w') as shelf:
|
||||
shelf[key] = value.info
|
||||
except Exception:
|
||||
if old is None:
|
||||
del self.dict[key]
|
||||
else:
|
||||
self.dict[key] = old
|
||||
raise
|
||||
|
||||
def delete(self, key):
|
||||
if key in self.dict:
|
||||
|
||||
Reference in New Issue
Block a user