@@ -354,3 +354,268 @@ def token_refresher() -> tuple[str, int]:
354354 )
355355
356356 assert os .path .exists (file_path )
357+
358+
359+ def _make_mock_total_update (transfer_increment : float , transfer_total : float ):
360+ """Create a mock PyTotalProgressUpdate with the given transfer fields."""
361+ update = Mock ()
362+ update .total_transfer_bytes_completion_increment = transfer_increment
363+ update .total_transfer_bytes = transfer_total
364+ return update
365+
366+
367+ @requires ("hf_xet" )
368+ class TestXetProgressGranularity :
369+ """Test that xet_get uses the fine-grained 2-arg callback for tqdm progress."""
370+
371+ _XET_FILE_DATA = XetFileData (file_hash = "mock_hash" , refresh_route = "mock/route" )
372+
373+ def _call_xet_get_and_capture (self , tmp_path , mock_download , expected_size = 1000 , mock_progress_cm = None ):
374+ """Call xet_get and return the captured progress callback."""
375+ incomplete_path = tmp_path / "test_file.bin"
376+ incomplete_path .touch ()
377+
378+ captured = {}
379+
380+ def capture (* args , ** kwargs ):
381+ captured ["callback" ] = kwargs ["progress_updater" ][0 ]
382+
383+ mock_download .side_effect = capture
384+
385+ if mock_progress_cm is not None :
386+ mock_bar = Mock ()
387+ mock_bar .n = 0
388+ mock_progress_cm .return_value .__enter__ = Mock (return_value = mock_bar )
389+ mock_progress_cm .return_value .__exit__ = Mock (return_value = False )
390+ else :
391+ mock_bar = None
392+
393+ xet_get (
394+ incomplete_path = incomplete_path ,
395+ xet_file_data = self ._XET_FILE_DATA ,
396+ headers = {"authorization" : "Bearer token" },
397+ expected_size = expected_size ,
398+ )
399+
400+ return captured ["callback" ], mock_bar
401+
402+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
403+ @patch ("hf_xet.download_files" )
404+ @patch (
405+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
406+ return_value = XetConnectionInfo (
407+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
408+ ),
409+ )
410+ def test_callback_uses_two_arg_signature (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
411+ """Verify xet_get passes a 2-arg callback to download_files, triggering
412+ xet-core's fine-grained network-level progress dispatch."""
413+ callback , _ = self ._call_xet_get_and_capture (tmp_path , mock_download , mock_progress_cm = mock_progress_cm )
414+
415+ # Call with 2 args (total_update, item_updates) to confirm it accepts them.
416+ # A 1-arg callback would raise TypeError here.
417+ total_update = _make_mock_total_update (transfer_increment = 200 , transfer_total = 1000 )
418+ callback (total_update , []) # should not raise
419+
420+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
421+ @patch ("hf_xet.download_files" )
422+ @patch (
423+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
424+ return_value = XetConnectionInfo (
425+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
426+ ),
427+ )
428+ def test_progress_bar_scales_network_to_file_size (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
429+ """When transfer bytes differ from file size, the progress bar should
430+ scale to expected_size so it always reaches 100%."""
431+ expected_size = 10_000
432+ transfer_total = 5_000 # fewer bytes due to deduplication
433+
434+ callback , mock_bar = self ._call_xet_get_and_capture (
435+ tmp_path , mock_download , expected_size = expected_size , mock_progress_cm = mock_progress_cm
436+ )
437+
438+ def update_side_effect (n ):
439+ mock_bar .n += n
440+
441+ mock_bar .update = Mock (side_effect = update_side_effect )
442+
443+ # Simulate 5 updates of 1000 transfer bytes each (total: 5000 transfer bytes)
444+ for _ in range (5 ):
445+ total_update = _make_mock_total_update (transfer_increment = 1000 , transfer_total = transfer_total )
446+ callback (total_update , [])
447+
448+ # After transferring 5000/5000 bytes, bar should be at expected_size (10000)
449+ assert mock_bar .n == expected_size
450+
451+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
452+ @patch ("hf_xet.download_files" )
453+ @patch (
454+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
455+ return_value = XetConnectionInfo (
456+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
457+ ),
458+ )
459+ def test_progress_bar_capped_at_expected_size (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
460+ """Progress bar should never exceed expected_size."""
461+ expected_size = 1000
462+
463+ callback , mock_bar = self ._call_xet_get_and_capture (
464+ tmp_path , mock_download , expected_size = expected_size , mock_progress_cm = mock_progress_cm
465+ )
466+
467+ def update_side_effect (n ):
468+ mock_bar .n += n
469+
470+ mock_bar .update = Mock (side_effect = update_side_effect )
471+
472+ # Send more transfer bytes than total (edge case)
473+ total_update = _make_mock_total_update (transfer_increment = 1200 , transfer_total = 1000 )
474+ callback (total_update , [])
475+
476+ assert mock_bar .n <= expected_size
477+
478+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
479+ @patch ("hf_xet.download_files" )
480+ @patch (
481+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
482+ return_value = XetConnectionInfo (
483+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
484+ ),
485+ )
486+ def test_zero_increment_skipped (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
487+ """Zero-increment updates should not call progress.update."""
488+ callback , mock_bar = self ._call_xet_get_and_capture (tmp_path , mock_download , mock_progress_cm = mock_progress_cm )
489+
490+ total_update = _make_mock_total_update (transfer_increment = 0 , transfer_total = 1000 )
491+ callback (total_update , [])
492+
493+ mock_bar .update .assert_not_called ()
494+
495+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
496+ @patch ("hf_xet.download_files" )
497+ @patch (
498+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
499+ return_value = XetConnectionInfo (
500+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
501+ ),
502+ )
503+ def test_expected_size_none_passes_raw_bytes (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
504+ """When expected_size is None, raw transfer bytes are passed through."""
505+ callback , mock_bar = self ._call_xet_get_and_capture (
506+ tmp_path , mock_download , expected_size = None , mock_progress_cm = mock_progress_cm
507+ )
508+
509+ def update_side_effect (n ):
510+ mock_bar .n += n
511+
512+ mock_bar .update = Mock (side_effect = update_side_effect )
513+
514+ total_update = _make_mock_total_update (transfer_increment = 500 , transfer_total = 0 )
515+ callback (total_update , [])
516+
517+ assert mock_bar .n == 500
518+
519+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
520+ @patch ("hf_xet.download_files" )
521+ @patch (
522+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
523+ return_value = XetConnectionInfo (
524+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
525+ ),
526+ )
527+ def test_expected_size_none_with_known_transfer_total (self , _mock_conn , mock_download , mock_progress_cm , tmp_path ):
528+ """When expected_size is None, raw bytes pass through even if transfer_total is known."""
529+ callback , mock_bar = self ._call_xet_get_and_capture (
530+ tmp_path , mock_download , expected_size = None , mock_progress_cm = mock_progress_cm
531+ )
532+
533+ def update_side_effect (n ):
534+ mock_bar .n += n
535+
536+ mock_bar .update = Mock (side_effect = update_side_effect )
537+
538+ total_update = _make_mock_total_update (transfer_increment = 500 , transfer_total = 2000 )
539+ callback (total_update , [])
540+
541+ assert mock_bar .n == 500
542+
543+ @patch ("huggingface_hub.file_download._get_progress_bar_context" )
544+ @patch ("hf_xet.download_files" )
545+ @patch (
546+ "huggingface_hub.file_download.refresh_xet_connection_info" ,
547+ return_value = XetConnectionInfo (
548+ endpoint = "mock_endpoint" , access_token = "mock_token" , expiration_unix_epoch = 9999999999
549+ ),
550+ )
551+ def test_transfer_total_zero_skips_when_expected_size_set (
552+ self , _mock_conn , mock_download , mock_progress_cm , tmp_path
553+ ):
554+ """When expected_size is set but transfer_total is 0 (not yet known),
555+ updates are skipped to avoid injecting unscaled bytes."""
556+ callback , mock_bar = self ._call_xet_get_and_capture (
557+ tmp_path , mock_download , expected_size = 1000 , mock_progress_cm = mock_progress_cm
558+ )
559+
560+ total_update = _make_mock_total_update (transfer_increment = 500 , transfer_total = 0 )
561+ callback (total_update , [])
562+
563+ mock_bar .update .assert_not_called ()
564+
565+
566+ @requires ("hf_xet" )
567+ class TestMakeXetProgressCallback :
568+ """Direct tests for make_xet_progress_callback shared helper."""
569+
570+ def test_multi_file_shared_bar (self ):
571+ """Multiple callbacks sharing one bar should each contribute independently."""
572+ from huggingface_hub .file_download import make_xet_progress_callback
573+
574+ mock_bar = Mock ()
575+ mock_bar .n = 0
576+
577+ def update_side_effect (n ):
578+ mock_bar .n += n
579+
580+ mock_bar .update = Mock (side_effect = update_side_effect )
581+
582+ # Two files: 600 bytes and 400 bytes, sharing a bar with total=1000
583+ cb_a = make_xet_progress_callback (mock_bar , file_size = 600 )
584+ cb_b = make_xet_progress_callback (mock_bar , file_size = 400 )
585+
586+ # File A: 50% done (transfers 500/1000 network bytes -> contributes 300 of 600 file bytes)
587+ cb_a (_make_mock_total_update (transfer_increment = 500 , transfer_total = 1000 ), [])
588+ assert mock_bar .n == 300
589+
590+ # File B: 100% done (transfers 800/800 -> contributes 400 of 400 file bytes)
591+ cb_b (_make_mock_total_update (transfer_increment = 800 , transfer_total = 800 ), [])
592+ assert mock_bar .n == 700 # 300 + 400
593+
594+ # File A: 100% done (transfers remaining 500/1000 -> contributes remaining 300)
595+ cb_a (_make_mock_total_update (transfer_increment = 500 , transfer_total = 1000 ), [])
596+ assert mock_bar .n == 1000 # 600 + 400
597+
598+ def test_no_regression_on_duplicate_progress (self ):
599+ """When cumulative doesn't advance (e.g. duplicate update), bar should not update."""
600+ from huggingface_hub .file_download import make_xet_progress_callback
601+
602+ mock_bar = Mock ()
603+ mock_bar .n = 0
604+
605+ def update_side_effect (n ):
606+ mock_bar .n += n
607+
608+ mock_bar .update = Mock (side_effect = update_side_effect )
609+
610+ cb = make_xet_progress_callback (mock_bar , file_size = 1000 )
611+
612+ # First update: 500/1000 transfer -> 500 file bytes
613+ cb (_make_mock_total_update (transfer_increment = 500 , transfer_total = 1000 ), [])
614+ assert mock_bar .n == 500
615+ assert mock_bar .update .call_count == 1
616+
617+ # Tiny increment that doesn't move int() forward (1 byte of 1000 transfer = 0.001 * 1000 = 1)
618+ # contributed = int(501/1000 * 1000) = 501, advance = 501 - 500 = 1
619+ cb (_make_mock_total_update (transfer_increment = 1 , transfer_total = 1000 ), [])
620+ assert mock_bar .n == 501
621+ assert mock_bar .update .call_count == 2
0 commit comments