@@ -75,9 +75,9 @@ def test_get_model_config_http_error(self, mock_logger, mock_urlopen):
7575
7676 with self .assertRaises (ValueError ) as context :
7777 _get_model_config_properties_from_hf ("non-existent-model" )
78-
79- self .assertIn ("Did not find a config.json " , str (context .exception ))
80- mock_logger .warning .assert_called_once ( )
78+
79+ self .assertIn ("Did not find any supported model config file " , str (context .exception ))
80+ self . assertEqual ( mock_logger .warning .call_count , 3 )
8181
8282 @patch ('urllib.request.urlopen' )
8383 @patch ('sagemaker.serve.utils.hf_utils.logger' )
@@ -87,9 +87,9 @@ def test_get_model_config_url_error(self, mock_logger, mock_urlopen):
8787
8888 with self .assertRaises (ValueError ) as context :
8989 _get_model_config_properties_from_hf ("model-id" )
90-
91- self .assertIn ("Did not find a config.json " , str (context .exception ))
92- mock_logger .warning .assert_called_once ( )
90+
91+ self .assertIn ("Did not find any supported model config file " , str (context .exception ))
92+ self . assertEqual ( mock_logger .warning .call_count , 3 )
9393
9494 @patch ('urllib.request.urlopen' )
9595 @patch ('sagemaker.serve.utils.hf_utils.logger' )
@@ -99,9 +99,9 @@ def test_get_model_config_timeout_error(self, mock_logger, mock_urlopen):
9999
100100 with self .assertRaises (ValueError ) as context :
101101 _get_model_config_properties_from_hf ("model-id" )
102-
103- self .assertIn ("Did not find a config.json " , str (context .exception ))
104- mock_logger .warning .assert_called_once ( )
102+
103+ self .assertIn ("Did not find any supported model config file " , str (context .exception ))
104+ self . assertEqual ( mock_logger .warning .call_count , 3 )
105105
106106 @patch ('urllib.request.urlopen' )
107107 @patch ('sagemaker.serve.utils.hf_utils.logger' )
@@ -115,9 +115,9 @@ def test_get_model_config_json_decode_error(self, mock_logger, mock_urlopen):
115115 with patch ('json.load' , side_effect = JSONDecodeError ("msg" , "doc" , 0 )):
116116 with self .assertRaises (ValueError ) as context :
117117 _get_model_config_properties_from_hf ("model-id" )
118-
119- self .assertIn ("Did not find a config.json " , str (context .exception ))
120- mock_logger .warning .assert_called_once ( )
118+
119+ self .assertIn ("Did not find any supported model config file " , str (context .exception ))
120+ self . assertEqual ( mock_logger .warning .call_count , 3 )
121121
122122 @patch ('urllib.request.urlopen' )
123123 def test_get_model_config_url_format (self , mock_urlopen ):
@@ -137,6 +137,84 @@ def test_get_model_config_url_format(self, mock_urlopen):
137137 actual_url = mock_urlopen .call_args [0 ][0 ]
138138 self .assertEqual (actual_url , expected_url )
139139
140+ @patch ("urllib.request.urlopen" )
141+ def test_get_model_config_falls_back_to_model_index (self , mock_urlopen ):
142+ """Test fallback to model_index.json when config.json is missing."""
143+ config_missing_error = HTTPError (
144+ "https://huggingface.co/org/model/raw/main/config.json" , 404 , "Not Found" , {}, None
145+ )
146+ model_index_config = {"_class_name" : "FluxPipeline" , "_diffusers_version" : "0.31.0" }
147+
148+ mock_model_index_response = Mock ()
149+ mock_model_index_response .__enter__ = Mock (return_value = mock_model_index_response )
150+ mock_model_index_response .__exit__ = Mock (return_value = False )
151+
152+ def _urlopen_side_effect (request ):
153+ url = request .full_url if hasattr (request , "full_url" ) else request
154+ if url .endswith ("/config.json" ):
155+ raise config_missing_error
156+ if url .endswith ("/model_index.json" ):
157+ return mock_model_index_response
158+ raise AssertionError (f"Unexpected URL called: { url } " )
159+
160+ mock_urlopen .side_effect = _urlopen_side_effect
161+
162+ with patch ("json.load" , side_effect = [model_index_config ]):
163+ result = _get_model_config_properties_from_hf ("org/model-name" )
164+
165+ self .assertEqual (result , model_index_config )
166+
167+ @patch ("urllib.request.urlopen" )
168+ @patch ("sagemaker.serve.utils.hf_utils.logger" )
169+ def test_get_model_config_dual_file_error_when_both_missing (self , mock_logger , mock_urlopen ):
170+ """Test error when all known config files are missing."""
171+ mock_urlopen .side_effect = HTTPError ("url" , 404 , "Not Found" , {}, None )
172+
173+ with self .assertRaises (ValueError ) as context :
174+ _get_model_config_properties_from_hf ("model-id" )
175+
176+ self .assertIn (
177+ "Expected one of: config.json, model_index.json, adapter_config.json" ,
178+ str (context .exception ),
179+ )
180+ self .assertEqual (mock_urlopen .call_count , 3 )
181+ self .assertEqual (mock_logger .warning .call_count , 3 )
182+
183+ @patch ("urllib.request.urlopen" )
184+ def test_get_model_config_falls_back_to_adapter_config (self , mock_urlopen ):
185+ """Test fallback to adapter_config.json when config/model_index are missing."""
186+ config_missing_error = HTTPError (
187+ "https://huggingface.co/org/model/raw/main/config.json" , 404 , "Not Found" , {}, None
188+ )
189+ model_index_missing_error = HTTPError (
190+ "https://huggingface.co/org/model/raw/main/model_index.json" , 404 , "Not Found" , {}, None
191+ )
192+ adapter_config = {
193+ "base_model_name_or_path" : "LiquidAI/LFM2.5-1.2B-Instruct" ,
194+ "peft_type" : "LORA" ,
195+ }
196+
197+ mock_adapter_response = Mock ()
198+ mock_adapter_response .__enter__ = Mock (return_value = mock_adapter_response )
199+ mock_adapter_response .__exit__ = Mock (return_value = False )
200+
201+ def _urlopen_side_effect (request ):
202+ url = request .full_url if hasattr (request , "full_url" ) else request
203+ if url .endswith ("/config.json" ):
204+ raise config_missing_error
205+ if url .endswith ("/model_index.json" ):
206+ raise model_index_missing_error
207+ if url .endswith ("/adapter_config.json" ):
208+ return mock_adapter_response
209+ raise AssertionError (f"Unexpected URL called: { url } " )
210+
211+ mock_urlopen .side_effect = _urlopen_side_effect
212+
213+ with patch ("json.load" , side_effect = [adapter_config ]):
214+ result = _get_model_config_properties_from_hf ("org/model-name" )
215+
216+ self .assertEqual (result , adapter_config )
217+
140218
141219if __name__ == "__main__" :
142220 unittest .main ()
0 commit comments