diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 7ae5974da6..d3d03d0532 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -10,9 +10,10 @@ import os import subprocess import unittest +from collections.abc import Callable, Generator from contextlib import contextmanager from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any import fbgemm_gpu import hypothesis.strategies as st @@ -27,8 +28,7 @@ # Skip pt2 compliant tag test for certain operators # TODO: remove this once the operators are pt2 compliant -# pyre-ignore -additional_decorators: dict[str, list[Callable]] = { +additional_decorators: dict[str, list[Callable[..., Any]]] = { # vbe_generate_metadata_cpu return different values from vbe_generate_metadata_meta # this fails fake_tensor test as the test expects them to be the same # fake_tensor test is added in failures_dict but failing fake_tensor test still cause pt2_compliant tag test to fail @@ -115,14 +115,12 @@ class optests: # ... # @staticmethod - # pyre-ignore[3] def generate_opcheck_tests( - test_class: Optional[unittest.TestCase] = None, + test_class: unittest.TestCase | None = None, *, fast: bool = False, - # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. - additional_decorators: Optional[dict[str, Callable]] = None, - ): + additional_decorators: dict[str, Callable[..., Any]] | None = None, + ) -> unittest.TestCase | Callable[[unittest.TestCase], unittest.TestCase]: if additional_decorators is None: additional_decorators = {} @@ -176,8 +174,9 @@ def is_inside_opcheck_mode() -> bool: return optests.is_inside_opcheck_mode() @staticmethod - # pyre-ignore[3] - def dontGenerateOpCheckTests(reason: str): + def dontGenerateOpCheckTests( + reason: str, + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: if hasattr(fbgemm_gpu, "open_source"): return lambda fun: fun import torch.testing._internal.optests as optests @@ -187,10 +186,10 @@ def dontGenerateOpCheckTests(reason: str): class TestSuite(unittest.TestCase): @contextmanager - # pyre-ignore[2] - def assertNotRaised(self, exc_type) -> None: + def assertNotRaised( + self, exc_type: type[BaseException] + ) -> Generator[None, None, None]: try: - # pyre-ignore[7] yield None except exc_type as e: raise self.failureException(e) @@ -200,10 +199,8 @@ def assertNotRaised(self, exc_type) -> None: # The problem with just torch.autograd.gradcheck is that it results in # very slow tests when composed with generate_opcheck_tests. def gradcheck( - # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. - f: Callable, - # pyre-ignore[2] - inputs: Union[torch.Tensor, tuple[Any, ...]], + f: Callable[..., Any], + inputs: torch.Tensor | tuple[Any, ...], *args: Any, **kwargs: Any, ) -> None: @@ -241,14 +238,12 @@ def gpu_memory_lt_gb(x: int) -> tuple[bool, str]: ) -# pyre-fixme[3]: Return annotation cannot be `Any`. -def skipIfRocm(reason: str = "Test currently doesn't work on the ROCm stack") -> Any: - # pyre-fixme[3]: Return annotation cannot be `Any`. - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def decorator(fn: Callable) -> Any: +def skipIfRocm( + reason: str = "Test currently doesn't work on the ROCm stack", +) -> Callable[[Callable[..., None]], Callable[..., None]]: + def decorator(fn: Callable[..., None]) -> Callable[..., None]: @wraps(fn) - # pyre-fixme[3]: Return annotation cannot be `Any`. - def wrapper(*args: Any, **kwargs: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> None: if TEST_WITH_ROCM: raise unittest.SkipTest(reason) else: @@ -259,16 +254,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator -# pyre-fixme[3]: Return annotation cannot be `Any`. def skipIfNotRocm( reason: str = "Test currently doesn work only on the ROCm stack", -) -> Any: - # pyre-fixme[3]: Return annotation cannot be `Any`. - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def decorator(fn: Callable) -> Any: +) -> Callable[[Callable[..., None]], Callable[..., None]]: + def decorator(fn: Callable[..., None]) -> Callable[..., None]: @wraps(fn) - # pyre-fixme[3]: Return annotation cannot be `Any`. - def wrapper(*args: Any, **kwargs: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> None: if TEST_WITH_ROCM: fn(*args, **kwargs) else: