diff --git a/beets/autotag/distance.py b/beets/autotag/distance.py index cc0873c8f3..e0214a0380 100644 --- a/beets/autotag/distance.py +++ b/beets/autotag/distance.py @@ -133,6 +133,7 @@ def __init__(self) -> None: self.tracks: dict[TrackInfo, Distance] = {} @cached_classproperty + @classmethod def _weights(cls) -> dict[str, float]: """A dictionary from keys to floating-point weights.""" weights_view = config["match"]["distance_weights"] diff --git a/beets/autotag/hooks.py b/beets/autotag/hooks.py index c8284c3124..97a40ecf9e 100644 --- a/beets/autotag/hooks.py +++ b/beets/autotag/hooks.py @@ -143,6 +143,7 @@ class Info(AttrDict[Any]): LEGACY_TO_LIST_FIELD: ClassVar[dict[str, str]] @cached_classproperty + @classmethod def nullable_fields(cls) -> set[str]: """Return fields that may be cleared when new metadata is applied.""" return set(config["overwrite_null"][cls.type.lower()].as_str_seq()) diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 140f1d4d0a..ba437faf79 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -326,6 +326,7 @@ class Model(ABC, Generic[D]): """ @cached_classproperty + @classmethod def _types(cls) -> dict[str, types.Type]: """Optional types for non-fixed (flexible and computed) fields.""" return {} @@ -336,6 +337,7 @@ def _types(cls) -> dict[str, types.Type]: """ @cached_classproperty + @classmethod def _queries(cls) -> dict[str, FieldQueryType]: """Named queries that use a field-like `name:value` syntax but which do not relate to any specific field. @@ -354,11 +356,13 @@ def _queries(cls) -> dict[str, FieldQueryType]: """ @cached_classproperty + @classmethod def _relation(cls): """The model that this model is closely related to.""" return cls @cached_classproperty + @classmethod def relation_join(cls) -> str: """Return the join required to include the related table in the query. @@ -367,14 +371,17 @@ def relation_join(cls) -> str: return "" @cached_classproperty + @classmethod def all_db_fields(cls) -> set[str]: return cls._fields.keys() | cls._relation._fields.keys() @cached_classproperty + @classmethod def shared_db_fields(cls) -> set[str]: return cls._fields.keys() & cls._relation._fields.keys() @cached_classproperty + @classmethod def other_db_fields(cls) -> set[str]: """Fields in the related table.""" return cls._relation._fields.keys() - cls.shared_db_fields @@ -1070,6 +1077,7 @@ class Migration(ABC): db: Database @cached_classproperty + @classmethod def name(cls) -> str: """Class name (except Migration) converted to snake case.""" name = cls.__name__.removesuffix("Migration") # type: ignore[attr-defined] diff --git a/beets/importer/tasks.py b/beets/importer/tasks.py index cbc12b62bc..16384a46bb 100644 --- a/beets/importer/tasks.py +++ b/beets/importer/tasks.py @@ -868,6 +868,7 @@ def is_archive(cls, path): return False @util.cached_classproperty + @classmethod def handlers(cls) -> list[ArchiveHandler]: """Returns a list of archive handlers. diff --git a/beets/library/models.py b/beets/library/models.py index a325458bdf..88d1ad759d 100644 --- a/beets/library/models.py +++ b/beets/library/models.py @@ -50,10 +50,12 @@ class LibModel(dbcore.Model["Library"]): length: float @cached_classproperty + @classmethod def _fields(cls) -> dict[str, types.Type]: return {f: TYPE_BY_FIELD[f] for f in sorted(cls._field_names)} @cached_classproperty + @classmethod def _types(cls) -> dict[str, types.Type]: """Return the types of the fields in this model.""" return { @@ -62,10 +64,12 @@ def _types(cls) -> dict[str, types.Type]: } @cached_classproperty + @classmethod def _queries(cls) -> dict[str, FieldQueryType]: return plugins.named_queries(cls) # type: ignore[arg-type] @cached_classproperty + @classmethod def writable_media_fields(cls) -> set[str]: return set(MediaFile.fields()) & cls._fields.keys() @@ -300,6 +304,7 @@ class Album(LibModel): _search_fields = ("album", "albumartist", "genres") @cached_classproperty + @classmethod def _types(cls) -> dict[str, types.Type]: return {**super()._types, "path": TYPE_BY_FIELD["path"]} @@ -314,10 +319,12 @@ def _types(cls) -> dict[str, types.Type]: _format_config_key = "format_album" @cached_classproperty + @classmethod def _relation(cls) -> type[Item]: return Item @cached_classproperty + @classmethod def relation_join(cls) -> str: """Return FROM clause which joins on related album items. @@ -700,6 +707,7 @@ class Item(LibModel): } @cached_classproperty + @classmethod def _queries(cls) -> dict[str, FieldQueryType]: return {**super()._queries, "singleton": dbcore.query.SingletonQuery} @@ -709,10 +717,12 @@ def _queries(cls) -> dict[str, FieldQueryType]: __album: Album | None = None @cached_classproperty + @classmethod def _relation(cls) -> type[Album]: return Album @cached_classproperty + @classmethod def relation_join(cls) -> str: """Return the FROM clause which includes related albums. @@ -1237,6 +1247,7 @@ class DefaultTemplateFunctions: _prefix = "tmpl_" @cached_classproperty + @classmethod def _func_names(cls) -> list[str]: """Names of tmpl_* functions in this class.""" return [s for s in dir(cls) if s.startswith(cls._prefix)] diff --git a/beets/metadata_plugins.py b/beets/metadata_plugins.py index bb9dbbf101..dacf06b980 100644 --- a/beets/metadata_plugins.py +++ b/beets/metadata_plugins.py @@ -179,6 +179,7 @@ class MetadataSourcePlugin(BeetsPlugin, metaclass=abc.ABCMeta): DEFAULT_DATA_SOURCE_MISMATCH_PENALTY = 0.5 @cached_classproperty + @classmethod def data_source(cls) -> str: """The data source name for this plugin. diff --git a/beets/util/__init__.py b/beets/util/__init__.py index 1733c53ba6..c5bf80ee81 100644 --- a/beets/util/__init__.py +++ b/beets/util/__init__.py @@ -62,6 +62,7 @@ MAX_FILENAME_LENGTH = 200 WINDOWS_MAGIC_PREFIX = "\\\\?\\" T = TypeVar("T") +R_co = TypeVar("R_co", covariant=True) StrPath = str | Path PathLike = StrPath | bytes Replacements = Sequence[tuple[Pattern[str], str]] @@ -1065,47 +1066,40 @@ def _worker(item: T) -> Any: pool.map(_worker, items) -class cached_classproperty(Generic[T]): +class cached_classproperty(Generic[T, R_co]): """Descriptor implementing cached class properties. + Must be used in combination with @classmethod. + Provides class-level dynamic property behavior where the getter function is called once per class and the result is cached for subsequent access. Unlike instance properties, this operates on the class rather than instances. """ - cache: ClassVar[dict[tuple[type[object], str], object]] = {} - - name: str = "" - - # Ideally, we would like to use `Callable[[type[T]], Any]` here, - # however, `mypy` is unable to see this as a **class** property, and thinks - # that this callable receives an **instance** of the object, failing the - # type check, for example: - # >>> class Album: - # >>> @cached_classproperty - # >>> def foo(cls): - # >>> reveal_type(cls) # mypy: revealed type is "Album" - # >>> return cls.bar - # - # Argument 1 to "cached_classproperty" has incompatible type - # "Callable[[Album], ...]"; expected "Callable[[type[Album]], ...]" - # - # Therefore, we just use `Any` here, which is not ideal, but works. - def __init__(self, getter: Callable[..., T]) -> None: + _cache: ClassVar[dict[tuple[type[object], str], object]] = {} + + def __init__(self, getter: Callable[[type[T]], R_co], /) -> None: """Initialize the descriptor with the property getter function.""" - self.getter: Callable[..., T] = getter + self.getter: Callable[[type[T]], R_co] = getter + self.name: str - def __set_name__(self, owner: object, name: str) -> None: + def __set_name__(self, owner: type[T], name: str, /) -> None: """Capture the attribute name this descriptor is assigned to.""" self.name = name - def __get__(self, instance: object, owner: type[object]) -> T: + def __get__(self, instance: T | None, owner: type[T], /) -> R_co: """Compute and cache if needed, and return the property value.""" key: tuple[type[object], str] = owner, self.name - if key not in self.cache: - self.cache[key] = self.getter(owner) - - return cast(T, self.cache[key]) + try: + return cast(R_co, type(self)._cache[key]) + except KeyError: + obj: R_co = self.getter.__func__(owner) # type: ignore[attr-defined] + type(self)._cache[key] = obj + return obj + + @classmethod + def clear_cache(cls) -> None: + cls._cache.clear() class LazySharedInstance(Generic[T]): diff --git a/test/conftest.py b/test/conftest.py index bff8117747..da580ca1ec 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -100,7 +100,7 @@ def pytest_assertrepr_compare(op, left, right): @pytest.fixture(autouse=True) def clear_cached_classproperty(): - cached_classproperty.cache.clear() + cached_classproperty.clear_cache() @pytest.fixture(scope="module") diff --git a/test/test_dbcore.py b/test/test_dbcore.py index 96b1b3bbcd..40c5502b2d 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -69,12 +69,14 @@ class ModelFixture1(LibModel): _indices = (Index("field_one_index", ("field_one",)),) @cached_classproperty + @classmethod def _types(cls): return { "some_float_field": dbcore.types.FLOAT, } @cached_classproperty + @classmethod def _queries(cls): return { "some_query": QueryFixture,