diff --git a/.gitignore b/.gitignore index 6550409..d388167 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -dl +/dl *.part* diff --git a/Makefile b/Makefile index ac528ab..05e3232 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ all: help .PHONY: build build: @echo "Building $(BINARY_NAME)..." - @$(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME) . + @$(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME) ./cmd/dl # Build for all platforms .PHONY: build-all @@ -34,21 +34,21 @@ build-all: build-linux build-darwin build-windows .PHONY: build-linux build-linux: @echo "Building for Linux..." - @GOOS=linux GOARCH=amd64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-linux-amd64 . - @GOOS=linux GOARCH=arm64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-linux-arm64 . + @GOOS=linux GOARCH=amd64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-linux-amd64 ./cmd/dl + @GOOS=linux GOARCH=arm64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-linux-arm64 ./cmd/dl # Build for macOS .PHONY: build-darwin build-darwin: @echo "Building for macOS..." - @GOOS=darwin GOARCH=amd64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-darwin-amd64 . - @GOOS=darwin GOARCH=arm64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-darwin-arm64 . + @GOOS=darwin GOARCH=amd64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-darwin-amd64 ./cmd/dl + @GOOS=darwin GOARCH=arm64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-darwin-arm64 ./cmd/dl # Build for Windows .PHONY: build-windows build-windows: @echo "Building for Windows..." - @GOOS=windows GOARCH=amd64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-windows-amd64.exe . + @GOOS=windows GOARCH=amd64 $(GO_BUILD) $(LDFLAGS) -o $(BINARY_NAME)-windows-amd64.exe ./cmd/dl # Run tests .PHONY: test @@ -101,7 +101,7 @@ clean: .PHONY: install install: @echo "Installing $(BINARY_NAME)..." - @$(GO) install $(LDFLAGS) . + @$(GO) install $(LDFLAGS) ./cmd/dl @echo "Installed to $(shell go env GOPATH)/bin/$(BINARY_NAME)" @if ! echo $$PATH | grep -q "$(shell go env GOPATH)/bin"; then \ echo ""; \ @@ -136,7 +136,7 @@ test-resume: build .PHONY: dev dev: @echo "Building with race detector..." - @$(GO_BUILD) -race $(LDFLAGS) -o $(BINARY_NAME) . + @$(GO_BUILD) -race $(LDFLAGS) -o $(BINARY_NAME) ./cmd/dl @echo "Ready for development testing" # Show help @@ -158,4 +158,4 @@ help: @echo " make test-download - Test download functionality" @echo " make test-resume - Test resume functionality" @echo " make dev - Build with race detector for development" - @echo " make help - Show this help message" \ No newline at end of file + @echo " make help - Show this help message" diff --git a/README.md b/README.md index ecea45d..9e36b23 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Requires Go 1.23+ ```bash git clone https://github.com/mgomes/dl.git cd dl -go build +go build -o dl ./cmd/dl ``` ## Usage @@ -33,6 +33,44 @@ go build dl [url2] [url3] ... ``` +## Library usage + +```go +package main + +import ( + "context" + "fmt" + "os" + + "github.com/mgomes/dl" +) + +func main() { + ctx := context.Background() + downloader := dl.Downloader{ + URI: "https://example.com/file.zip", + Context: ctx, + WorkingDir: ".", + Boost: 4, + Retries: 3, + Resume: true, + } + + if err := downloader.FetchMetadata(); err != nil { + fmt.Printf("metadata error: %v\n", err) + os.Exit(1) + } + + if err := downloader.Fetch(); err != nil { + fmt.Printf("download error: %v\n", err) + os.Exit(1) + } + + fmt.Println("saved to", downloader.OutputPath()) +} +``` + ### Options ``` diff --git a/checksum.go b/cmd/dl/checksum.go similarity index 100% rename from checksum.go rename to cmd/dl/checksum.go diff --git a/checksum_test.go b/cmd/dl/checksum_test.go similarity index 100% rename from checksum_test.go rename to cmd/dl/checksum_test.go diff --git a/config.go b/cmd/dl/config.go similarity index 100% rename from config.go rename to cmd/dl/config.go diff --git a/config_test.go b/cmd/dl/config_test.go similarity index 100% rename from config_test.go rename to cmd/dl/config_test.go diff --git a/main.go b/cmd/dl/main.go similarity index 74% rename from main.go rename to cmd/dl/main.go index 7f37f56..e12bdff 100644 --- a/main.go +++ b/cmd/dl/main.go @@ -8,6 +8,8 @@ import ( "os" "os/signal" "syscall" + + "github.com/mgomes/dl" ) func main() { @@ -52,21 +54,22 @@ func main() { }() for _, uri := range fileURIs { - var dl download - dl.uri = uri - dl.boost = *boostPtr - dl.retries = *retriesPtr - dl.resume = *resumePtr && !*noResumePtr - dl.bandwidthLimit = bandwidthLimit - dl.ctx = ctx - - if err := dl.FetchMetadata(); err != nil { + downloader := dl.Downloader{ + URI: uri, + Boost: *boostPtr, + Retries: *retriesPtr, + Resume: *resumePtr && !*noResumePtr, + BandwidthLimit: bandwidthLimit, + Context: ctx, + } + + if err := downloader.FetchMetadata(); err != nil { fmt.Fprintf(os.Stderr, "Error fetching metadata for %s: %v\n", uri, err) os.Exit(1) } if *filenamePtr != "" { - dl.filename = *filenamePtr + downloader.Filename = *filenamePtr } wd, err := os.Getwd() @@ -74,32 +77,31 @@ func main() { fmt.Fprintf(os.Stderr, "Error getting working directory: %v\n", err) os.Exit(1) } - dl.workingDir = wd + downloader.WorkingDir = wd - fmt.Println("Downloading:", dl.filename) + fmt.Println("Downloading:", downloader.Filename) - if !dl.supportsRange && dl.boost > 1 { + if !downloader.SupportsRange() && downloader.Boost > 1 { fmt.Println("Server does not support partial content. Falling back to single-threaded download.") - dl.boost = 1 + downloader.Boost = 1 } - if err := dl.Fetch(); err != nil { + downloader.Progress = newProgressBarReporter(downloader.FileSize()) + + if err := downloader.Fetch(); err != nil { if errors.Is(err, context.Canceled) { fmt.Println("\nDownload cancelled") os.Exit(1) } fmt.Fprintf(os.Stderr, "Error while downloading: %v\n", err) - if dl.progress == nil { - _ = os.Remove(dl.outputPath()) - } os.Exit(1) } - fmt.Println("Download completed:", dl.filename) + fmt.Println("Download completed:", downloader.Filename) if *checksumPtr != "" { fmt.Printf("Verifying checksum...") - if err := verifyChecksum(dl.outputPath(), *checksumPtr); err != nil { + if err := verifyChecksum(downloader.OutputPath(), *checksumPtr); err != nil { fmt.Fprintf(os.Stderr, "\nChecksum verification failed: %v\n", err) os.Exit(1) } diff --git a/cmd/dl/progress.go b/cmd/dl/progress.go new file mode 100644 index 0000000..08b0fa4 --- /dev/null +++ b/cmd/dl/progress.go @@ -0,0 +1,42 @@ +package main + +import "github.com/schollz/progressbar/v3" + +type progressBarReporter struct { + bar *progressbar.ProgressBar +} + +func newProgressBarReporter(total uint64) *progressBarReporter { + return &progressBarReporter{ + bar: progressbar.DefaultBytes(int64(total), "Downloading"), + } +} + +func (p *progressBarReporter) SetTotal(total uint64) { + if p.bar == nil { + p.bar = progressbar.DefaultBytes(int64(total), "Downloading") + return + } + p.bar.ChangeMax64(int64(total)) +} + +func (p *progressBarReporter) SetDownloaded(downloaded uint64) { + if p.bar == nil { + return + } + _ = p.bar.Set64(int64(downloaded)) +} + +func (p *progressBarReporter) AddDownloaded(delta uint64) { + if p.bar == nil { + return + } + _ = p.bar.Add64(int64(delta)) +} + +func (p *progressBarReporter) Done() { + if p.bar == nil { + return + } + _ = p.bar.Finish() +} diff --git a/download.go b/download.go index cecb231..02585e1 100644 --- a/download.go +++ b/download.go @@ -1,4 +1,4 @@ -package main +package dl import ( "context" @@ -14,7 +14,6 @@ import ( "sync/atomic" "time" - "github.com/schollz/progressbar/v3" "golang.org/x/time/rate" ) @@ -52,25 +51,49 @@ type atomicCounter struct { _ [56]byte } -type download struct { - uri string - filesize uint64 - filename string - workingDir string - boost int - retries int - resume bool - bandwidthLimit int64 - parts []downloadPart - supportsRange bool - ctx context.Context - progress *DownloadProgress - progressMutex sync.Mutex - partDownloaded []atomicCounter +type Downloader struct { + URI string + Filename string + WorkingDir string + Boost int + Retries int + Resume bool + BandwidthLimit int64 + Progress ProgressReporter + Context context.Context + + fileSize uint64 + supportsRange bool + parts []downloadPart + progress *DownloadProgress + progressMutex sync.Mutex + partDownloaded []atomicCounter + metadataFetched bool } -func (dl *download) FetchMetadata() error { - req, err := http.NewRequestWithContext(dl.ctx, "HEAD", dl.uri, nil) +const ( + DefaultBoost = 8 + DefaultRetries = 3 +) + +func (dl *Downloader) FileSize() uint64 { + return dl.fileSize +} + +func (dl *Downloader) SupportsRange() bool { + return dl.supportsRange +} + +func (dl *Downloader) FetchMetadata() error { + if dl.URI == "" { + return fmt.Errorf("uri is required") + } + + if dl.Context == nil { + dl.Context = context.Background() + } + + req, err := http.NewRequestWithContext(dl.Context, "HEAD", dl.URI, nil) if err != nil { return fmt.Errorf("failed to create HEAD request: %w", err) } @@ -91,7 +114,7 @@ func (dl *download) FetchMetadata() error { return fmt.Errorf("server did not provide Content-Length header, cannot determine file size") } - dl.filesize, err = strconv.ParseUint(contentLength, 10, 64) + dl.fileSize, err = strconv.ParseUint(contentLength, 10, 64) if err != nil { return fmt.Errorf("invalid Content-Length '%s': %w", contentLength, err) } @@ -99,63 +122,84 @@ func (dl *download) FetchMetadata() error { acceptRanges := resp.Header.Get("Accept-Ranges") dl.supportsRange = strings.ToLower(acceptRanges) == "bytes" - contentDisposition := resp.Header.Get("Content-Disposition") - if contentDisposition != "" { - _, params, err := mime.ParseMediaType(contentDisposition) - if err == nil && params["filename"] != "" { - dl.filename = params["filename"] + if dl.Filename == "" { + contentDisposition := resp.Header.Get("Content-Disposition") + if contentDisposition != "" { + _, params, err := mime.ParseMediaType(contentDisposition) + if err == nil && params["filename"] != "" { + dl.Filename = params["filename"] + } else { + dl.Filename = dl.filenameFromURI() + } } else { - dl.filename = dl.filenameFromURI() + dl.Filename = dl.filenameFromURI() } - } else { - dl.filename = dl.filenameFromURI() } + dl.metadataFetched = true + return nil } -func (dl *download) Fetch() error { +func (dl *Downloader) Fetch() error { var outFile *os.File var err error var existingSize int64 var resumeFromProgress bool - if dl.boost > 1 && dl.supportsRange { - dl.parts = make([]downloadPart, dl.boost) - dl.partDownloaded = make([]atomicCounter, dl.boost) - for i := 0; i < dl.boost; i++ { + if err := dl.ensureDefaults(); err != nil { + return err + } + + if !dl.metadataFetched { + if err := dl.FetchMetadata(); err != nil { + return err + } + } + + reporter := dl.progressReporter() + reporter.SetTotal(dl.fileSize) + + if dl.Boost > 1 && dl.supportsRange { + dl.parts = make([]downloadPart, dl.Boost) + dl.partDownloaded = make([]atomicCounter, dl.Boost) + for i := range dl.Boost { start, end := dl.calculatePartBoundary(i) dl.parts[i] = downloadPart{ index: i, - uri: dl.uri, + uri: dl.URI, startByte: start, endByte: end, } } } - if dl.resume { + if dl.Resume { if err := dl.loadProgress(); err != nil { fmt.Printf("Warning: could not load progress file: %v\n", err) } - if stat, err := os.Stat(dl.outputPath()); err == nil { + if stat, err := os.Stat(dl.OutputPath()); err == nil { existingSize = stat.Size() if dl.progress != nil && !dl.progress.Completed { resumeFromProgress = true totalDownloaded := dl.getTotalDownloaded() fmt.Printf("Resuming download using progress file (%.1f%% complete)\n", - float64(totalDownloaded)/float64(dl.filesize)*100) - } else if existingSize >= int64(dl.filesize) { - fmt.Printf("File %s already fully downloaded (%d bytes)\n", dl.filename, existingSize) + float64(totalDownloaded)/float64(dl.fileSize)*100) + reporter.SetDownloaded(totalDownloaded) + } else if existingSize >= int64(dl.fileSize) { + fmt.Printf("File %s already fully downloaded (%d bytes)\n", dl.Filename, existingSize) + reporter.SetDownloaded(dl.fileSize) + reporter.Done() return nil - } else if existingSize > 0 && dl.boost == 1 { + } else if existingSize > 0 && dl.Boost == 1 { fmt.Printf("Resuming download from %d bytes (%.1f%% complete)\n", - existingSize, float64(existingSize)/float64(dl.filesize)*100) + existingSize, float64(existingSize)/float64(dl.fileSize)*100) + reporter.SetDownloaded(uint64(existingSize)) } - outFile, err = os.OpenFile(dl.outputPath(), os.O_RDWR, 0644) + outFile, err = os.OpenFile(dl.OutputPath(), os.O_RDWR, 0644) if err != nil { return fmt.Errorf("cannot open file for resume: %w", err) } @@ -165,34 +209,34 @@ func (dl *download) Fetch() error { dl.progress = nil resumeFromProgress = false } - dl.resume = false + dl.Resume = false } } if outFile == nil { - outFile, err = os.Create(dl.outputPath()) + outFile, err = os.Create(dl.OutputPath()) if err != nil { return fmt.Errorf("cannot create output file: %w", err) } } defer outFile.Close() - if dl.boost > 1 && dl.supportsRange && !dl.resume { - if supportsSparseFiles(dl.workingDir) { - if err := createSparseFile(outFile, int64(dl.filesize)); err != nil { + if dl.Boost > 1 && dl.supportsRange && !dl.Resume { + if supportsSparseFiles(dl.WorkingDir) { + if err := createSparseFile(outFile, int64(dl.fileSize)); err != nil { fmt.Printf("Warning: sparse file creation failed, using regular allocation: %v\n", err) - if err = outFile.Truncate(int64(dl.filesize)); err != nil { + if err = outFile.Truncate(int64(dl.fileSize)); err != nil { return fmt.Errorf("error setting file size: %w", err) } } } else { - if err = outFile.Truncate(int64(dl.filesize)); err != nil { + if err = outFile.Truncate(int64(dl.fileSize)); err != nil { return fmt.Errorf("error setting file size: %w", err) } } } - if dl.progress == nil && dl.boost > 1 && dl.supportsRange { + if dl.progress == nil && dl.Boost > 1 && dl.supportsRange { dl.initProgress() if err := dl.saveProgress(); err != nil { fmt.Printf("Warning: could not save initial progress: %v\n", err) @@ -207,28 +251,24 @@ func (dl *download) Fetch() error { } } - bar := progressbar.DefaultBytes( - int64(dl.filesize), - "Downloading", - ) - - if resumeFromProgress { - bar.Set64(int64(dl.getTotalDownloaded())) - } else if dl.resume && existingSize > 0 && dl.boost == 1 { - bar.Set64(existingSize) + if dl.Boost == 1 || !dl.supportsRange { + err = dl.fetchSingleStream(outFile, reporter, existingSize) + } else { + err = dl.fetchMultiPart(outFile, reporter, resumeFromProgress) } - if dl.boost == 1 || !dl.supportsRange { - return dl.fetchSingleStream(outFile, bar, existingSize) + if err != nil { + return err } - return dl.fetchMultiPart(outFile, bar, resumeFromProgress) + reporter.Done() + return nil } -func (dl *download) fetchSingleStream(outFile *os.File, bar *progressbar.ProgressBar, existingSize int64) error { - downloadSize := dl.filesize - if dl.resume && existingSize > 0 { - downloadSize = dl.filesize - uint64(existingSize) +func (dl *Downloader) fetchSingleStream(outFile *os.File, reporter ProgressReporter, existingSize int64) error { + downloadSize := dl.fileSize + if dl.Resume && existingSize > 0 { + downloadSize = dl.fileSize - uint64(existingSize) } downloadSizeMB := downloadSize / (1024 * 1024) timeout := minDownloadTimeout + (time.Duration(downloadSizeMB) * timeoutPerMB) @@ -238,17 +278,16 @@ func (dl *download) fetchSingleStream(outFile *os.File, bar *progressbar.Progres Transport: httpClient.Transport, } - req, err := http.NewRequestWithContext(dl.ctx, "GET", dl.uri, nil) + req, err := http.NewRequestWithContext(dl.Context, "GET", dl.URI, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("User-Agent", "dl/1.1.1") startOffset := int64(0) - if dl.resume && existingSize > 0 { + if dl.Resume && existingSize > 0 { req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existingSize)) startOffset = existingSize - bar.Set64(existingSize) } resp, err := singleClient.Do(req) @@ -271,10 +310,10 @@ func (dl *download) fetchSingleStream(outFile *os.File, bar *progressbar.Progres offset: startOffset, } - var writer io.Writer = io.MultiWriter(ow, bar) - if dl.bandwidthLimit > 0 { - limiter := rate.NewLimiter(rate.Limit(dl.bandwidthLimit), int(dl.bandwidthLimit)) - writer = &rateLimitedWriter{w: writer, limiter: limiter, ctx: dl.ctx} + var writer io.Writer = io.MultiWriter(ow, &progressWriter{reporter: reporter}) + if dl.BandwidthLimit > 0 { + limiter := rate.NewLimiter(rate.Limit(dl.BandwidthLimit), int(dl.BandwidthLimit)) + writer = &rateLimitedWriter{w: writer, limiter: limiter, ctx: dl.Context} } _, err = io.Copy(writer, resp.Body) @@ -288,12 +327,11 @@ func (dl *download) fetchSingleStream(outFile *os.File, bar *progressbar.Progres return nil } -func (dl *download) fetchMultiPart(outFile *os.File, bar *progressbar.ProgressBar, resumeFromProgress bool) error { +func (dl *Downloader) fetchMultiPart(outFile *os.File, reporter ProgressReporter, resumeFromProgress bool) error { var wg sync.WaitGroup - errCh := make(chan error, dl.boost) - progressCh := make(chan struct{}, 1) + errCh := make(chan error, dl.Boost) - ctx, cancel := context.WithCancel(dl.ctx) + ctx, cancel := context.WithCancel(dl.Context) defer cancel() go func() { @@ -308,39 +346,6 @@ func (dl *download) fetchMultiPart(outFile *os.File, bar *progressbar.ProgressBa if err := dl.saveProgress(); err != nil { fmt.Printf("Warning: could not save progress: %v\n", err) } - case <-progressCh: - if err := dl.saveProgress(); err != nil { - fmt.Printf("Warning: could not save progress: %v\n", err) - } - } - } - }() - - go func() { - uiTicker := time.NewTicker(100 * time.Millisecond) - defer uiTicker.Stop() - - var lastTotal uint64 - for { - select { - case <-ctx.Done(): - var finalTotal uint64 - for i := range dl.partDownloaded { - finalTotal += atomic.LoadUint64(&dl.partDownloaded[i].val) - } - if diff := finalTotal - lastTotal; diff > 0 { - bar.Add64(int64(diff)) - } - return - case <-uiTicker.C: - var currentTotal uint64 - for i := range dl.partDownloaded { - currentTotal += atomic.LoadUint64(&dl.partDownloaded[i].val) - } - if diff := currentTotal - lastTotal; diff > 0 { - bar.Add64(int64(diff)) - lastTotal = currentTotal - } } } }() @@ -351,13 +356,12 @@ func (dl *download) fetchMultiPart(outFile *os.File, bar *progressbar.ProgressBa continue } - wg.Add(1) - go func(dp downloadPart) { - defer wg.Done() - if err := dl.fetchPartRange(dp, outFile); err != nil { + part := part + wg.Go(func() { + if err := dl.fetchPartRange(part, outFile, reporter); err != nil { errCh <- err } - }(part) + }) } wg.Wait() @@ -389,28 +393,28 @@ func (dl *download) fetchMultiPart(outFile *os.File, bar *progressbar.ProgressBa return nil } -func (dl *download) fetchPartRange(p downloadPart, outFile *os.File) error { +func (dl *Downloader) fetchPartRange(p downloadPart, outFile *os.File, reporter ProgressReporter) error { var lastErr error baseDelay := 1 * time.Second - for i := 0; i < dl.retries; i++ { - err := dl.fetchPartOnce(p, outFile) + for i := range dl.Retries { + err := dl.fetchPartOnce(p, outFile, reporter) if err == nil { return nil } lastErr = err - if i < dl.retries-1 { + if i < dl.Retries-1 { delay := baseDelay * time.Duration(math.Pow(2, float64(i))) fmt.Printf("Part %d failed (attempt %d/%d): %v. Retrying in %v...\n", - p.index, i+1, dl.retries, err, delay) + p.index, i+1, dl.Retries, err, delay) time.Sleep(delay) } } - return fmt.Errorf("part %d failed after %d retries: %w", p.index, dl.retries, lastErr) + return fmt.Errorf("part %d failed after %d retries: %w", p.index, dl.Retries, lastErr) } -func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File) error { +func (dl *Downloader) fetchPartOnce(p downloadPart, outFile *os.File, reporter ProgressReporter) error { startByte := p.startByte alreadyDownloaded := uint64(0) @@ -438,7 +442,7 @@ func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File) error { } byteRange := fmt.Sprintf("bytes=%d-%d", startByte, p.endByte) - req, err := http.NewRequestWithContext(dl.ctx, "GET", p.uri, nil) + req, err := http.NewRequestWithContext(dl.Context, "GET", p.uri, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } @@ -464,18 +468,19 @@ func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File) error { if n > 0 { writeOffset += int64(n) atomic.AddUint64(&dl.partDownloaded[partIndex].val, uint64(n)) + reporter.AddDownloaded(uint64(n)) } return n, err } var writer io.Writer = WriterFunc(trackerWriter) - if dl.bandwidthLimit > 0 { - perPartLimit := dl.bandwidthLimit / int64(dl.boost) + if dl.BandwidthLimit > 0 { + perPartLimit := dl.BandwidthLimit / int64(dl.Boost) if perPartLimit < 1024 { perPartLimit = 1024 } limiter := rate.NewLimiter(rate.Limit(perPartLimit), int(perPartLimit)) - writer = &rateLimitedWriter{w: writer, limiter: limiter, ctx: dl.ctx} + writer = &rateLimitedWriter{w: writer, limiter: limiter, ctx: dl.Context} } buf := make([]byte, 1024*1024) @@ -488,21 +493,21 @@ func (dl *download) fetchPartOnce(p downloadPart, outFile *os.File) error { return nil } -func (dl *download) calculatePartBoundary(part int) (uint64, uint64) { - chunkSize := dl.filesize / uint64(dl.boost) +func (dl *Downloader) calculatePartBoundary(part int) (uint64, uint64) { + chunkSize := dl.fileSize / uint64(dl.Boost) startByte := uint64(part) * chunkSize var endByte uint64 - if part == dl.boost-1 { - endByte = dl.filesize - 1 + if part == dl.Boost-1 { + endByte = dl.fileSize - 1 } else { endByte = startByte + chunkSize - 1 } return startByte, endByte } -func (dl *download) filenameFromURI() string { - splitURI := strings.Split(dl.uri, "/") +func (dl *Downloader) filenameFromURI() string { + splitURI := strings.Split(dl.URI, "/") filename := splitURI[len(splitURI)-1] if idx := strings.Index(filename, "?"); idx != -1 { filename = filename[:idx] @@ -511,6 +516,42 @@ func (dl *download) filenameFromURI() string { return filename } -func (dl *download) outputPath() string { - return fmt.Sprintf("%s%c%s", dl.workingDir, os.PathSeparator, dl.filename) +func (dl *Downloader) OutputPath() string { + return fmt.Sprintf("%s%c%s", dl.WorkingDir, os.PathSeparator, dl.Filename) +} + +func (dl *Downloader) ensureDefaults() error { + if dl.URI == "" { + return fmt.Errorf("uri is required") + } + + if dl.Context == nil { + dl.Context = context.Background() + } + + if dl.WorkingDir == "" { + wd, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get working directory: %w", err) + } + dl.WorkingDir = wd + } + + if dl.Boost <= 0 { + dl.Boost = DefaultBoost + } + + if dl.Retries <= 0 { + dl.Retries = DefaultRetries + } + + return nil +} + +func (dl *Downloader) progressReporter() ProgressReporter { + if dl.Progress == nil { + return noopProgressReporter{} + } + + return dl.Progress } diff --git a/download_test.go b/download_test.go index 2a543c3..d3af470 100644 --- a/download_test.go +++ b/download_test.go @@ -1,4 +1,4 @@ -package main +package dl import ( "context" @@ -10,9 +10,9 @@ import ( ) func TestCalculatePartBoundary(t *testing.T) { - dl := download{ - filesize: 1000, - boost: 4, + dl := Downloader{ + fileSize: 1000, + Boost: 4, } tests := []struct { @@ -38,9 +38,9 @@ func TestCalculatePartBoundary(t *testing.T) { } func TestCalculatePartBoundaryUneven(t *testing.T) { - dl := download{ - filesize: 1000, - boost: 3, + dl := Downloader{ + fileSize: 1000, + Boost: 3, } start0, end0 := dl.calculatePartBoundary(0) @@ -73,7 +73,7 @@ func TestFilenameFromURI(t *testing.T) { for _, tt := range tests { t.Run(tt.uri, func(t *testing.T) { - dl := download{uri: tt.uri} + dl := Downloader{URI: tt.uri} result := dl.filenameFromURI() if result != tt.expected { t.Errorf("for URI %s, expected %s but got %s", tt.uri, tt.expected, result) @@ -83,13 +83,13 @@ func TestFilenameFromURI(t *testing.T) { } func TestOutputPath(t *testing.T) { - dl := download{ - workingDir: "/home/user/downloads", - filename: "test.zip", + dl := Downloader{ + WorkingDir: "/home/user/downloads", + Filename: "test.zip", } expected := "/home/user/downloads/test.zip" - result := dl.outputPath() + result := dl.OutputPath() if result != expected { t.Errorf("expected %s, got %s", expected, result) } @@ -107,23 +107,23 @@ func TestFetchMetadata(t *testing.T) { })) defer server.Close() - dl := download{ - uri: server.URL + "/test", - ctx: context.Background(), + dl := Downloader{ + URI: server.URL + "/test", + Context: context.Background(), } if err := dl.FetchMetadata(); err != nil { t.Fatalf("FetchMetadata failed: %v", err) } - if dl.filesize != 12345 { - t.Errorf("expected filesize=12345, got %d", dl.filesize) + if dl.FileSize() != 12345 { + t.Errorf("expected filesize=12345, got %d", dl.FileSize()) } - if !dl.supportsRange { + if !dl.SupportsRange() { t.Error("expected supportsRange=true") } - if dl.filename != "test-file.zip" { - t.Errorf("expected filename='test-file.zip', got '%s'", dl.filename) + if dl.Filename != "test-file.zip" { + t.Errorf("expected filename='test-file.zip', got '%s'", dl.Filename) } } @@ -134,20 +134,20 @@ func TestFetchMetadataNoRange(t *testing.T) { })) defer server.Close() - dl := download{ - uri: server.URL + "/file.zip", - ctx: context.Background(), + dl := Downloader{ + URI: server.URL + "/file.zip", + Context: context.Background(), } if err := dl.FetchMetadata(); err != nil { t.Fatalf("FetchMetadata failed: %v", err) } - if dl.supportsRange { + if dl.SupportsRange() { t.Error("expected supportsRange=false when Accept-Ranges not set") } - if dl.filename != "file.zip" { - t.Errorf("expected filename='file.zip', got '%s'", dl.filename) + if dl.Filename != "file.zip" { + t.Errorf("expected filename='file.zip', got '%s'", dl.Filename) } } @@ -157,9 +157,9 @@ func TestFetchMetadataNoContentLength(t *testing.T) { })) defer server.Close() - dl := download{ - uri: server.URL + "/test", - ctx: context.Background(), + dl := Downloader{ + URI: server.URL + "/test", + Context: context.Background(), } err := dl.FetchMetadata() @@ -174,9 +174,9 @@ func TestFetchMetadataServerError(t *testing.T) { })) defer server.Close() - dl := download{ - uri: server.URL + "/test", - ctx: context.Background(), + dl := Downloader{ + URI: server.URL + "/test", + Context: context.Background(), } err := dl.FetchMetadata() @@ -204,11 +204,11 @@ func TestFetchSingleStream(t *testing.T) { } defer os.RemoveAll(tmpDir) - dl := download{ - uri: server.URL + "/test.txt", - ctx: context.Background(), - boost: 1, - workingDir: tmpDir, + dl := Downloader{ + URI: server.URL + "/test.txt", + Context: context.Background(), + Boost: 1, + WorkingDir: tmpDir, } if err := dl.FetchMetadata(); err != nil { @@ -219,7 +219,7 @@ func TestFetchSingleStream(t *testing.T) { t.Fatalf("Fetch failed: %v", err) } - downloaded, err := os.ReadFile(dl.outputPath()) + downloaded, err := os.ReadFile(dl.OutputPath()) if err != nil { t.Fatalf("failed to read downloaded file: %v", err) } @@ -263,12 +263,12 @@ func TestFetchMultiPart(t *testing.T) { } defer os.RemoveAll(tmpDir) - dl := download{ - uri: server.URL + "/test.bin", - ctx: context.Background(), - boost: 4, - retries: 3, - workingDir: tmpDir, + dl := Downloader{ + URI: server.URL + "/test.bin", + Context: context.Background(), + Boost: 4, + Retries: 3, + WorkingDir: tmpDir, } if err := dl.FetchMetadata(); err != nil { @@ -279,7 +279,7 @@ func TestFetchMultiPart(t *testing.T) { t.Fatalf("Fetch failed: %v", err) } - downloaded, err := os.ReadFile(dl.outputPath()) + downloaded, err := os.ReadFile(dl.OutputPath()) if err != nil { t.Fatalf("failed to read downloaded file: %v", err) } diff --git a/file.go b/file.go index 2435c48..acddc7a 100644 --- a/file.go +++ b/file.go @@ -1,4 +1,4 @@ -package main +package dl import ( "fmt" diff --git a/file_test.go b/file_test.go index ee967de..b7c095b 100644 --- a/file_test.go +++ b/file_test.go @@ -1,4 +1,4 @@ -package main +package dl import ( "os" diff --git a/go.mod b/go.mod index b0fb8e0..397a286 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,15 @@ module github.com/mgomes/dl -go 1.23.0 +go 1.25 -toolchain go1.24.2 - -require github.com/schollz/progressbar/v3 v3.7.2 +require ( + github.com/schollz/progressbar/v3 v3.7.2 + golang.org/x/time v0.11.0 +) require ( github.com/mattn/go-runewidth v0.0.9 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9 // indirect golang.org/x/sys v0.0.0-20220325203850-36772127a21f // indirect - golang.org/x/time v0.11.0 // indirect ) diff --git a/progress.go b/progress.go index 2650eba..47d508e 100644 --- a/progress.go +++ b/progress.go @@ -1,4 +1,4 @@ -package main +package dl import ( "encoding/json" @@ -28,11 +28,11 @@ type PartProgress struct { LastModified time.Time `json:"last_modified"` } -func (dl *download) progressFilePath() string { - return fmt.Sprintf("%s%c.%s.dl_progress", dl.workingDir, os.PathSeparator, dl.filename) +func (dl *Downloader) progressFilePath() string { + return fmt.Sprintf("%s%c.%s.dl_progress", dl.WorkingDir, os.PathSeparator, dl.Filename) } -func (dl *download) loadProgress() error { +func (dl *Downloader) loadProgress() error { progressPath := dl.progressFilePath() data, err := os.ReadFile(progressPath) if err != nil { @@ -47,7 +47,7 @@ func (dl *download) loadProgress() error { return fmt.Errorf("failed to parse progress file: %w", err) } - if progress.URI != dl.uri || progress.FileSize != dl.filesize { + if progress.URI != dl.URI || progress.FileSize != dl.fileSize { fmt.Println("Progress file is for a different download, starting fresh") return nil } @@ -56,7 +56,7 @@ func (dl *download) loadProgress() error { return nil } -func (dl *download) saveProgress() error { +func (dl *Downloader) saveProgress() error { dl.progressMutex.Lock() defer dl.progressMutex.Unlock() @@ -90,16 +90,16 @@ func (dl *download) saveProgress() error { return nil } -func (dl *download) removeProgress() error { +func (dl *Downloader) removeProgress() error { return os.Remove(dl.progressFilePath()) } -func (dl *download) initProgress() { +func (dl *Downloader) initProgress() { dl.progress = &DownloadProgress{ Version: 1, - URI: dl.uri, - FileSize: dl.filesize, - Filename: dl.filename, + URI: dl.URI, + FileSize: dl.fileSize, + Filename: dl.Filename, Parts: make(map[int]*PartProgress), Created: time.Now(), LastUpdated: time.Now(), @@ -117,7 +117,7 @@ func (dl *download) initProgress() { } } -func (dl *download) updatePartProgress(index int, downloaded uint64, completed bool) { +func (dl *Downloader) updatePartProgress(index int, downloaded uint64, completed bool) { dl.progressMutex.Lock() defer dl.progressMutex.Unlock() @@ -128,7 +128,7 @@ func (dl *download) updatePartProgress(index int, downloaded uint64, completed b } } -func (dl *download) getTotalDownloaded() uint64 { +func (dl *Downloader) getTotalDownloaded() uint64 { dl.progressMutex.Lock() defer dl.progressMutex.Unlock() diff --git a/progress_reporter.go b/progress_reporter.go new file mode 100644 index 0000000..01d941b --- /dev/null +++ b/progress_reporter.go @@ -0,0 +1,26 @@ +package dl + +type ProgressReporter interface { + SetTotal(total uint64) + SetDownloaded(downloaded uint64) + AddDownloaded(delta uint64) + Done() +} + +type noopProgressReporter struct{} + +func (noopProgressReporter) SetTotal(uint64) {} +func (noopProgressReporter) SetDownloaded(uint64) {} +func (noopProgressReporter) AddDownloaded(uint64) {} +func (noopProgressReporter) Done() {} + +type progressWriter struct { + reporter ProgressReporter +} + +func (pw *progressWriter) Write(p []byte) (int, error) { + if pw.reporter != nil { + pw.reporter.AddDownloaded(uint64(len(p))) + } + return len(p), nil +} diff --git a/progress_test.go b/progress_test.go index 8c48594..45c9462 100644 --- a/progress_test.go +++ b/progress_test.go @@ -1,4 +1,4 @@ -package main +package dl import ( "os" @@ -7,9 +7,9 @@ import ( ) func TestProgressFilePath(t *testing.T) { - dl := download{ - workingDir: "/tmp", - filename: "test.zip", + dl := Downloader{ + WorkingDir: "/tmp", + Filename: "test.zip", } expected := "/tmp/.test.zip.dl_progress" @@ -20,11 +20,11 @@ func TestProgressFilePath(t *testing.T) { } func TestInitProgress(t *testing.T) { - dl := download{ - uri: "http://example.com/file.zip", - filesize: 1000, - filename: "file.zip", - boost: 4, + dl := Downloader{ + URI: "http://example.com/file.zip", + fileSize: 1000, + Filename: "file.zip", + Boost: 4, parts: []downloadPart{ {index: 0, startByte: 0, endByte: 249}, {index: 1, startByte: 250, endByte: 499}, @@ -38,11 +38,11 @@ func TestInitProgress(t *testing.T) { if dl.progress == nil { t.Fatal("progress should not be nil") } - if dl.progress.URI != dl.uri { - t.Errorf("expected URI %s, got %s", dl.uri, dl.progress.URI) + if dl.progress.URI != dl.URI { + t.Errorf("expected URI %s, got %s", dl.URI, dl.progress.URI) } - if dl.progress.FileSize != dl.filesize { - t.Errorf("expected filesize %d, got %d", dl.filesize, dl.progress.FileSize) + if dl.progress.FileSize != dl.fileSize { + t.Errorf("expected filesize %d, got %d", dl.fileSize, dl.progress.FileSize) } if len(dl.progress.Parts) != 4 { t.Errorf("expected 4 parts, got %d", len(dl.progress.Parts)) @@ -69,12 +69,12 @@ func TestSaveAndLoadProgress(t *testing.T) { } defer os.RemoveAll(tmpDir) - dl := download{ - uri: "http://example.com/file.zip", - filesize: 1000, - filename: "file.zip", - workingDir: tmpDir, - boost: 2, + dl := Downloader{ + URI: "http://example.com/file.zip", + fileSize: 1000, + Filename: "file.zip", + WorkingDir: tmpDir, + Boost: 2, partDownloaded: make([]atomicCounter, 2), parts: []downloadPart{ {index: 0, startByte: 0, endByte: 499}, @@ -95,11 +95,11 @@ func TestSaveAndLoadProgress(t *testing.T) { t.Error("progress file was not created") } - dl2 := download{ - uri: "http://example.com/file.zip", - filesize: 1000, - filename: "file.zip", - workingDir: tmpDir, + dl2 := Downloader{ + URI: "http://example.com/file.zip", + fileSize: 1000, + Filename: "file.zip", + WorkingDir: tmpDir, } if err := dl2.loadProgress(); err != nil { @@ -124,12 +124,12 @@ func TestLoadProgressMismatch(t *testing.T) { } defer os.RemoveAll(tmpDir) - dl := download{ - uri: "http://example.com/file.zip", - filesize: 1000, - filename: "file.zip", - workingDir: tmpDir, - boost: 1, + dl := Downloader{ + URI: "http://example.com/file.zip", + fileSize: 1000, + Filename: "file.zip", + WorkingDir: tmpDir, + Boost: 1, partDownloaded: make([]atomicCounter, 1), parts: []downloadPart{ {index: 0, startByte: 0, endByte: 999}, @@ -138,11 +138,11 @@ func TestLoadProgressMismatch(t *testing.T) { dl.initProgress() dl.saveProgress() - dl2 := download{ - uri: "http://example.com/different.zip", - filesize: 1000, - filename: "file.zip", - workingDir: tmpDir, + dl2 := Downloader{ + URI: "http://example.com/different.zip", + fileSize: 1000, + Filename: "file.zip", + WorkingDir: tmpDir, } if err := dl2.loadProgress(); err != nil { @@ -161,11 +161,11 @@ func TestLoadProgressNoFile(t *testing.T) { } defer os.RemoveAll(tmpDir) - dl := download{ - uri: "http://example.com/file.zip", - filesize: 1000, - filename: "file.zip", - workingDir: tmpDir, + dl := Downloader{ + URI: "http://example.com/file.zip", + fileSize: 1000, + Filename: "file.zip", + WorkingDir: tmpDir, } if err := dl.loadProgress(); err != nil { @@ -184,12 +184,12 @@ func TestRemoveProgress(t *testing.T) { } defer os.RemoveAll(tmpDir) - dl := download{ - uri: "http://example.com/file.zip", - filesize: 1000, - filename: "file.zip", - workingDir: tmpDir, - boost: 1, + dl := Downloader{ + URI: "http://example.com/file.zip", + fileSize: 1000, + Filename: "file.zip", + WorkingDir: tmpDir, + Boost: 1, partDownloaded: make([]atomicCounter, 1), parts: []downloadPart{ {index: 0, startByte: 0, endByte: 999}, @@ -213,11 +213,11 @@ func TestRemoveProgress(t *testing.T) { } func TestUpdatePartProgress(t *testing.T) { - dl := download{ - uri: "http://example.com/file.zip", - filesize: 1000, - filename: "file.zip", - boost: 2, + dl := Downloader{ + URI: "http://example.com/file.zip", + fileSize: 1000, + Filename: "file.zip", + Boost: 2, parts: []downloadPart{ {index: 0, startByte: 0, endByte: 499}, {index: 1, startByte: 500, endByte: 999}, @@ -243,11 +243,11 @@ func TestUpdatePartProgress(t *testing.T) { } func TestGetTotalDownloaded(t *testing.T) { - dl := download{ - uri: "http://example.com/file.zip", - filesize: 1000, - filename: "file.zip", - boost: 4, + dl := Downloader{ + URI: "http://example.com/file.zip", + fileSize: 1000, + Filename: "file.zip", + Boost: 4, parts: []downloadPart{ {index: 0, startByte: 0, endByte: 249}, {index: 1, startByte: 250, endByte: 499}, @@ -269,7 +269,7 @@ func TestGetTotalDownloaded(t *testing.T) { } func TestGetTotalDownloadedNilProgress(t *testing.T) { - dl := download{} + dl := Downloader{} total := dl.getTotalDownloaded() if total != 0 { t.Errorf("expected total=0 for nil progress, got %d", total) diff --git a/writers.go b/writers.go index b905379..1f2544c 100644 --- a/writers.go +++ b/writers.go @@ -1,4 +1,4 @@ -package main +package dl import ( "context" diff --git a/writers_test.go b/writers_test.go index 204bd38..f392a98 100644 --- a/writers_test.go +++ b/writers_test.go @@ -1,4 +1,4 @@ -package main +package dl import ( "bytes"