diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2e69dad..99e2b25 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -2,7 +2,6 @@ name: goreleaser on: - pull_request: push: tags: - "*" diff --git a/main.go b/main.go index 6dd3e94..b0d0507 100644 --- a/main.go +++ b/main.go @@ -57,7 +57,7 @@ func (rl *rateLimitedWriter) Write(p []byte) (int, error) { if rl.limiter == nil { return rl.w.Write(p) } - + // Write in chunks to avoid blocking for too long written := 0 for written < len(p) { @@ -65,12 +65,12 @@ func (rl *rateLimitedWriter) Write(p []byte) (int, error) { if chunkSize > len(p)-written { chunkSize = len(p) - written } - + // Wait for permission to write if err := rl.limiter.WaitN(rl.ctx, chunkSize); err != nil { return written, err } - + n, err := rl.w.Write(p[written : written+chunkSize]) written += n if err != nil { @@ -87,6 +87,13 @@ type downloadPart struct { endByte uint64 } +// atomicCounter is padded to 64 bytes to prevent false sharing between CPU cores. +// Without padding, adjacent counters share a cache line, causing cache thrashing. +type atomicCounter struct { + val uint64 + _ [56]byte // pad to 64 bytes (cache line size) +} + type download struct { uri string filesize uint64 @@ -101,48 +108,49 @@ type download struct { ctx context.Context progress *DownloadProgress progressMutex sync.Mutex + partDownloaded []atomicCounter // cache-line padded atomic counters for lock-free progress } // DownloadProgress tracks the progress of a download type DownloadProgress struct { - Version int `json:"version"` - URI string `json:"uri"` - FileSize uint64 `json:"file_size"` - Filename string `json:"filename"` - Parts map[int]*PartProgress `json:"parts"` - Created time.Time `json:"created"` - LastUpdated time.Time `json:"last_updated"` - Completed bool `json:"completed"` + Version int `json:"version"` + URI string `json:"uri"` + FileSize uint64 `json:"file_size"` + Filename string `json:"filename"` + Parts map[int]*PartProgress `json:"parts"` + Created time.Time `json:"created"` + LastUpdated time.Time `json:"last_updated"` + Completed bool `json:"completed"` } // PartProgress tracks the progress of a single download part type PartProgress struct { - Index int `json:"index"` - StartByte uint64 `json:"start_byte"` - EndByte uint64 `json:"end_byte"` - Downloaded uint64 `json:"downloaded"` - Completed bool `json:"completed"` - LastModified time.Time `json:"last_modified"` + Index int `json:"index"` + StartByte uint64 `json:"start_byte"` + EndByte uint64 `json:"end_byte"` + Downloaded uint64 `json:"downloaded"` + Completed bool `json:"completed"` + LastModified time.Time `json:"last_modified"` } type config struct { - boost int - retries int + boost int + retries int } const ( // HTTP client timeouts - httpTimeout = 30 * time.Second // For HEAD requests and initial connections + httpTimeout = 30 * time.Second // For HEAD requests and initial connections idleConnTimeout = 90 * time.Second tlsHandshakeTimeout = 10 * time.Second - + // Connection pool settings maxIdleConns = 100 maxIdleConnsPerHost = 10 - + // Download settings - minDownloadTimeout = 60 * time.Second // Minimum timeout for downloads - timeoutPerMB = 3 * time.Second // Additional timeout per MB of part size + minDownloadTimeout = 60 * time.Second // Minimum timeout for downloads + timeoutPerMB = 3 * time.Second // Additional timeout per MB of part size ) // Global HTTP client with proper timeouts and connection pooling @@ -153,7 +161,8 @@ var httpClient = &http.Client{ MaxIdleConnsPerHost: maxIdleConnsPerHost, IdleConnTimeout: idleConnTimeout, TLSHandshakeTimeout: tlsHandshakeTimeout, - DisableCompression: true, // We're downloading files, compression is usually not helpful + ForceAttemptHTTP2: false, // Use multiple TCP connections for boosted downloads + DisableCompression: true, // We're downloading files, compression is usually not helpful }, } @@ -162,35 +171,35 @@ func loadConfig() config { boost: 8, retries: 3, } - + // Try to load from user's home directory home, err := os.UserHomeDir() if err != nil { return cfg } - + configPath := filepath.Join(home, ".dlrc") file, err := os.Open(configPath) if err != nil { return cfg // File doesn't exist, use defaults } defer file.Close() - + scanner := bufio.NewScanner(file) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" || strings.HasPrefix(line, "#") { continue } - + parts := strings.SplitN(line, "=", 2) if len(parts) != 2 { continue } - + key := strings.TrimSpace(parts[0]) value := strings.TrimSpace(parts[1]) - + switch key { case "boost": if v, err := strconv.Atoi(value); err == nil && v > 0 { @@ -202,7 +211,7 @@ func loadConfig() config { } } } - + return cfg } @@ -211,11 +220,11 @@ func parseBandwidthLimit(limit string) (int64, error) { if limit == "" { return 0, nil } - + // Remove "/s" suffix if present limit = strings.TrimSuffix(strings.ToUpper(limit), "/S") limit = strings.TrimSpace(limit) - + // Extract numeric part and unit var numStr string var unit string @@ -227,16 +236,16 @@ func parseBandwidthLimit(limit string) (int64, error) { unit = limit[i:] break } - + if numStr == "" { numStr = limit } - + num, err := strconv.ParseFloat(numStr, 64) if err != nil { return 0, fmt.Errorf("invalid bandwidth limit: %s", limit) } - + // Convert to bytes per second multiplier := float64(1) switch strings.ToUpper(strings.TrimSpace(unit)) { @@ -251,14 +260,14 @@ func parseBandwidthLimit(limit string) (int64, error) { default: return 0, fmt.Errorf("unknown unit: %s", unit) } - + return int64(num * multiplier), nil } func main() { // Load config file first cfg := loadConfig() - + filenamePtr := flag.String("filename", "", "custom filename") boostPtr := flag.Int("boost", cfg.boost, "number of concurrent downloads") retriesPtr := flag.Int("retries", cfg.retries, "max retries for failed parts") @@ -266,9 +275,9 @@ func main() { noResumePtr := flag.Bool("no-resume", false, "disable auto-resume") limitPtr := flag.String("limit", "", "bandwidth limit (e.g. 1M, 500K, 100KB/s)") checksumPtr := flag.String("checksum", "", "verify download with checksum (format: algorithm:hash, e.g. sha256:abc123...)") - + flag.Parse() - + if *boostPtr < 1 { fmt.Fprintln(os.Stderr, "Boost must be greater than 0") os.Exit(1) @@ -290,7 +299,7 @@ func main() { // Create a context that can be cancelled ctx, cancel := context.WithCancel(context.Background()) defer cancel() - + // Handle signals for graceful shutdown sigc := make(chan os.Signal, 1) signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) @@ -351,7 +360,7 @@ func main() { } fmt.Println("Download completed:", dl.filename) - + // Verify checksum if provided if *checksumPtr != "" { fmt.Printf("Verifying checksum...") @@ -370,7 +379,7 @@ func (dl *download) FetchMetadata() error { return fmt.Errorf("failed to create HEAD request: %w", err) } req.Header.Set("User-Agent", "dl/1.1.1") - + resp, err := httpClient.Do(req) if err != nil { return fmt.Errorf("HEAD request failed: %w", err) @@ -451,6 +460,14 @@ func (dl *download) saveProgress() error { return nil } + // Sync from atomic counters to progress struct (lock-free reads) + for i := range dl.partDownloaded { + if partProgress, ok := dl.progress.Parts[i]; ok { + partProgress.Downloaded = atomic.LoadUint64(&dl.partDownloaded[i].val) + partProgress.LastModified = time.Now() + } + } + dl.progress.LastUpdated = time.Now() data, err := json.MarshalIndent(dl.progress, "", " ") @@ -573,10 +590,11 @@ func (dl *download) Fetch() error { var err error var existingSize int64 var resumeFromProgress bool - + // Prepare chunk boundaries first (needed for progress tracking) if dl.boost > 1 && dl.supportsRange { dl.parts = make([]downloadPart, dl.boost) + dl.partDownloaded = make([]atomicCounter, dl.boost) // cache-line padded counters for lock-free progress for i := 0; i < dl.boost; i++ { start, end := dl.calculatePartBoundary(i) dl.parts[i] = downloadPart{ @@ -587,32 +605,32 @@ func (dl *download) Fetch() error { } } } - + // Check for existing progress if dl.resume { if err := dl.loadProgress(); err != nil { fmt.Printf("Warning: could not load progress file: %v\n", err) } - + // Check if file exists if stat, err := os.Stat(dl.outputPath()); err == nil { existingSize = stat.Size() - + // If we have valid progress metadata, use it if dl.progress != nil && !dl.progress.Completed { resumeFromProgress = true totalDownloaded := dl.getTotalDownloaded() - fmt.Printf("Resuming download using progress file (%.1f%% complete)\n", + fmt.Printf("Resuming download using progress file (%.1f%% complete)\n", float64(totalDownloaded)/float64(dl.filesize)*100) } else if existingSize >= int64(dl.filesize) { fmt.Printf("File %s already fully downloaded (%d bytes)\n", dl.filename, existingSize) return nil } else if existingSize > 0 && dl.boost == 1 { // For single-threaded downloads without progress file, resume from file size - fmt.Printf("Resuming download from %d bytes (%.1f%% complete)\n", + fmt.Printf("Resuming download from %d bytes (%.1f%% complete)\n", existingSize, float64(existingSize)/float64(dl.filesize)*100) } - + // Open for writing, don't truncate outFile, err = os.OpenFile(dl.outputPath(), os.O_RDWR, 0644) if err != nil { @@ -628,7 +646,7 @@ func (dl *download) Fetch() error { dl.resume = false } } - + // Create new file if needed if outFile == nil { outFile, err = os.Create(dl.outputPath()) @@ -637,7 +655,7 @@ func (dl *download) Fetch() error { } } defer outFile.Close() - + // Set up the file size if dl.boost > 1 && dl.supportsRange && !dl.resume { // Use sparse file if supported @@ -656,7 +674,7 @@ func (dl *download) Fetch() error { } } } - + // Initialize progress tracking if not resuming if dl.progress == nil && dl.boost > 1 && dl.supportsRange { dl.initProgress() @@ -665,12 +683,21 @@ func (dl *download) Fetch() error { } } + // Initialize atomic counters from existing progress (for resume) + if resumeFromProgress && dl.progress != nil { + for i := range dl.parts { + if partProgress, ok := dl.progress.Parts[i]; ok { + atomic.StoreUint64(&dl.partDownloaded[i].val, partProgress.Downloaded) + } + } + } + // Create a progress bar spanning the entire file bar := progressbar.DefaultBytes( int64(dl.filesize), "Downloading", ) - + // Set initial progress if resuming if resumeFromProgress { bar.Set64(int64(dl.getTotalDownloaded())) @@ -687,20 +714,20 @@ func (dl *download) Fetch() error { } downloadSizeMB := downloadSize / (1024 * 1024) timeout := minDownloadTimeout + (time.Duration(downloadSizeMB) * timeoutPerMB) - + // Create a custom HTTP client with appropriate timeout singleClient := &http.Client{ - Timeout: timeout, + Timeout: timeout, Transport: httpClient.Transport, } - + // Single-stream download req, err := http.NewRequestWithContext(dl.ctx, "GET", dl.uri, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("User-Agent", "dl/1.1.1") - + // For resume, request from where we left off startOffset := int64(0) if dl.resume && existingSize > 0 { @@ -719,7 +746,7 @@ func (dl *download) Fetch() error { if startOffset > 0 { expectedStatus = http.StatusPartialContent } - + if resp.StatusCode != expectedStatus { return fmt.Errorf("server returned status %d (%s) for download", resp.StatusCode, resp.Status) } @@ -729,19 +756,19 @@ func (dl *download) Fetch() error { w: outFile, offset: startOffset, } - + // Create rate limiter if bandwidth limit is set var writer io.Writer = io.MultiWriter(ow, bar) if dl.bandwidthLimit > 0 { limiter := rate.NewLimiter(rate.Limit(dl.bandwidthLimit), int(dl.bandwidthLimit)) writer = &rateLimitedWriter{w: writer, limiter: limiter, ctx: dl.ctx} } - + _, err = io.Copy(writer, resp.Body) if err != nil { return err } - + // Clean up progress file for single-threaded downloads if dl.progress != nil { _ = dl.removeProgress() @@ -757,11 +784,11 @@ func (dl *download) Fetch() error { // Start a goroutine to periodically save progress ctx, cancel := context.WithCancel(dl.ctx) defer cancel() - + go func() { ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() - + for { select { case <-ctx.Done(): @@ -779,6 +806,37 @@ func (dl *download) Fetch() error { } }() + // UI update ticker - updates progress bar without mutex contention in download loop + go func() { + uiTicker := time.NewTicker(100 * time.Millisecond) + defer uiTicker.Stop() + + var lastTotal uint64 + for { + select { + case <-ctx.Done(): + // Final update to ensure bar reaches 100% + var finalTotal uint64 + for i := range dl.partDownloaded { + finalTotal += atomic.LoadUint64(&dl.partDownloaded[i].val) + } + if diff := finalTotal - lastTotal; diff > 0 { + bar.Add64(int64(diff)) + } + return + case <-uiTicker.C: + var currentTotal uint64 + for i := range dl.partDownloaded { + currentTotal += atomic.LoadUint64(&dl.partDownloaded[i].val) + } + if diff := currentTotal - lastTotal; diff > 0 { + bar.Add64(int64(diff)) + lastTotal = currentTotal + } + } + } + }() + // Launch goroutines for _, part := range dl.parts { // Skip completed parts when resuming @@ -786,11 +844,11 @@ func (dl *download) Fetch() error { fmt.Printf("Part %d already completed, skipping\n", part.index) continue } - + wg.Add(1) go func(dp downloadPart) { defer wg.Done() - if err := dl.fetchPartRange(dp, outFile, bar); err != nil { + if err := dl.fetchPartRange(dp, outFile); err != nil { errCh <- err } }(part) @@ -814,7 +872,7 @@ func (dl *download) Fetch() error { return e } } - + // Mark as completed and remove progress file if dl.progress != nil { dl.progress.Completed = true @@ -826,25 +884,25 @@ func (dl *download) Fetch() error { fmt.Printf("Warning: could not remove progress file: %v\n", err) } } - + return nil } -func (dl *download) fetchPartRange(p downloadPart, outFile *os.File, bar *progressbar.ProgressBar) error { +func (dl *download) fetchPartRange(p downloadPart, outFile *os.File) error { var lastErr error baseDelay := 1 * time.Second - + for i := 0; i < dl.retries; i++ { - err := dl.fetchPartOnce(p, outFile, bar) + err := dl.fetchPartOnce(p, outFile) if err == nil { return nil } lastErr = err - + if i < dl.retries-1 { // Exponential backoff: 1s, 2s, 4s delay := baseDelay * time.Duration(math.Pow(2, float64(i))) - fmt.Printf("Part %d failed (attempt %d/%d): %v. Retrying in %v...\n", + fmt.Printf("Part %d failed (attempt %d/%d): %v. Retrying in %v...\n", p.index, i+1, dl.retries, err, delay) time.Sleep(delay) } @@ -852,18 +910,18 @@ func (dl *download) fetchPartRange(p downloadPart, outFile *os.File, bar *progre return fmt.Errorf("part %d failed after %d retries: %w", p.index, dl.retries, lastErr) } -func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File, bar *progressbar.ProgressBar) error { +func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File) error { // Check if we need to resume this part startByte := p.startByte alreadyDownloaded := uint64(0) - + if dl.progress != nil && dl.progress.Parts[p.index] != nil { partProgress := dl.progress.Parts[p.index] if partProgress.Downloaded > 0 { // Resume from where we left off alreadyDownloaded = partProgress.Downloaded startByte = p.startByte + alreadyDownloaded - + // If we've already downloaded everything for this part, mark as complete expectedSize := p.endByte - p.startByte + 1 if alreadyDownloaded >= expectedSize { @@ -877,10 +935,10 @@ func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File, bar *progres partSize := p.endByte - startByte + 1 partSizeMB := partSize / (1024 * 1024) timeout := minDownloadTimeout + (time.Duration(partSizeMB) * timeoutPerMB) - + // Create a custom HTTP client with appropriate timeout for this part partClient := &http.Client{ - Timeout: timeout, + Timeout: timeout, Transport: httpClient.Transport, // Reuse the transport for connection pooling } @@ -899,44 +957,25 @@ func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File, bar *progres defer resp.Body.Close() if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { - return fmt.Errorf("server returned status %d (%s) for part %d (bytes %d-%d)", + return fmt.Errorf("server returned status %d (%s) for part %d (bytes %d-%d)", resp.StatusCode, resp.Status, p.index, startByte, p.endByte) } - // Create a custom writer that tracks progress - type progressTracker struct { - w io.WriterAt - offset int64 - index int - downloaded *uint64 - dl *download - bar *progressbar.ProgressBar - } - - downloaded := alreadyDownloaded - tracker := &progressTracker{ - w: outFile, - offset: int64(startByte), - index: p.index, - downloaded: &downloaded, - dl: dl, - bar: bar, - } - - // Custom Write method that updates progress - trackerWriter := func(p []byte) (int, error) { - n, err := tracker.w.WriteAt(p, tracker.offset) + // Track write offset for this part + writeOffset := int64(startByte) + partIndex := p.index + + // Custom Write method that updates progress via atomic counter (no mutex, no locks) + // Progress bar is updated by a separate ticker goroutine, not here + trackerWriter := func(data []byte) (int, error) { + n, err := outFile.WriteAt(data, writeOffset) if n > 0 { - tracker.offset += int64(n) - atomic.AddUint64(tracker.downloaded, uint64(n)) - // Update progress tracking - tracker.dl.updatePartProgress(tracker.index, atomic.LoadUint64(tracker.downloaded), false) - // Update progress bar - tracker.bar.Add(n) + writeOffset += int64(n) + atomic.AddUint64(&dl.partDownloaded[partIndex].val, uint64(n)) } return n, err } - + // Create rate limiter if bandwidth limit is set var writer io.Writer = WriterFunc(trackerWriter) if dl.bandwidthLimit > 0 { @@ -948,14 +987,16 @@ func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File, bar *progres limiter := rate.NewLimiter(rate.Limit(perPartLimit), int(perPartLimit)) writer = &rateLimitedWriter{w: writer, limiter: limiter, ctx: dl.ctx} } - - if _, copyErr := io.Copy(writer, resp.Body); copyErr != nil { + + // Use 1MB buffer to reduce syscall overhead + buf := make([]byte, 1024*1024) + if _, copyErr := io.CopyBuffer(writer, resp.Body, buf); copyErr != nil { return fmt.Errorf("error writing part %d: %w", p.index, copyErr) } - - // Mark part as completed - dl.updatePartProgress(p.index, atomic.LoadUint64(&downloaded), true) - + + // Mark part as completed (this one mutex call per part is fine) + dl.updatePartProgress(p.index, atomic.LoadUint64(&dl.partDownloaded[p.index].val), true) + return nil } @@ -1004,17 +1045,17 @@ func verifyChecksum(filepath string, checksumStr string) error { if len(parts) != 2 { return fmt.Errorf("invalid checksum format, expected algorithm:hash") } - + algorithm := strings.ToLower(parts[0]) expectedHash := strings.ToLower(parts[1]) - + // Open file file, err := os.Open(filepath) if err != nil { return fmt.Errorf("failed to open file: %w", err) } defer file.Close() - + // Create hasher based on algorithm var hasher hash.Hash switch algorithm { @@ -1025,17 +1066,17 @@ func verifyChecksum(filepath string, checksumStr string) error { default: return fmt.Errorf("unsupported hash algorithm: %s (supported: md5, sha256)", algorithm) } - + // Calculate hash if _, err := io.Copy(hasher, file); err != nil { return fmt.Errorf("failed to calculate hash: %w", err) } - + actualHash := hex.EncodeToString(hasher.Sum(nil)) - + if actualHash != expectedHash { return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedHash, actualHash) } - + return nil }