diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index 2875494a2e..8535374b7b 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -13,13 +13,12 @@ import logging +from dataclasses import dataclass from typing import TypeVar -_T = TypeVar("_T") - import torch -# fmt:skip +import fbgemm_gpu.quantize.quantize_ops # noqa: F401 from fbgemm_gpu.quantize_utils import ( bf16_to_fp32, fp16_to_fp32, @@ -33,9 +32,8 @@ ) from fbgemm_gpu.split_embedding_configs import SparseType from torch.autograd.profiler import record_function # usort:skip -from dataclasses import dataclass -import fbgemm_gpu.quantize.quantize_ops # noqa F401 +_T = TypeVar("_T") logger: logging.Logger = logging.getLogger()