Skip to content

Commit 6558d01

Browse files
Improve SDK v3 Hugging Face support (#5736)
* Bug fixes for HF models * Fix serialization deserialization issues in core * Removing unnecessary comments * feat: add support for model_index.json fallback in HF config retrieval * feat: add support for PEFT models with adapter_config.json config file --------- Co-authored-by: aviruthen <91846056+aviruthen@users.noreply.github.com>
1 parent 91ca011 commit 6558d01

7 files changed

Lines changed: 149 additions & 34 deletions

File tree

sagemaker-core/src/sagemaker/core/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def pascal_to_snake(pascal_str):
273273

274274

275275
def is_not_primitive(obj):
276-
return not isinstance(obj, (int, float, str, bool, datetime.datetime))
276+
return not isinstance(obj, (int, float, str, bool, datetime.datetime, bytes))
277277

278278

279279
def is_not_str_dict(obj):

sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ def _get_deserializer(self, obj):
196196
return StringDeserializer()
197197
if _is_jsonable(obj):
198198
return JSONDeserializer()
199+
if isinstance(obj, dict) and "content_type" in obj:
200+
try:
201+
return BytesDeserializer()
202+
except ValueError as e:
203+
logger.error(e)
199204

200205
raise ValueError(
201206
(

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,20 @@ def _build_for_transformers(self) -> Model:
687687
hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
688688
)
689689
elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string
690+
# Get model metadata for task detection
691+
hf_model_md = self.get_huggingface_model_metadata(
692+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
693+
)
694+
model_task = hf_model_md.get("pipeline_tag")
695+
if model_task:
696+
self.env_vars.update({"HF_TASK": model_task})
697+
690698
self.env_vars.update({"HF_MODEL_ID": self.model})
699+
700+
# Add HuggingFace token if available
701+
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
702+
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
703+
691704
# Get HF config for string model IDs
692705
if hasattr(self.env_vars, "HF_API_TOKEN"):
693706
self.hf_model_config = _get_model_config_properties_from_hf(

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,11 @@ def _hf_schema_builder_init(self, model_task: str) -> None:
10041004
sample_inputs,
10051005
sample_outputs,
10061006
) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task)
1007+
# Unwrap list outputs for binary tasks (text-to-image, audio, etc.)
1008+
# Remote schema retriever returns [{'data': b'...', 'content_type': '...'}]
1009+
# but SchemaBuilder expects {'data': b'...', 'content_type': '...'}
1010+
if isinstance(sample_outputs, list) and len(sample_outputs) > 0:
1011+
sample_outputs = sample_outputs[0]
10071012

10081013
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
10091014

sagemaker-serve/src/sagemaker/serve/utils/hf_utils.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Utility functions for fetching model information from HuggingFace Hub"""
14+
1415
from __future__ import absolute_import
1516
import json
1617
import urllib.request
@@ -24,30 +25,39 @@
2425
def _get_model_config_properties_from_hf(model_id: str, hf_hub_token: str = None):
2526
"""Placeholder docstring"""
2627

27-
config_url = f"https://huggingface.co/{model_id}/raw/main/config.json"
28+
config_files = ["config.json", "model_index.json", "adapter_config.json"]
29+
2830
model_config = None
29-
try:
30-
if hf_hub_token:
31-
config_url = urllib.request.Request(
32-
config_url, headers={"Authorization": "Bearer " + hf_hub_token}
33-
)
34-
with urllib.request.urlopen(config_url) as response:
35-
model_config = json.load(response)
36-
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
37-
if "HTTP Error 401: Unauthorized" in str(e):
38-
raise ValueError(
39-
"Trying to access a gated/private HuggingFace model without valid credentials. "
40-
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
31+
for config_file in config_files:
32+
config_url = f"https://huggingface.co/{model_id}/raw/main/{config_file}"
33+
request = config_url
34+
35+
try:
36+
if hf_hub_token:
37+
request = urllib.request.Request(
38+
config_url, headers={"Authorization": "Bearer " + hf_hub_token}
39+
)
40+
41+
with urllib.request.urlopen(request) as response:
42+
model_config = json.load(response)
43+
break
44+
except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e:
45+
if "HTTP Error 401: Unauthorized" in str(e):
46+
raise ValueError(
47+
"Trying to access a gated/private HuggingFace model without valid credentials. "
48+
"Please provide a HUGGING_FACE_HUB_TOKEN in env_vars"
49+
)
50+
51+
logger.warning(
52+
"Exception encountered while trying to read config file %s. Details: %s",
53+
config_url,
54+
e,
4155
)
42-
logger.warning(
43-
"Exception encountered while trying to read config file %s. " "Details: %s",
44-
config_url,
45-
e,
46-
)
56+
4757
if not model_config:
58+
allowed_files = ", ".join(config_files)
4859
raise ValueError(
49-
f"Did not find a config.json or model_index.json file in huggingface hub for "
50-
f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable "
51-
f"Diffusion Models) for this model in the huggingface hub"
60+
f"Did not find any supported model config file in Hugging Face Hub for {model_id}. "
61+
f"Expected one of: {allowed_files}"
5262
)
5363
return model_config

sagemaker-serve/tests/unit/servers/test_model_builder_servers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,10 @@ def test_build_with_hf_model_string(
781781
result = self.builder._build_for_transformers()
782782

783783
self.assertEqual(self.builder.env_vars["HF_MODEL_ID"], "gpt2")
784+
mock_hf_config.assert_called_once_with(
785+
"gpt2",
786+
"token",
787+
)
784788
mock_create.assert_called_once()
785789

786790
@patch("sagemaker.serve.model_builder_servers._get_nb_instance")

sagemaker-serve/tests/unit/utils/test_hf_utils.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def test_get_model_config_http_error(self, mock_logger, mock_urlopen):
7575

7676
with self.assertRaises(ValueError) as context:
7777
_get_model_config_properties_from_hf("non-existent-model")
78-
79-
self.assertIn("Did not find a config.json", str(context.exception))
80-
mock_logger.warning.assert_called_once()
78+
79+
self.assertIn("Did not find any supported model config file", str(context.exception))
80+
self.assertEqual(mock_logger.warning.call_count, 3)
8181

8282
@patch('urllib.request.urlopen')
8383
@patch('sagemaker.serve.utils.hf_utils.logger')
@@ -87,9 +87,9 @@ def test_get_model_config_url_error(self, mock_logger, mock_urlopen):
8787

8888
with self.assertRaises(ValueError) as context:
8989
_get_model_config_properties_from_hf("model-id")
90-
91-
self.assertIn("Did not find a config.json", str(context.exception))
92-
mock_logger.warning.assert_called_once()
90+
91+
self.assertIn("Did not find any supported model config file", str(context.exception))
92+
self.assertEqual(mock_logger.warning.call_count, 3)
9393

9494
@patch('urllib.request.urlopen')
9595
@patch('sagemaker.serve.utils.hf_utils.logger')
@@ -99,9 +99,9 @@ def test_get_model_config_timeout_error(self, mock_logger, mock_urlopen):
9999

100100
with self.assertRaises(ValueError) as context:
101101
_get_model_config_properties_from_hf("model-id")
102-
103-
self.assertIn("Did not find a config.json", str(context.exception))
104-
mock_logger.warning.assert_called_once()
102+
103+
self.assertIn("Did not find any supported model config file", str(context.exception))
104+
self.assertEqual(mock_logger.warning.call_count, 3)
105105

106106
@patch('urllib.request.urlopen')
107107
@patch('sagemaker.serve.utils.hf_utils.logger')
@@ -115,9 +115,9 @@ def test_get_model_config_json_decode_error(self, mock_logger, mock_urlopen):
115115
with patch('json.load', side_effect=JSONDecodeError("msg", "doc", 0)):
116116
with self.assertRaises(ValueError) as context:
117117
_get_model_config_properties_from_hf("model-id")
118-
119-
self.assertIn("Did not find a config.json", str(context.exception))
120-
mock_logger.warning.assert_called_once()
118+
119+
self.assertIn("Did not find any supported model config file", str(context.exception))
120+
self.assertEqual(mock_logger.warning.call_count, 3)
121121

122122
@patch('urllib.request.urlopen')
123123
def test_get_model_config_url_format(self, mock_urlopen):
@@ -137,6 +137,84 @@ def test_get_model_config_url_format(self, mock_urlopen):
137137
actual_url = mock_urlopen.call_args[0][0]
138138
self.assertEqual(actual_url, expected_url)
139139

140+
@patch("urllib.request.urlopen")
141+
def test_get_model_config_falls_back_to_model_index(self, mock_urlopen):
142+
"""Test fallback to model_index.json when config.json is missing."""
143+
config_missing_error = HTTPError(
144+
"https://huggingface.co/org/model/raw/main/config.json", 404, "Not Found", {}, None
145+
)
146+
model_index_config = {"_class_name": "FluxPipeline", "_diffusers_version": "0.31.0"}
147+
148+
mock_model_index_response = Mock()
149+
mock_model_index_response.__enter__ = Mock(return_value=mock_model_index_response)
150+
mock_model_index_response.__exit__ = Mock(return_value=False)
151+
152+
def _urlopen_side_effect(request):
153+
url = request.full_url if hasattr(request, "full_url") else request
154+
if url.endswith("/config.json"):
155+
raise config_missing_error
156+
if url.endswith("/model_index.json"):
157+
return mock_model_index_response
158+
raise AssertionError(f"Unexpected URL called: {url}")
159+
160+
mock_urlopen.side_effect = _urlopen_side_effect
161+
162+
with patch("json.load", side_effect=[model_index_config]):
163+
result = _get_model_config_properties_from_hf("org/model-name")
164+
165+
self.assertEqual(result, model_index_config)
166+
167+
@patch("urllib.request.urlopen")
168+
@patch("sagemaker.serve.utils.hf_utils.logger")
169+
def test_get_model_config_dual_file_error_when_both_missing(self, mock_logger, mock_urlopen):
170+
"""Test error when all known config files are missing."""
171+
mock_urlopen.side_effect = HTTPError("url", 404, "Not Found", {}, None)
172+
173+
with self.assertRaises(ValueError) as context:
174+
_get_model_config_properties_from_hf("model-id")
175+
176+
self.assertIn(
177+
"Expected one of: config.json, model_index.json, adapter_config.json",
178+
str(context.exception),
179+
)
180+
self.assertEqual(mock_urlopen.call_count, 3)
181+
self.assertEqual(mock_logger.warning.call_count, 3)
182+
183+
@patch("urllib.request.urlopen")
184+
def test_get_model_config_falls_back_to_adapter_config(self, mock_urlopen):
185+
"""Test fallback to adapter_config.json when config/model_index are missing."""
186+
config_missing_error = HTTPError(
187+
"https://huggingface.co/org/model/raw/main/config.json", 404, "Not Found", {}, None
188+
)
189+
model_index_missing_error = HTTPError(
190+
"https://huggingface.co/org/model/raw/main/model_index.json", 404, "Not Found", {}, None
191+
)
192+
adapter_config = {
193+
"base_model_name_or_path": "LiquidAI/LFM2.5-1.2B-Instruct",
194+
"peft_type": "LORA",
195+
}
196+
197+
mock_adapter_response = Mock()
198+
mock_adapter_response.__enter__ = Mock(return_value=mock_adapter_response)
199+
mock_adapter_response.__exit__ = Mock(return_value=False)
200+
201+
def _urlopen_side_effect(request):
202+
url = request.full_url if hasattr(request, "full_url") else request
203+
if url.endswith("/config.json"):
204+
raise config_missing_error
205+
if url.endswith("/model_index.json"):
206+
raise model_index_missing_error
207+
if url.endswith("/adapter_config.json"):
208+
return mock_adapter_response
209+
raise AssertionError(f"Unexpected URL called: {url}")
210+
211+
mock_urlopen.side_effect = _urlopen_side_effect
212+
213+
with patch("json.load", side_effect=[adapter_config]):
214+
result = _get_model_config_properties_from_hf("org/model-name")
215+
216+
self.assertEqual(result, adapter_config)
217+
140218

141219
if __name__ == "__main__":
142220
unittest.main()

0 commit comments

Comments
 (0)