@@ -354,3 +354,185 @@ def token_refresher() -> tuple[str, int]:
354354 )
355355
356356 assert os .path .exists (file_path )
357+
358+
359+ @requires ("hf_xet" )
360+ class TestXetProgressGranularity :
361+ """Test that xet_get uses the fine-grained 2-arg callback for tqdm progress."""
362+
363+ def _make_mock_total_update (self , transfer_increment : float , transfer_total : float ):
364+ """Create a mock PyTotalProgressUpdate with the given transfer fields."""
365+ update = Mock ()
366+ update .total_transfer_bytes_completion_increment = transfer_increment
367+ update .total_transfer_bytes = transfer_total
368+ return update
369+
370+ def test_callback_uses_two_arg_signature (self , tmp_path ):
371+ """Verify xet_get passes a 2-arg callback to download_files, triggering
372+ xet-core's fine-grained network-level progress dispatch."""
373+ incomplete_path = tmp_path / "test_file.bin"
374+ incomplete_path .touch ()
375+
376+ xet_file_data = XetFileData (file_hash = "mock_hash" , refresh_route = "mock/route" )
377+ connection_info = XetConnectionInfo (
378+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
379+ )
380+
381+ with patch ("huggingface_hub.file_download.refresh_xet_connection_info" , return_value = connection_info ):
382+ with patch ("hf_xet.download_files" ) as mock_download :
383+ xet_get (
384+ incomplete_path = incomplete_path ,
385+ xet_file_data = xet_file_data ,
386+ headers = {"authorization" : "Bearer token" },
387+ expected_size = 1000 ,
388+ )
389+
390+ mock_download .assert_called_once ()
391+ callbacks = mock_download .call_args .kwargs ["progress_updater" ]
392+ assert len (callbacks ) == 1
393+ callback = callbacks [0 ]
394+
395+ # Call with 2 args (total_update, item_updates) to confirm it accepts them.
396+ # A 1-arg callback would raise TypeError here.
397+ total_update = self ._make_mock_total_update (transfer_increment = 200 , transfer_total = 1000 )
398+ callback (total_update , []) # should not raise
399+
400+ def test_progress_bar_scales_network_to_file_size (self , tmp_path ):
401+ """When transfer bytes differ from file size, the progress bar should
402+ scale to expected_size so it always reaches 100%."""
403+ incomplete_path = tmp_path / "test_file.bin"
404+ incomplete_path .touch ()
405+
406+ xet_file_data = XetFileData (file_hash = "mock_hash" , refresh_route = "mock/route" )
407+ connection_info = XetConnectionInfo (
408+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
409+ )
410+
411+ expected_size = 10_000
412+ # Simulate xet transferring fewer bytes than file size (deduplication)
413+ transfer_total = 5_000
414+
415+ captured_callback = None
416+
417+ def capture_callback (* args , ** kwargs ):
418+ nonlocal captured_callback
419+ captured_callback = kwargs ["progress_updater" ][0 ]
420+
421+ with patch ("huggingface_hub.file_download.refresh_xet_connection_info" , return_value = connection_info ):
422+ with patch ("hf_xet.download_files" , side_effect = capture_callback ):
423+ with patch ("huggingface_hub.file_download._get_progress_bar_context" ) as mock_progress_cm :
424+ mock_bar = Mock ()
425+ mock_bar .n = 0
426+ mock_progress_cm .return_value .__enter__ = Mock (return_value = mock_bar )
427+ mock_progress_cm .return_value .__exit__ = Mock (return_value = False )
428+
429+ xet_get (
430+ incomplete_path = incomplete_path ,
431+ xet_file_data = xet_file_data ,
432+ headers = {"authorization" : "Bearer token" },
433+ expected_size = expected_size ,
434+ )
435+
436+ assert captured_callback is not None
437+
438+ # Simulate 5 updates of 1000 transfer bytes each (total: 5000 transfer bytes)
439+ for i in range (5 ):
440+ total_update = self ._make_mock_total_update (
441+ transfer_increment = 1000 ,
442+ transfer_total = transfer_total ,
443+ )
444+
445+ # Simulate tqdm advancing bar.n after each update call
446+ def update_side_effect (n ):
447+ mock_bar .n += n
448+
449+ mock_bar .update = Mock (side_effect = update_side_effect )
450+ captured_callback (total_update , [])
451+
452+ # After transferring 5000/5000 bytes, bar should be at expected_size (10000)
453+ assert mock_bar .n == expected_size
454+
455+ def test_progress_bar_capped_at_expected_size (self , tmp_path ):
456+ """Progress bar should never exceed expected_size."""
457+ incomplete_path = tmp_path / "test_file.bin"
458+ incomplete_path .touch ()
459+
460+ xet_file_data = XetFileData (file_hash = "mock_hash" , refresh_route = "mock/route" )
461+ connection_info = XetConnectionInfo (
462+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
463+ )
464+
465+ expected_size = 1000
466+ transfer_total = 1000
467+
468+ captured_callback = None
469+
470+ def capture_callback (* args , ** kwargs ):
471+ nonlocal captured_callback
472+ captured_callback = kwargs ["progress_updater" ][0 ]
473+
474+ with patch ("huggingface_hub.file_download.refresh_xet_connection_info" , return_value = connection_info ):
475+ with patch ("hf_xet.download_files" , side_effect = capture_callback ):
476+ with patch ("huggingface_hub.file_download._get_progress_bar_context" ) as mock_progress_cm :
477+ mock_bar = Mock ()
478+ mock_bar .n = 0
479+ mock_progress_cm .return_value .__enter__ = Mock (return_value = mock_bar )
480+ mock_progress_cm .return_value .__exit__ = Mock (return_value = False )
481+
482+ xet_get (
483+ incomplete_path = incomplete_path ,
484+ xet_file_data = xet_file_data ,
485+ headers = {"authorization" : "Bearer token" },
486+ expected_size = expected_size ,
487+ )
488+
489+ assert captured_callback is not None
490+
491+ def update_side_effect (n ):
492+ mock_bar .n += n
493+
494+ mock_bar .update = Mock (side_effect = update_side_effect )
495+
496+ # Send more transfer bytes than total (edge case)
497+ total_update = self ._make_mock_total_update (transfer_increment = 1200 , transfer_total = transfer_total )
498+ captured_callback (total_update , [])
499+
500+ assert mock_bar .n <= expected_size
501+
502+ def test_zero_increment_skipped (self , tmp_path ):
503+ """Zero-increment updates should not call progress.update."""
504+ incomplete_path = tmp_path / "test_file.bin"
505+ incomplete_path .touch ()
506+
507+ xet_file_data = XetFileData (file_hash = "mock_hash" , refresh_route = "mock/route" )
508+ connection_info = XetConnectionInfo (
509+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
510+ )
511+
512+ captured_callback = None
513+
514+ def capture_callback (* args , ** kwargs ):
515+ nonlocal captured_callback
516+ captured_callback = kwargs ["progress_updater" ][0 ]
517+
518+ with patch ("huggingface_hub.file_download.refresh_xet_connection_info" , return_value = connection_info ):
519+ with patch ("hf_xet.download_files" , side_effect = capture_callback ):
520+ with patch ("huggingface_hub.file_download._get_progress_bar_context" ) as mock_progress_cm :
521+ mock_bar = Mock ()
522+ mock_bar .n = 0
523+ mock_progress_cm .return_value .__enter__ = Mock (return_value = mock_bar )
524+ mock_progress_cm .return_value .__exit__ = Mock (return_value = False )
525+
526+ xet_get (
527+ incomplete_path = incomplete_path ,
528+ xet_file_data = xet_file_data ,
529+ headers = {"authorization" : "Bearer token" },
530+ expected_size = 1000 ,
531+ )
532+
533+ assert captured_callback is not None
534+
535+ total_update = self ._make_mock_total_update (transfer_increment = 0 , transfer_total = 1000 )
536+ captured_callback (total_update , [])
537+
538+ mock_bar .update .assert_not_called ()
0 commit comments