Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 22 additions & 31 deletions fbgemm_gpu/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import os
import subprocess
import unittest
from collections.abc import Callable, Generator
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, Optional, Union
from typing import Any

import fbgemm_gpu
import hypothesis.strategies as st
Expand All @@ -27,8 +28,7 @@

# Skip pt2 compliant tag test for certain operators
# TODO: remove this once the operators are pt2 compliant
# pyre-ignore
additional_decorators: dict[str, list[Callable]] = {
additional_decorators: dict[str, list[Callable[..., Any]]] = {
# vbe_generate_metadata_cpu return different values from vbe_generate_metadata_meta
# this fails fake_tensor test as the test expects them to be the same
# fake_tensor test is added in failures_dict but failing fake_tensor test still cause pt2_compliant tag test to fail
Expand Down Expand Up @@ -115,14 +115,12 @@ class optests:
# ...
#
@staticmethod
# pyre-ignore[3]
def generate_opcheck_tests(
test_class: Optional[unittest.TestCase] = None,
test_class: unittest.TestCase | None = None,
*,
fast: bool = False,
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
additional_decorators: Optional[dict[str, Callable]] = None,
):
additional_decorators: dict[str, Callable[..., Any]] | None = None,
) -> unittest.TestCase | Callable[[unittest.TestCase], unittest.TestCase]:
if additional_decorators is None:
additional_decorators = {}

Expand Down Expand Up @@ -176,8 +174,9 @@ def is_inside_opcheck_mode() -> bool:
return optests.is_inside_opcheck_mode()

@staticmethod
# pyre-ignore[3]
def dontGenerateOpCheckTests(reason: str):
def dontGenerateOpCheckTests(
reason: str,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
if hasattr(fbgemm_gpu, "open_source"):
return lambda fun: fun
import torch.testing._internal.optests as optests
Expand All @@ -187,10 +186,10 @@ def dontGenerateOpCheckTests(reason: str):

class TestSuite(unittest.TestCase):
@contextmanager
# pyre-ignore[2]
def assertNotRaised(self, exc_type) -> None:
def assertNotRaised(
self, exc_type: type[BaseException]
) -> Generator[None, None, None]:
try:
# pyre-ignore[7]
yield None
except exc_type as e:
raise self.failureException(e)
Expand All @@ -200,10 +199,8 @@ def assertNotRaised(self, exc_type) -> None:
# The problem with just torch.autograd.gradcheck is that it results in
# very slow tests when composed with generate_opcheck_tests.
def gradcheck(
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
f: Callable,
# pyre-ignore[2]
inputs: Union[torch.Tensor, tuple[Any, ...]],
f: Callable[..., Any],
inputs: torch.Tensor | tuple[Any, ...],
*args: Any,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -241,14 +238,12 @@ def gpu_memory_lt_gb(x: int) -> tuple[bool, str]:
)


# pyre-fixme[3]: Return annotation cannot be `Any`.
def skipIfRocm(reason: str = "Test currently doesn't work on the ROCm stack") -> Any:
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def decorator(fn: Callable) -> Any:
def skipIfRocm(
reason: str = "Test currently doesn't work on the ROCm stack",
) -> Callable[[Callable[..., None]], Callable[..., None]]:
def decorator(fn: Callable[..., None]) -> Callable[..., None]:
@wraps(fn)
# pyre-fixme[3]: Return annotation cannot be `Any`.
def wrapper(*args: Any, **kwargs: Any) -> Any:
def wrapper(*args: Any, **kwargs: Any) -> None:
if TEST_WITH_ROCM:
raise unittest.SkipTest(reason)
else:
Expand All @@ -259,16 +254,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return decorator


# pyre-fixme[3]: Return annotation cannot be `Any`.
def skipIfNotRocm(
reason: str = "Test currently doesn work only on the ROCm stack",
) -> Any:
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def decorator(fn: Callable) -> Any:
) -> Callable[[Callable[..., None]], Callable[..., None]]:
def decorator(fn: Callable[..., None]) -> Callable[..., None]:
@wraps(fn)
# pyre-fixme[3]: Return annotation cannot be `Any`.
def wrapper(*args: Any, **kwargs: Any) -> Any:
def wrapper(*args: Any, **kwargs: Any) -> None:
if TEST_WITH_ROCM:
fn(*args, **kwargs)
else:
Expand Down
Loading