@@ -354,3 +354,143 @@ def token_refresher() -> tuple[str, int]:
354354 )
355355
356356 assert os .path .exists (file_path )
357+
358+
359+ @requires ("hf_xet" )
360+ class TestXetProgressGranularity :
361+ """Test that xet_get uses the fine-grained 2-arg callback for tqdm progress."""
362+
363+ _XET_FILE_DATA = XetFileData (file_hash = "mock_hash" , refresh_route = "mock/route" )
364+ _CONNECTION_INFO = XetConnectionInfo (
365+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
366+ )
367+
368+ @staticmethod
369+ def _make_mock_total_update (transfer_increment : float , transfer_total : float ):
370+ """Create a mock PyTotalProgressUpdate with the given transfer fields."""
371+ update = Mock ()
372+ update .total_transfer_bytes_completion_increment = transfer_increment
373+ update .total_transfer_bytes = transfer_total
374+ return update
375+
376+ def _call_xet_get_and_capture (self , tmp_path , mock_download , expected_size = 1000 , mock_progress_cm = None ):
377+ """Call xet_get and return the captured progress callback."""
378+ incomplete_path = tmp_path / "test_file.bin"
379+ incomplete_path .touch ()
380+
381+ captured = {}
382+
383+ def capture (* args , ** kwargs ):
384+ captured ["callback" ] = kwargs ["progress_updater" ][0 ]
385+
386+ mock_download .side_effect = capture
387+
388+ if mock_progress_cm is not None :
389+ mock_bar = Mock ()
390+ mock_bar .n = 0
391+ mock_progress_cm .return_value .__enter__ = Mock (return_value = mock_bar )
392+ mock_progress_cm .return_value .__exit__ = Mock (return_value = False )
393+ else :
394+ mock_bar = None
395+
396+ xet_get (
397+ incomplete_path = incomplete_path ,
398+ xet_file_data = self ._XET_FILE_DATA ,
399+ headers = {"authorization" : "Bearer token" },
400+ expected_size = expected_size ,
401+ )
402+
403+ return captured ["callback" ], mock_bar
404+
405+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
406+ @patch ("hf_xet.download_files" )
407+ @patch (
408+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
409+ return_value = XetConnectionInfo (
410+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
411+ ),
412+ )
413+ def test_callback_uses_two_arg_signature (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
414+ """Verify xet_get passes a 2-arg callback to download_files, triggering
415+ xet-core's fine-grained network-level progress dispatch."""
416+ callback , _ = self ._call_xet_get_and_capture (tmp_path , mock_download , mock_progress_cm = mock_progress_cm )
417+
418+ # Call with 2 args (total_update, item_updates) to confirm it accepts them.
419+ # A 1-arg callback would raise TypeError here.
420+ total_update = self ._make_mock_total_update (transfer_increment = 200 , transfer_total = 1000 )
421+ callback (total_update , []) # should not raise
422+
423+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
424+ @patch ("hf_xet.download_files" )
425+ @patch (
426+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
427+ return_value = XetConnectionInfo (
428+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
429+ ),
430+ )
431+ def test_progress_bar_scales_network_to_file_size (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
432+ """When transfer bytes differ from file size, the progress bar should
433+ scale to expected_size so it always reaches 100%."""
434+ expected_size = 10_000
435+ transfer_total = 5_000 # fewer bytes due to deduplication
436+
437+ callback , mock_bar = self ._call_xet_get_and_capture (
438+ tmp_path , mock_download , expected_size = expected_size , mock_progress_cm = mock_progress_cm
439+ )
440+
441+ def update_side_effect (n ):
442+ mock_bar .n += n
443+
444+ mock_bar .update = Mock (side_effect = update_side_effect )
445+
446+ # Simulate 5 updates of 1000 transfer bytes each (total: 5000 transfer bytes)
447+ for _ in range (5 ):
448+ total_update = self ._make_mock_total_update (transfer_increment = 1000 , transfer_total = transfer_total )
449+ callback (total_update , [])
450+
451+ # After transferring 5000/5000 bytes, bar should be at expected_size (10000)
452+ assert mock_bar .n == expected_size
453+
454+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
455+ @patch ("hf_xet.download_files" )
456+ @patch (
457+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
458+ return_value = XetConnectionInfo (
459+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
460+ ),
461+ )
462+ def test_progress_bar_capped_at_expected_size (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
463+ """Progress bar should never exceed expected_size."""
464+ expected_size = 1000
465+
466+ callback , mock_bar = self ._call_xet_get_and_capture (
467+ tmp_path , mock_download , expected_size = expected_size , mock_progress_cm = mock_progress_cm
468+ )
469+
470+ def update_side_effect (n ):
471+ mock_bar .n += n
472+
473+ mock_bar .update = Mock (side_effect = update_side_effect )
474+
475+ # Send more transfer bytes than total (edge case)
476+ total_update = self ._make_mock_total_update (transfer_increment = 1200 , transfer_total = 1000 )
477+ callback (total_update , [])
478+
479+ assert mock_bar .n <= expected_size
480+
481+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
482+ @patch ("hf_xet.download_files" )
483+ @patch (
484+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
485+ return_value = XetConnectionInfo (
486+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
487+ ),
488+ )
489+ def test_zero_increment_skipped (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
490+ """Zero-increment updates should not call progress.update."""
491+ callback , mock_bar = self ._call_xet_get_and_capture (tmp_path , mock_download , mock_progress_cm = mock_progress_cm )
492+
493+ total_update = self ._make_mock_total_update (transfer_increment = 0 , transfer_total = 1000 )
494+ callback (total_update , [])
495+
496+ mock_bar .update .assert_not_called ()
0 commit comments