diff --git a/Makefile b/Makefile index 14bfd335..e80df798 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ # # ============================================================================ -.PHONY: format lint lint-pylint lint-flake8 test test-unit test-cov check all install clean help +.PHONY: format lint lint-fix test test-unit test-cov check all install clean help # Default paths (can be overridden with FILES=path/to/file.py) SRC_DIR = src/nba_api @@ -46,11 +46,9 @@ FILES ?= $(SRC_DIR) $(TEST_DIR) help: @echo "NBA API Makefile Commands:" @echo "" - @echo " make install - Install dependencies with poetry" - @echo " make format - Format code with black and isort" - @echo " make lint - Run all linters (flake8 + pylint)" - @echo " make lint-flake8 - Run only flake8" - @echo " make lint-pylint - Run only pylint" + @echo " make install - Install dependencies with poetry" + @echo " make format - Format code with ruff" + @echo " make lint - Run ruff linter" @echo " make test - Run all unit tests" @echo " make test-unit - Run unit tests (same as test)" @echo " make test-cov - Run tests with coverage report" @@ -60,31 +58,26 @@ help: @echo "" @echo "Examples with specific files:" @echo " make lint FILES=src/nba_api/stats/endpoints/dunkscoreleaders.py" - @echo " make lint-pylint FILES='src/nba_api/stats/endpoints/dunkscoreleaders.py tests/unit/stats/endpoints/test_dunkscoreleaders.py'" @echo " make format FILES=src/nba_api/stats/endpoints/dunkscoreleaders.py" # Install dependencies install: poetry install --sync -# Format code with isort and black (isort first to take precedence) +# Format code with ruff format: @echo "Formatting: $(FILES)" - poetry run isort $(FILES) - poetry run black $(FILES) + poetry run ruff format $(FILES) -# Run flake8 linter -lint-flake8: - @echo "Running flake8 on: $(FILES)" - poetry run flake8 $(FILES) +# Run ruff linter +lint: + @echo "Linting: $(FILES)" + poetry run ruff check $(FILES) -# Run pylint -lint-pylint: - @echo "Running pylint on: $(FILES)" - poetry run pylint $(FILES) || true - -# Run all linters -lint: lint-flake8 lint-pylint +# Run ruff linter with autofix +lint-fix: + @echo "Lint-fixing: $(FILES)" + poetry run ruff check $(FILES) --fix # Run all unit tests test-unit: diff --git a/src/nba_api/library/http.py b/src/nba_api/library/http.py index aeac18ee..5184e45f 100644 --- a/src/nba_api/library/http.py +++ b/src/nba_api/library/http.py @@ -30,44 +30,53 @@ class NBAResponse: - def __init__(self, response, status_code, url): + def __init__(self, response: str, status_code: int | None, url: str | None) -> None: self._response = response self._status_code = status_code self._url = url + self._dict_cache: dict | None = None + self._json_cache: str | None = None - def get_response(self): + def get_response(self) -> str: return self._response - def get_dict(self): - return json.loads(self._response) + def get_dict(self) -> dict: + if self._dict_cache is None: + self._dict_cache = json.loads(self._response) + return self._dict_cache - def get_json(self): - return json.dumps(self.get_dict()) + def get_json(self) -> str: + if self._json_cache is None: + self._json_cache = json.dumps(self.get_dict()) + return self._json_cache - def valid_json(self): + def valid_json(self) -> bool: try: self.get_dict() except ValueError: return False return True - def get_url(self): + def get_url(self) -> str | None: return self._url + def get_status_code(self) -> int | None: + return self._status_code + class NBAHTTP: - nba_response = NBAResponse + nba_response: type[NBAResponse] = NBAResponse - base_url = None + base_url: str | None = None - parameters = None + parameters: tuple | None = None - headers = None + headers: dict[str, str] | None = None - _session = None + _session: requests.Session | None = None @classmethod - def get_session(cls): + def get_session(cls) -> requests.Session: session = cls._session if session is None: session = requests.Session() @@ -75,32 +84,34 @@ def get_session(cls): return session @classmethod - def set_session(cls, session) -> None: + def set_session(cls, session: requests.Session) -> None: cls._session = session - def clean_contents(self, contents): + def clean_contents(self, contents: str) -> str: return contents def send_api_request( self, - endpoint, - parameters, - referer=None, - proxy=None, - headers=None, - timeout=None, - raise_exception_on_error=False, - ): + endpoint: str, + parameters: dict[str, str | None], + referer: str | None = None, + proxy: str | list[str] | None = None, + headers: dict[str, str] | None = None, + timeout: int | None = None, + raise_exception_on_error: bool = False, + ) -> NBAResponse: if not self.base_url: raise Exception("Cannot use send_api_request from _HTTP class.") base_url = self.base_url.format(endpoint=endpoint) endpoint = endpoint.lower() self.parameters = parameters - request_headers = self.headers if headers is None else headers - - if referer: - request_headers["Referer"] = referer + if headers is not None: + request_headers = headers + elif referer: + request_headers = {**self.headers, "Referer": referer} + else: + request_headers = self.headers if proxy is None: request_proxy = PROXY @@ -126,8 +137,8 @@ def send_api_request( contents = None file_path = None - # Sort parameters by key... for some reason this matters for some requests... - parameters = sorted(parameters.items(), key=lambda kv: kv[0]) + # tuples are faster to handle and iterate + parameters = tuple(sorted(parameters.items(), key=lambda kv: kv[0])) if DEBUG and DEBUG_STORAGE: print(endpoint, parameters) @@ -151,6 +162,7 @@ def send_api_request( if os.path.isfile(file_path): with open(file_path) as f: contents = f.read() + status_code = 200 print("loading from file...") if not contents: @@ -173,7 +185,14 @@ def send_api_request( data = self.nba_response(response=contents, status_code=status_code, url=url) - if raise_exception_on_error and not data.valid_json(): - raise Exception("InvalidResponse: Response is not in a valid JSON format.") + if raise_exception_on_error: + if status_code is not None and status_code >= 400: + raise Exception( + f"HTTPError: Request failed with status code {status_code}." + ) + if not data.valid_json(): + raise Exception( + "InvalidResponse: Response is not in a valid JSON format." + ) return data diff --git a/src/nba_api/live/nba/endpoints/_base.py b/src/nba_api/live/nba/endpoints/_base.py index 382006c6..f99c19b9 100644 --- a/src/nba_api/live/nba/endpoints/_base.py +++ b/src/nba_api/live/nba/endpoints/_base.py @@ -1,30 +1,36 @@ import json +from typing import Any class Endpoint: class DataSet: - key = None - data = {} + key: str | None = None - def __init__(self, data=None): + def __init__(self, data: dict[str, Any] | list | None = None) -> None: if data is None: data = {} self.data = data - def get_json(self): + def get_json(self) -> str: return json.dumps(self.data) - def get_dict(self): + def get_dict(self) -> dict[str, Any] | list: return self.data - def get_request_url(self): + nba_response: Any = None + data_sets: list[DataSet] | None = None + + def get_request_url(self) -> str: return self.nba_response.get_url() - def get_response(self): + def get_response(self) -> str: return self.nba_response.get_response() - def get_dict(self): + def get_status_code(self) -> int: + return self.nba_response.get_status_code() + + def get_dict(self) -> dict[str, Any]: return self.nba_response.get_dict() - def get_json(self): + def get_json(self) -> str: return self.nba_response.get_json() diff --git a/src/nba_api/live/nba/library/http.py b/src/nba_api/live/nba/library/http.py index 38289981..fef6e792 100644 --- a/src/nba_api/live/nba/library/http.py +++ b/src/nba_api/live/nba/library/http.py @@ -1,9 +1,9 @@ from nba_api.library import http try: - from nba_api.library.debug.debug import STATS_HEADERS + from nba_api.library.debug.debug import LIVE_HEADERS except ImportError: - STATS_HEADERS = { + LIVE_HEADERS = { "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", "Accept-Encoding": "gzip, deflate, br", "Accept-Language": "en-US,en;q=0.9", @@ -17,9 +17,9 @@ class NBALiveHTTP(http.NBAHTTP): nba_response = http.NBAResponse base_url = "https://cdn.nba.com/static/json/liveData/{endpoint}" - headers = STATS_HEADERS + headers = LIVE_HEADERS - def clean_contents(self, contents): + def clean_contents(self, contents: str) -> str: if '{"Message":"An error has occurred."}' in contents: return "An error has occurred." return contents diff --git a/src/nba_api/stats/endpoints/_base.py b/src/nba_api/stats/endpoints/_base.py index 1520de65..11433e5b 100644 --- a/src/nba_api/stats/endpoints/_base.py +++ b/src/nba_api/stats/endpoints/_base.py @@ -16,7 +16,6 @@ class Endpoint: class DataSet: key: str | None = None - data: dict[str, Any] = {} def __init__(self, data: dict[str, Any]) -> None: self.data = data @@ -36,8 +35,9 @@ def get_data_frame(self) -> DataFrame: Exception: If pandas is not installed. """ if not PANDAS: - raise Exception( - "Import Missing - Failed to import DataFrame from pandas." + raise ImportError( + "Failed to import DataFrame from pandas. " + "Install pandas: pip install pandas" ) if "headers" not in self.data or not self.data["headers"]: @@ -53,14 +53,8 @@ def get_data_frame(self) -> DataFrame: len(self.data["headers"]) ): # Extend column names for level to full length level = self.data["headers"][i] - level_names.append( - level["name"] if "name" in level else "LEVEL_" + str(i) - ) - column_names = ( - [""] * level["columnsToSkip"] - if "columnsToSkip" in level - else [] - ) + level_names.append(level.get("name", f"LEVEL_{i}")) + column_names = [""] * level.get("columnsToSkip", 0) column_names += list( np.repeat( np.array(level["columnNames"]), @@ -74,7 +68,7 @@ def get_data_frame(self) -> DataFrame: return DataFrame(self.data["data"], columns=midx) nba_response: Any = None - data_sets: list[DataSet] = [] + data_sets: list[DataSet] | None = None def get_request_url(self) -> str: """Return the URL of the request.""" @@ -82,12 +76,16 @@ def get_request_url(self) -> str: def get_available_data(self) -> list[str]: """Return the keys of the available data sets.""" - return self.get_normalized_dict().keys() + return list(self.get_normalized_dict().keys()) def get_response(self) -> str: """Return the raw response string.""" return self.nba_response.get_response() + def get_status_code(self) -> int: + """Return the HTTP status code of the response.""" + return self.nba_response.get_status_code() + def get_dict(self) -> dict[str, Any]: """Return the response as a dictionary.""" return self.nba_response.get_dict() diff --git a/src/nba_api/stats/endpoints/_parsers/__init__.py b/src/nba_api/stats/endpoints/_parsers/__init__.py index 100a13c8..32ee40ea 100644 --- a/src/nba_api/stats/endpoints/_parsers/__init__.py +++ b/src/nba_api/stats/endpoints/_parsers/__init__.py @@ -86,10 +86,10 @@ def get_parser_for_endpoint(endpoint, nba_dict): nba_dict (dict): The raw API response dictionary. Returns: - Parser instance configured with the provided data. - - Raises: - KeyError: If the endpoint doesn't have a registered parser. + Parser instance configured with the provided data, or None if + no parser is registered for this endpoint. """ - parser_class = _PARSER_REGISTRY[endpoint] + parser_class = _PARSER_REGISTRY.get(endpoint) + if parser_class is None: + return None return parser_class(nba_dict) diff --git a/src/nba_api/stats/library/http.py b/src/nba_api/stats/library/http.py index 558ee814..f113c1f8 100644 --- a/src/nba_api/stats/library/http.py +++ b/src/nba_api/stats/library/http.py @@ -29,72 +29,75 @@ class NBAStatsResponse(http.NBAResponse): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._endpoint = None + self._normalized_dict_cache = None @staticmethod def _build_rows(headers, row_set): return [dict(zip(headers, raw_row, strict=False)) for raw_row in row_set] + @staticmethod + def _get_legacy_results(raw_dict): + return raw_dict.get("resultSets") or raw_dict.get("resultSet") + def get_normalized_dict(self): + if self._normalized_dict_cache is not None: + return self._normalized_dict_cache + raw_data = self.get_dict() data = {} - legacy_headers = ["resultSets", "resultSet"] - is_legacy = set(legacy_headers) & set(raw_data.keys()) + legacy_headers = {"resultSets", "resultSet"} + raw_keys = raw_data.keys() + is_legacy = bool(legacy_headers & raw_keys) if is_legacy: - if "resultSets" in raw_data: - results = raw_data["resultSets"] - if "Meta" in results: - return results - else: - results = raw_data["resultSet"] + results = self._get_legacy_results(raw_data) + if results and "Meta" in results: + self._normalized_dict_cache = results + return results if isinstance(results, dict): results = [results] for result in results: name = result["name"] data[name] = self._build_rows(result["headers"], result["rowSet"]) elif self._endpoint is not None: - try: - from nba_api.stats.endpoints._parsers import get_parser_for_endpoint + from nba_api.stats.endpoints._parsers import get_parser_for_endpoint - endpoint_parser = get_parser_for_endpoint(self._endpoint, raw_data) + endpoint_parser = get_parser_for_endpoint(self._endpoint, raw_data) + if endpoint_parser is not None: for name, dataset in endpoint_parser.get_data_sets().items(): data[name] = self._build_rows(dataset["headers"], dataset["data"]) - except (KeyError, ImportError): - pass + self._normalized_dict_cache = data return data def get_normalized_json(self): + if self._normalized_dict_cache is not None: + return json.dumps(self._normalized_dict_cache) return json.dumps(self.get_normalized_dict()) def get_parameters(self): - if not self.valid_json() or "parameters" not in self.get_dict(): + raw = self.get_dict() if self.valid_json() else None + if raw is None or "parameters" not in raw: return None - parameters = self.get_dict()["parameters"] + parameters = raw["parameters"] if isinstance(parameters, dict): return parameters - parameters = {} - for parameter in self.get_dict()["parameters"]: + result = {} + for parameter in parameters: for key, value in parameter.items(): - parameters.update({key: value}) - return parameters + result[key] = value + return result def get_headers_from_data_sets(self): raw_dict = self.get_dict() - legacy_headers = ["resultSets", "resultSet"] - is_legacy = set(legacy_headers) & set(raw_dict.keys()) - if not is_legacy: + results = self._get_legacy_results(raw_dict) + if results is None: return {} - - if "resultSets" in raw_dict: - results = raw_dict["resultSets"] - else: - results = raw_dict["resultSet"] if isinstance(results, dict): if "name" not in results: return {} @@ -108,10 +111,9 @@ def get_data_sets(self, endpoint=None): self._endpoint = endpoint if endpoint is None: - if "resultSets" in raw_dict: - results = raw_dict["resultSets"] - else: - results = raw_dict["resultSet"] + results = self._get_legacy_results(raw_dict) + if results is None: + return {} if isinstance(results, dict): if "name" not in results: return {} @@ -132,6 +134,8 @@ def get_data_sets(self, endpoint=None): from nba_api.stats.endpoints._parsers import get_parser_for_endpoint endpoint_parser = get_parser_for_endpoint(endpoint, self.get_dict()) + if endpoint_parser is None: + return {} return endpoint_parser.get_data_sets() diff --git a/src/nba_api/stats/static/players.py b/src/nba_api/stats/static/players.py index 4b4e007a..e23f446f 100644 --- a/src/nba_api/stats/static/players.py +++ b/src/nba_api/stats/static/players.py @@ -1,3 +1,4 @@ +import functools import re import unicodedata @@ -11,17 +12,22 @@ wnba_players, ) +# Pre-built index for O(1) ID lookup +_players_by_id = {p[player_index_id]: p for p in players} +_wnba_players_by_id = {p[player_index_id]: p for p in wnba_players} -def _find_players(regex_pattern, row_id, players=players): - players_found = [] - for player in players: - if re.search( - _strip_accents(regex_pattern), - _strip_accents(str(player[row_id])), - flags=re.I, - ): - players_found.append(_get_player_dict(player)) - return players_found +# Pre-computed cached lists +_cached_players = None +_cached_active_players = None +_cached_inactive_players = None +_cached_wnba_players = None +_cached_wnba_active_players = None +_cached_wnba_inactive_players = None + + +@functools.lru_cache(maxsize=128) +def _compile_regex(pattern): + return re.compile(_strip_accents(pattern), flags=re.I) def _strip_accents(inputstr: str) -> str: @@ -36,38 +42,72 @@ def _strip_accents(inputstr: str) -> str: ) -def _find_player_by_id(player_id, players=players): - regex_pattern = f"^{player_id}$" - players_list = _find_players(regex_pattern, player_index_id, players=players) - if len(players_list) > 1: - raise Exception("Found more than 1 id") - elif not players_list: - return None - else: - return players_list[0] - - -def _get_players(players=players): - players_list = [] - for player in players: - players_list.append(_get_player_dict(player)) - return players_list - - -def _get_active_players(players=players): - players_list = [] - for player in players: - if player[player_index_is_active]: - players_list.append(_get_player_dict(player)) - return players_list - - -def _get_inactive_players(players=players): - players_list = [] - for player in players: - if not player[player_index_is_active]: - players_list.append(_get_player_dict(player)) - return players_list +def _find_players(regex_pattern, row_id, players=players): + compiled = _compile_regex(regex_pattern) + return [ + _get_player_dict(player) + for player in players + if compiled.search(_strip_accents(str(player[row_id]))) + ] + + +def _find_player_by_id(player_id, _index=_players_by_id): + player = _index.get(player_id) + return _get_player_dict(player) if player is not None else None + + +def _get_players(players=players, _cache=False): + global _cached_players, _cached_wnba_players + if _cache: + if players is wnba_players: + if _cached_wnba_players is None: + _cached_wnba_players = [_get_player_dict(p) for p in players] + return _cached_wnba_players + else: + if _cached_players is None: + _cached_players = [_get_player_dict(p) for p in players] + return _cached_players + return [_get_player_dict(p) for p in players] + + +def _get_active_players(players=players, _cache=False): + global _cached_active_players, _cached_wnba_active_players + if _cache: + if players is wnba_players: + if _cached_wnba_active_players is None: + _cached_wnba_active_players = [ + _get_player_dict(p) for p in players if p[player_index_is_active] + ] + return _cached_wnba_active_players + else: + if _cached_active_players is None: + _cached_active_players = [ + _get_player_dict(p) for p in players if p[player_index_is_active] + ] + return _cached_active_players + return [_get_player_dict(p) for p in players if p[player_index_is_active]] + + +def _get_inactive_players(players=players, _cache=False): + global _cached_inactive_players, _cached_wnba_inactive_players + if _cache: + if players is wnba_players: + if _cached_wnba_inactive_players is None: + _cached_wnba_inactive_players = [ + _get_player_dict(p) + for p in players + if not p[player_index_is_active] + ] + return _cached_wnba_inactive_players + else: + if _cached_inactive_players is None: + _cached_inactive_players = [ + _get_player_dict(p) + for p in players + if not p[player_index_is_active] + ] + return _cached_inactive_players + return [_get_player_dict(p) for p in players if not p[player_index_is_active]] def _get_player_dict(player_row): @@ -80,57 +120,57 @@ def _get_player_dict(player_row): } -def find_players_by_full_name(regex_pattern): +def find_players_by_full_name(regex_pattern: str) -> list[dict]: return _find_players(regex_pattern, player_index_full_name) -def find_players_by_first_name(regex_pattern): +def find_players_by_first_name(regex_pattern: str) -> list[dict]: return _find_players(regex_pattern, player_index_first_name) -def find_players_by_last_name(regex_pattern): +def find_players_by_last_name(regex_pattern: str) -> list[dict]: return _find_players(regex_pattern, player_index_last_name) -def find_player_by_id(player_id): +def find_player_by_id(player_id: int) -> dict | None: return _find_player_by_id(player_id) -def get_players(): - return _get_players() +def get_players() -> list[dict]: + return _get_players(_cache=True) -def get_active_players(): - return _get_active_players() +def get_active_players() -> list[dict]: + return _get_active_players(_cache=True) -def get_inactive_players(): - return _get_inactive_players() +def get_inactive_players() -> list[dict]: + return _get_inactive_players(_cache=True) -def find_wnba_players_by_full_name(regex_pattern): +def find_wnba_players_by_full_name(regex_pattern: str) -> list[dict]: return _find_players(regex_pattern, player_index_full_name, players=wnba_players) -def find_wnba_players_by_first_name(regex_pattern): +def find_wnba_players_by_first_name(regex_pattern: str) -> list[dict]: return _find_players(regex_pattern, player_index_first_name, players=wnba_players) -def find_wnba_players_by_last_name(regex_pattern): +def find_wnba_players_by_last_name(regex_pattern: str) -> list[dict]: return _find_players(regex_pattern, player_index_last_name, players=wnba_players) -def find_wnba_player_by_id(player_id): - return _find_player_by_id(player_id, players=wnba_players) +def find_wnba_player_by_id(player_id: int) -> dict | None: + return _find_player_by_id(player_id, _index=_wnba_players_by_id) -def get_wnba_players(): - return _get_players(players=wnba_players) +def get_wnba_players() -> list[dict]: + return _get_players(players=wnba_players, _cache=True) -def get_wnba_active_players(): - return _get_active_players(players=wnba_players) +def get_wnba_active_players() -> list[dict]: + return _get_active_players(players=wnba_players, _cache=True) -def get_wnba_inactive_players(): - return _get_inactive_players(players=wnba_players) +def get_wnba_inactive_players() -> list[dict]: + return _get_inactive_players(players=wnba_players, _cache=True) diff --git a/src/nba_api/stats/static/teams.py b/src/nba_api/stats/static/teams.py index e494f8fe..ba20c9fe 100644 --- a/src/nba_api/stats/static/teams.py +++ b/src/nba_api/stats/static/teams.py @@ -1,4 +1,6 @@ +import functools import re +import unicodedata from nba_api.stats.library.data import ( team_index_abbreviation, @@ -13,57 +15,72 @@ wnba_teams, ) +# Pre-built indexes for O(1) lookups +_teams_by_id = {t[team_index_id]: t for t in teams} +_teams_by_abbreviation = {t[team_index_abbreviation]: t for t in teams} +_wnba_teams_by_id = {t[team_index_id]: t for t in wnba_teams} +_wnba_teams_by_abbreviation = {t[team_index_abbreviation]: t for t in wnba_teams} + +# Pre-computed cached lists +_cached_teams = None +_cached_wnba_teams = None + + +@functools.lru_cache(maxsize=128) +def _compile_regex(pattern): + return re.compile(_strip_accents(pattern), flags=re.I) + + +def _strip_accents(inputstr: str) -> str: + normalizedstr = unicodedata.normalize("NFD", inputstr) + return "".join(c for c in normalizedstr if unicodedata.category(c) != "Mn") + def _find_teams(regex_pattern, row_id, teams=teams): - teams_found = [] - for team in teams: - if re.search(regex_pattern, str(team[row_id]), flags=re.I): - teams_found.append(_get_team_dict(team)) - return teams_found - - -def _find_team_name_by_id(team_id, teams=teams): - regex_pattern = f"^{team_id}$" - teams_list = _find_teams(regex_pattern, team_index_id, teams=teams) - if len(teams_list) > 1: - raise Exception("Found more than 1 id") - elif not teams_list: - return None - else: - return teams_list[0] - - -def _find_team_by_abbreviation(abbreviation, teams=teams): - regex_pattern = f"^{abbreviation}$" - teams_list = _find_teams(regex_pattern, team_index_abbreviation, teams=teams) - if len(teams_list) > 1: - raise Exception("Found more than 1 id") - elif not teams_list: - return None - else: - return teams_list[0] + compiled = _compile_regex(regex_pattern) + return [ + _get_team_dict(team) + for team in teams + if compiled.search(_strip_accents(str(team[row_id]))) + ] + + +def _find_team_name_by_id(team_id, _index=_teams_by_id): + team = _index.get(team_id) + return _get_team_dict(team) if team is not None else None + + +def _find_team_by_abbreviation(abbreviation, _index=_teams_by_abbreviation): + team = _index.get(abbreviation.upper()) + return _get_team_dict(team) if team is not None else None def _find_teams_by_championship_year(year, teams=teams): - for team in teams: - if year in team[team_index_championship_year]: - result = team[team_index_full_name] - return result + return [ + _get_team_dict(team) + for team in teams + if year in team[team_index_championship_year] + ] def _find_teams_by_year_founded(year, teams=teams): - teams_found = [] - for team in teams: - if team[team_index_year_founded] == year: - teams_found.append(_get_team_dict(team)) - return teams_found + return [ + _get_team_dict(team) for team in teams if team[team_index_year_founded] == year + ] -def _get_teams(teams=teams): - teams_list = [] - for team in teams: - teams_list.append(_get_team_dict(team)) - return teams_list +def _get_teams(teams=teams, _cache=False): + global _cached_teams, _cached_wnba_teams + if _cache: + if teams is wnba_teams: + if _cached_wnba_teams is None: + _cached_wnba_teams = [_get_team_dict(t) for t in teams] + return _cached_wnba_teams + else: + if _cached_teams is None: + _cached_teams = [_get_team_dict(t) for t in teams] + return _cached_teams + return [_get_team_dict(t) for t in teams] def _get_team_dict(team_row): @@ -78,73 +95,73 @@ def _get_team_dict(team_row): } -def find_teams_by_full_name(regex_pattern): +def find_teams_by_full_name(regex_pattern: str) -> list[dict]: return _find_teams(regex_pattern, team_index_full_name) -def find_teams_by_state(regex_pattern): +def find_teams_by_state(regex_pattern: str) -> list[dict]: return _find_teams(regex_pattern, team_index_state) -def find_teams_by_city(regex_pattern): +def find_teams_by_city(regex_pattern: str) -> list[dict]: return _find_teams(regex_pattern, team_index_city) -def find_teams_by_nickname(regex_pattern): +def find_teams_by_nickname(regex_pattern: str) -> list[dict]: return _find_teams(regex_pattern, team_index_nickname) -def find_teams_by_year_founded(year): +def find_teams_by_year_founded(year: int) -> list[dict]: return _find_teams_by_year_founded(year) -def find_teams_by_championship_year(year): +def find_teams_by_championship_year(year: int) -> list[dict]: return _find_teams_by_championship_year(year) -def find_team_by_abbreviation(abbreviation): +def find_team_by_abbreviation(abbreviation: str) -> dict | None: return _find_team_by_abbreviation(abbreviation) -def find_team_name_by_id(team_id): +def find_team_name_by_id(team_id: int) -> dict | None: return _find_team_name_by_id(team_id) -def get_teams(): - return _get_teams() +def get_teams() -> list[dict]: + return _get_teams(_cache=True) -def find_wnba_teams_by_full_name(regex_pattern): +def find_wnba_teams_by_full_name(regex_pattern: str) -> list[dict]: return _find_teams(regex_pattern, team_index_full_name, teams=wnba_teams) -def find_wnba_teams_by_state(regex_pattern): +def find_wnba_teams_by_state(regex_pattern: str) -> list[dict]: return _find_teams(regex_pattern, team_index_state, teams=wnba_teams) -def find_wnba_teams_by_city(regex_pattern): +def find_wnba_teams_by_city(regex_pattern: str) -> list[dict]: return _find_teams(regex_pattern, team_index_city, teams=wnba_teams) -def find_wnba_teams_by_nickname(regex_pattern): +def find_wnba_teams_by_nickname(regex_pattern: str) -> list[dict]: return _find_teams(regex_pattern, team_index_nickname, teams=wnba_teams) -def find_wnba_teams_by_year_founded(year): +def find_wnba_teams_by_year_founded(year: int) -> list[dict]: return _find_teams_by_year_founded(year, teams=wnba_teams) -def find_wnba_teams_by_championship_year(year): +def find_wnba_teams_by_championship_year(year: int) -> list[dict]: return _find_teams_by_championship_year(year, teams=wnba_teams) -def find_wnba_team_by_abbreviation(abbreviation): - return _find_team_by_abbreviation(abbreviation, teams=wnba_teams) +def find_wnba_team_by_abbreviation(abbreviation: str) -> dict | None: + return _find_team_by_abbreviation(abbreviation, _index=_wnba_teams_by_abbreviation) -def find_wnba_team_name_by_id(team_id): - return _find_team_name_by_id(team_id, teams=wnba_teams) +def find_wnba_team_name_by_id(team_id: int) -> dict | None: + return _find_team_name_by_id(team_id, _index=_wnba_teams_by_id) -def get_wnba_teams(): - return _get_teams(teams=wnba_teams) +def get_wnba_teams() -> list[dict]: + return _get_teams(teams=wnba_teams, _cache=True) diff --git a/tests/unit/http/test_http_fixes.py b/tests/unit/http/test_http_fixes.py new file mode 100644 index 00000000..51370f41 --- /dev/null +++ b/tests/unit/http/test_http_fixes.py @@ -0,0 +1,164 @@ +"""Tests for bug fixes in HTTP layer.""" + +from unittest.mock import Mock + +import pytest +import requests + +from nba_api.library.http import NBAHTTP, NBAResponse +from nba_api.stats.library.http import NBAStatsHTTP, NBAStatsResponse + + +@pytest.fixture(autouse=True) +def cleanup(): + NBAHTTP._session = None + yield + NBAHTTP._session = None + + +@pytest.fixture +def mock_session(): + session = Mock(spec=requests.Session) + mock_response = Mock() + mock_response.text = '{"resultSets": []}' + mock_response.status_code = 200 + mock_response.url = "https://stats.nba.com/stats/test" + session.get.return_value = mock_response + return session + + +class TestHeadersMutationFix: + """Headers dict must not be mutated when referer is passed.""" + + def test_referer_does_not_mutate_class_headers(self, mock_session): + NBAHTTP.set_session(mock_session) + http = NBAStatsHTTP() + original_headers = http.headers.copy() + + http.send_api_request( + endpoint="test", + parameters={}, + referer="https://www.nba.com/game/123", + ) + + assert http.headers == original_headers + + def test_referer_is_sent_in_request(self, mock_session): + NBAHTTP.set_session(mock_session) + http = NBAStatsHTTP() + + http.send_api_request( + endpoint="test", + parameters={}, + referer="https://www.nba.com/game/123", + ) + + call_kwargs = mock_session.get.call_args.kwargs + assert call_kwargs["headers"]["Referer"] == "https://www.nba.com/game/123" + + def test_no_referer_uses_original_headers_object(self, mock_session): + NBAHTTP.set_session(mock_session) + http = NBAStatsHTTP() + original_id = id(http.headers) + + http.send_api_request( + endpoint="test", + parameters={}, + ) + + call_kwargs = mock_session.get.call_args.kwargs + assert id(call_kwargs["headers"]) == original_id + + def test_custom_headers_override(self, mock_session): + NBAHTTP.set_session(mock_session) + http = NBAStatsHTTP() + custom = {"X-Custom": "value"} + + http.send_api_request( + endpoint="test", + parameters={}, + headers=custom, + ) + + call_kwargs = mock_session.get.call_args.kwargs + assert call_kwargs["headers"] == custom + + +class TestGetParserReturnsNone: + """get_parser_for_endpoint returns None for unregistered endpoints.""" + + def test_unregistered_endpoint_returns_none(self): + from nba_api.stats.endpoints._parsers import get_parser_for_endpoint + + result = get_parser_for_endpoint("nonexistent_endpoint_xyz", {}) + assert result is None + + def test_registered_endpoint_returns_parser(self): + from nba_api.stats.endpoints._parsers import get_parser_for_endpoint + + parser = get_parser_for_endpoint("boxscoreadvancedv3", {"some": "data"}) + assert parser is not None + + +class TestGetLegacyResults: + """_get_legacy_results helper extracts resultSets/resultSet correctly.""" + + def test_result_sets_plural(self): + data = {"resultSets": [{"name": "A", "headers": [], "rowSet": []}]} + result = NBAStatsResponse._get_legacy_results(data) + assert result == [{"name": "A", "headers": [], "rowSet": []}] + + def test_result_set_singular(self): + data = {"resultSet": {"name": "B", "headers": [], "rowSet": []}} + result = NBAStatsResponse._get_legacy_results(data) + assert result == {"name": "B", "headers": [], "rowSet": []} + + def test_neither_key_returns_none(self): + data = {"someOtherKey": "value"} + result = NBAStatsResponse._get_legacy_results(data) + assert result is None + + def test_prefers_result_sets_over_result_set(self): + data = { + "resultSets": [{"name": "A"}], + "resultSet": {"name": "B"}, + } + result = NBAStatsResponse._get_legacy_results(data) + assert result == [{"name": "A"}] + + +class TestGetDataSetsMissingKeys: + """get_data_sets returns {} when legacy keys are missing.""" + + def test_no_legacy_keys_returns_empty(self): + resp = NBAStatsResponse( + response='{"meta": {"version": 1}}', + status_code=200, + url="https://test", + ) + assert resp.get_data_sets() == {} + + def test_no_legacy_keys_with_endpoint_no_parser(self): + resp = NBAStatsResponse( + response='{"meta": {"version": 1}}', + status_code=200, + url="https://test", + ) + assert resp.get_data_sets(endpoint="nonexistent_xyz") == {} + + +class TestNBAResponseTypeHints: + """Verify NBAResponse handles typed None values correctly.""" + + def test_status_code_none(self): + resp = NBAResponse(response="{}", status_code=None, url=None) + assert resp.get_status_code() is None + assert resp.get_url() is None + + def test_valid_json_check(self): + resp = NBAResponse(response="not json", status_code=200, url="http://test") + assert resp.valid_json() is False + + def test_valid_json_success(self): + resp = NBAResponse(response='{"ok": true}', status_code=200, url="http://test") + assert resp.valid_json() is True