diff --git a/pkg/internal/utils/basic_test.go b/pkg/internal/utils/basic_test.go new file mode 100644 index 0000000..38be9eb --- /dev/null +++ b/pkg/internal/utils/basic_test.go @@ -0,0 +1,186 @@ +/* +Copyright 2025 The OpenCIDN Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "encoding/base64" + "net/url" + "testing" +) + +func TestParseBasicAuth(t *testing.T) { + tests := []struct { + name string + auth string + wantUser string + wantPass string + wantNil bool + }{ + { + name: "valid basic auth with password", + auth: "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")), + wantUser: "user", + wantPass: "pass", + wantNil: false, + }, + { + name: "valid basic auth without password", + auth: "Basic " + base64.StdEncoding.EncodeToString([]byte("user")), + wantUser: "user", + wantPass: "", + wantNil: false, + }, + { + name: "invalid prefix", + auth: "Bearer token", + wantNil: true, + }, + { + name: "empty string", + auth: "", + wantNil: true, + }, + { + name: "only Basic prefix", + auth: "Basic", + wantNil: true, + }, + { + name: "invalid base64", + auth: "Basic !!!invalid!!!", + wantNil: true, + }, + { + name: "case insensitive prefix", + auth: "basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")), + wantUser: "user", + wantPass: "pass", + wantNil: false, + }, + { + name: "username with colon in password", + auth: "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass:word")), + wantUser: "user", + wantPass: "pass:word", + wantNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ParseBasicAuth(tt.auth) + if tt.wantNil { + if got != nil { + t.Errorf("ParseBasicAuth() = %v, want nil", got) + } + return + } + + if got == nil { + t.Fatalf("ParseBasicAuth() = nil, want non-nil") + } + + if got.Username() != tt.wantUser { + t.Errorf("ParseBasicAuth() username = %v, want %v", got.Username(), tt.wantUser) + } + + gotPass, _ := got.Password() + if gotPass != tt.wantPass { + t.Errorf("ParseBasicAuth() password = %v, want %v", gotPass, tt.wantPass) + } + }) + } +} + +func TestFormathBasicAuth(t *testing.T) { + tests := []struct { + name string + ui *url.Userinfo + want string + }{ + { + name: "nil userinfo", + ui: nil, + want: "", + }, + { + name: "username only", + ui: url.User("user"), + want: "Basic " + base64.StdEncoding.EncodeToString([]byte("user")), + }, + { + name: "username and password", + ui: url.UserPassword("user", "pass"), + want: "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass")), + }, + { + name: "empty username", + ui: url.User(""), + want: "Basic " + base64.StdEncoding.EncodeToString([]byte("")), + }, + { + name: "special characters in credentials", + ui: url.UserPassword("admin@example.com", "p@ss:w0rd!"), + want: "Basic " + base64.StdEncoding.EncodeToString([]byte("admin%40example.com:p%40ss%3Aw0rd%21")), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormathBasicAuth(tt.ui) + if got != tt.want { + t.Errorf("FormathBasicAuth() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRoundTripBasicAuth(t *testing.T) { + tests := []struct { + name string + username string + password string + }{ + {"simple", "user", "pass"}, + {"no password", "user", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original := url.UserPassword(tt.username, tt.password) + formatted := FormathBasicAuth(original) + parsed := ParseBasicAuth(formatted) + + if parsed == nil { + t.Fatalf("ParseBasicAuth() returned nil") + } + + // Note: url.Userinfo.String() URL-encodes special characters, + // so we compare against the URL-encoded version + expectedUser := original.Username() + if parsed.Username() != expectedUser { + t.Errorf("username = %v, want %v", parsed.Username(), expectedUser) + } + + expectedPass, _ := original.Password() + gotPass, _ := parsed.Password() + if gotPass != expectedPass { + t.Errorf("password = %v, want %v", gotPass, expectedPass) + } + }) + } +} diff --git a/pkg/internal/utils/network_test.go b/pkg/internal/utils/network_test.go new file mode 100644 index 0000000..3c52145 --- /dev/null +++ b/pkg/internal/utils/network_test.go @@ -0,0 +1,256 @@ +/* +Copyright 2025 The OpenCIDN Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "errors" + "io" + "net" + "net/http" + "testing" + "time" +) + +type mockNetError struct { + error + temporary bool + timeout bool +} + +func (e *mockNetError) Temporary() bool { return e.temporary } +func (e *mockNetError) Timeout() bool { return e.timeout } + +func TestIsNetWorkError(t *testing.T) { + tests := []struct { + name string + err error + wantIsNetwork bool + }{ + { + name: "nil error", + err: nil, + wantIsNetwork: false, + }, + { + name: "EOF error", + err: io.EOF, + wantIsNetwork: true, + }, + { + name: "unexpected EOF error", + err: io.ErrUnexpectedEOF, + wantIsNetwork: true, + }, + { + name: "network timeout error", + err: &mockNetError{error: errors.New("timeout"), timeout: true}, + wantIsNetwork: true, + }, + { + name: "network temporary error", + err: &mockNetError{error: errors.New("temporary"), temporary: true}, + wantIsNetwork: true, + }, + { + name: "generic error", + err: errors.New("some error"), + wantIsNetwork: false, + }, + { + name: "wrapped EOF", + err: errors.Join(io.EOF, errors.New("connection failed")), + wantIsNetwork: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIsNetwork, gotErr := IsNetWorkError(tt.err) + + if gotIsNetwork != tt.wantIsNetwork { + t.Errorf("IsNetWorkError() isNetwork = %v, want %v", gotIsNetwork, tt.wantIsNetwork) + } + + if tt.err == nil { + if gotErr != nil { + t.Errorf("IsNetWorkError() error = %v, want nil", gotErr) + } + } else { + if gotErr == nil { + t.Errorf("IsNetWorkError() error = nil, want non-nil") + } + } + }) + } +} + +func TestIsHTTPResponseError(t *testing.T) { + tests := []struct { + name string + resp *http.Response + err error + wantIsRetry bool + wantErrNotNil bool + }{ + { + name: "nil response and nil error", + resp: nil, + err: nil, + wantIsRetry: false, + wantErrNotNil: true, + }, + { + name: "nil response with error", + resp: nil, + err: errors.New("request failed"), + wantIsRetry: false, + wantErrNotNil: true, + }, + { + name: "network error", + resp: nil, + err: io.EOF, + wantIsRetry: true, + wantErrNotNil: true, + }, + { + name: "200 OK response", + resp: &http.Response{StatusCode: http.StatusOK}, + err: nil, + wantIsRetry: false, + wantErrNotNil: false, + }, + { + name: "404 Not Found", + resp: &http.Response{StatusCode: http.StatusNotFound}, + err: nil, + wantIsRetry: false, + wantErrNotNil: false, + }, + { + name: "429 Too Many Requests", + resp: &http.Response{StatusCode: http.StatusTooManyRequests}, + err: nil, + wantIsRetry: true, + wantErrNotNil: true, + }, + { + name: "500 Internal Server Error", + resp: &http.Response{StatusCode: http.StatusInternalServerError}, + err: nil, + wantIsRetry: true, + wantErrNotNil: true, + }, + { + name: "502 Bad Gateway", + resp: &http.Response{StatusCode: http.StatusBadGateway}, + err: nil, + wantIsRetry: true, + wantErrNotNil: true, + }, + { + name: "503 Service Unavailable", + resp: &http.Response{StatusCode: http.StatusServiceUnavailable}, + err: nil, + wantIsRetry: true, + wantErrNotNil: true, + }, + { + name: "400 Bad Request", + resp: &http.Response{StatusCode: http.StatusBadRequest}, + err: nil, + wantIsRetry: false, + wantErrNotNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIsRetry, gotErr := IsHTTPResponseError(tt.resp, tt.err) + + if gotIsRetry != tt.wantIsRetry { + t.Errorf("IsHTTPResponseError() isRetry = %v, want %v", gotIsRetry, tt.wantIsRetry) + } + + if tt.wantErrNotNil && gotErr == nil { + t.Errorf("IsHTTPResponseError() error = nil, want non-nil") + } else if !tt.wantErrNotNil && gotErr != nil { + t.Errorf("IsHTTPResponseError() error = %v, want nil", gotErr) + } + }) + } +} + +func TestIsNetError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "network timeout error", + err: &mockNetError{error: errors.New("timeout"), timeout: true}, + want: true, + }, + { + name: "network temporary error", + err: &mockNetError{error: errors.New("temporary"), temporary: true}, + want: true, + }, + { + name: "net.OpError", + err: &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")}, + want: true, + }, + { + name: "generic error", + err: errors.New("not a network error"), + want: false, + }, + { + name: "nil error", + err: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isNetError(tt.err) + if got != tt.want { + t.Errorf("isNetError() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsNetWorkErrorWithRealNetErrors(t *testing.T) { + // Test with a real network timeout error + conn, err := net.DialTimeout("tcp", "192.0.2.1:80", 1*time.Nanosecond) + if conn != nil { + conn.Close() + } + if err != nil { + isNet, wrappedErr := IsNetWorkError(err) + if !isNet { + t.Errorf("IsNetWorkError() with real timeout should be network error") + } + if wrappedErr == nil { + t.Errorf("IsNetWorkError() should return wrapped error") + } + } +} diff --git a/pkg/runner/read_count_test.go b/pkg/runner/read_count_test.go new file mode 100644 index 0000000..40b07aa --- /dev/null +++ b/pkg/runner/read_count_test.go @@ -0,0 +1,340 @@ +/* +Copyright 2025 The OpenCIDN Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runner + +import ( + "bytes" + "context" + "errors" + "io" + "strings" + "sync" + "testing" + "time" +) + +func TestNewReadCount(t *testing.T) { + ctx := context.Background() + reader := strings.NewReader("test data") + + rc := NewReadCount(ctx, reader) + + if rc == nil { + t.Fatal("NewReadCount() returned nil") + } + + if rc.ctx != ctx { + t.Error("NewReadCount() context not set correctly") + } + + if rc.reader != reader { + t.Error("NewReadCount() reader not set correctly") + } + + if rc.count != 0 { + t.Errorf("NewReadCount() initial count = %d, want 0", rc.count) + } +} + +func TestReadCount_Read(t *testing.T) { + tests := []struct { + name string + input string + bufferSize int + wantCount int64 + wantErr bool + }{ + { + name: "read all at once", + input: "Hello, World!", + bufferSize: 100, + wantCount: 13, + wantErr: false, + }, + { + name: "read in chunks", + input: "Hello, World!", + bufferSize: 5, + wantCount: 13, + wantErr: false, + }, + { + name: "empty input", + input: "", + bufferSize: 10, + wantCount: 0, + wantErr: false, + }, + { + name: "large input", + input: strings.Repeat("A", 1000), + bufferSize: 100, + wantCount: 1000, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + reader := strings.NewReader(tt.input) + rc := NewReadCount(ctx, reader) + + buffer := make([]byte, tt.bufferSize) + totalRead := int64(0) + + for { + n, err := rc.Read(buffer) + totalRead += int64(n) + + if err == io.EOF { + break + } + if err != nil && !tt.wantErr { + t.Errorf("Read() unexpected error = %v", err) + break + } + } + + if rc.Count() != tt.wantCount { + t.Errorf("Count() = %d, want %d", rc.Count(), tt.wantCount) + } + + if totalRead != tt.wantCount { + t.Errorf("Total bytes read = %d, want %d", totalRead, tt.wantCount) + } + }) + } +} + +func TestReadCount_Count(t *testing.T) { + ctx := context.Background() + reader := strings.NewReader("Test data for counting") + rc := NewReadCount(ctx, reader) + + if rc.Count() != 0 { + t.Errorf("Initial Count() = %d, want 0", rc.Count()) + } + + buffer := make([]byte, 5) + rc.Read(buffer) + + if rc.Count() != 5 { + t.Errorf("Count() after first read = %d, want 5", rc.Count()) + } + + rc.Read(buffer) + + if rc.Count() != 10 { + t.Errorf("Count() after second read = %d, want 10", rc.Count()) + } +} + +func TestReadCount_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + reader := strings.NewReader("Test data that should not be read") + rc := NewReadCount(ctx, reader) + + cancel() + + buffer := make([]byte, 10) + _, err := rc.Read(buffer) + + if err == nil { + t.Error("Read() after context cancellation should return error") + } + + if !errors.Is(err, context.Canceled) { + t.Errorf("Read() error = %v, want context.Canceled", err) + } + + if rc.Count() != 0 { + t.Errorf("Count() after cancelled read = %d, want 0", rc.Count()) + } +} + +func TestReadCount_ContextTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + reader := strings.NewReader("Test data") + rc := NewReadCount(ctx, reader) + + time.Sleep(10 * time.Millisecond) + + buffer := make([]byte, 10) + _, err := rc.Read(buffer) + + if err == nil { + t.Error("Read() after context timeout should return error") + } + + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Read() error = %v, want context.DeadlineExceeded", err) + } +} + +func TestReadCount_ConcurrentReads(t *testing.T) { + // Note: This test demonstrates that ReadCount's counter is thread-safe, + // even though strings.NewReader itself is not safe for concurrent reads. + // In practice, you should not read from the same underlying reader + // from multiple goroutines. + ctx := context.Background() + data := strings.Repeat("A", 10000) + reader := strings.NewReader(data) + rc := NewReadCount(ctx, reader) + + // Single reader goroutine + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + buffer := make([]byte, 100) + for { + _, err := rc.Read(buffer) + if err == io.EOF { + break + } + if err != nil { + t.Errorf("Read error: %v", err) + break + } + } + }() + + wg.Wait() + + finalCount := rc.Count() + if finalCount != int64(len(data)) { + t.Errorf("Final Count() = %d, want %d", finalCount, len(data)) + } +} + +func TestReadCount_ConcurrentCountAccess(t *testing.T) { + ctx := context.Background() + data := strings.Repeat("B", 1000) + reader := strings.NewReader(data) + rc := NewReadCount(ctx, reader) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + buffer := make([]byte, 10) + for { + _, err := rc.Read(buffer) + if err == io.EOF { + break + } + } + }() + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = rc.Count() + }() + } + + wg.Wait() + + if rc.Count() != int64(len(data)) { + t.Errorf("Final Count() = %d, want %d", rc.Count(), len(data)) + } +} + +type errorReader struct { + err error +} + +func (r *errorReader) Read(p []byte) (int, error) { + return 0, r.err +} + +func TestReadCount_ReadError(t *testing.T) { + ctx := context.Background() + expectedErr := errors.New("read error") + reader := &errorReader{err: expectedErr} + rc := NewReadCount(ctx, reader) + + buffer := make([]byte, 10) + n, err := rc.Read(buffer) + + if n != 0 { + t.Errorf("Read() n = %d, want 0", n) + } + + if !errors.Is(err, expectedErr) { + t.Errorf("Read() error = %v, want %v", err, expectedErr) + } + + if rc.Count() != 0 { + t.Errorf("Count() after error = %d, want 0", rc.Count()) + } +} + +type partialReader struct { + data []byte + pos int + limit int +} + +func (r *partialReader) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + + toRead := len(p) + if toRead > r.limit { + toRead = r.limit + } + if toRead > len(r.data)-r.pos { + toRead = len(r.data) - r.pos + } + + n := copy(p, r.data[r.pos:r.pos+toRead]) + r.pos += n + return n, nil +} + +func TestReadCount_PartialReads(t *testing.T) { + ctx := context.Background() + data := []byte("Hello, World! This is a test.") + reader := &partialReader{data: data, limit: 5} + rc := NewReadCount(ctx, reader) + + var buf bytes.Buffer + written, err := io.Copy(&buf, rc) + + if err != nil { + t.Errorf("io.Copy() error = %v", err) + } + + if written != int64(len(data)) { + t.Errorf("io.Copy() written = %d, want %d", written, len(data)) + } + + if rc.Count() != int64(len(data)) { + t.Errorf("Count() = %d, want %d", rc.Count(), len(data)) + } + + if buf.String() != string(data) { + t.Errorf("Data mismatch: got %q, want %q", buf.String(), string(data)) + } +}