diff --git a/README.md b/README.md index 7d88b2a1..b3a0255c 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,9 @@ For concrete numbers on your own setup, run the included benchmark tool against ### Does it support SSH? -No. `git-sync` supports smart HTTP/HTTPS only. +Yes. `git-sync` supports SSH remotes through the local `ssh` binary, including +`ssh://`, SCP-style `git@host:path.git`, and `git+ssh://` URLs. See +[docs/usage.md](docs/usage.md) for details and current caveats. ### Does it run as a daemon or watch for changes? diff --git a/docs/protocol.md b/docs/protocol.md index 0cfe4537..cce7ab17 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -20,7 +20,8 @@ Out of scope (with pointers): - The pack format itself (object types, deltas, index format) — see [Git's pack-format docs](https://git-scm.com/docs/pack-format) - Dumb HTTP — `git-sync` does not support it -- SSH transport — `git-sync` is HTTPS-only +- Full SSH transport details — `git-sync` supports SSH, but this document is + focused on the Smart HTTP wire flow - Bundle URI, partial clones, and other newer extensions ## Smart HTTP Overview diff --git a/docs/usage.md b/docs/usage.md index c9bbfe1e..1090acd7 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -121,6 +121,30 @@ git-sync sync \ ``` +### SSH remotes + +`git-sync` also supports SSH remotes. Accepted forms include: + +- `ssh://git@example.com/org/repo.git` +- `git@example.com:org/repo.git` +- `git+ssh://example.com/org/repo.git` + +SSH transport shells out to the local `ssh` binary, so host aliases, +`IdentityFile`, agent-backed keys, and other `~/.ssh/config` behavior come +from your existing SSH setup rather than separate `git-sync` flags. + +`git-sync` runs SSH with `BatchMode=yes`, which avoids interactive password or +host-key prompts during syncs. On first contact with a host, add it to +`known_hosts` ahead of time or configure `StrictHostKeyChecking=accept-new` +for that host in your SSH config. + +Current limitation: `--progress` and `--show-stats` do not yet include +byte-counted SSH transfer metrics, so `--progress` and `--stats` omit +SSH-side throughput. + +If `ssh` is not available on `PATH`, `git-sync` fails early with a clear +`locate ssh binary` error before contacting either remote. + ## Sync Behavior `sync` picks the bootstrap relay path automatically when the target is empty. For non-empty targets, safe fast-forward updates also use a relay path that streams the source pack directly into target `receive-pack` without local materialization. Anything not relay-eligible (force, prune, deletes, tag retargets) falls back to a materialized path bounded by `--materialized-max-objects`. diff --git a/internal/gitproto/conn.go b/internal/gitproto/conn.go new file mode 100644 index 00000000..b4f46166 --- /dev/null +++ b/internal/gitproto/conn.go @@ -0,0 +1,18 @@ +package gitproto + +import ( + "context" + "io" + "net/url" +) + +// Conn represents a connection to a Git remote transport such as Smart HTTP +// or SSH. +type Conn interface { + RequestInfoRefs(ctx context.Context, service string, gitProtocol string) ([]byte, error) + PostRPCStreamBody(ctx context.Context, service string, body io.Reader, v2 bool, phase string) (io.ReadCloser, error) + Endpoint() *url.URL + ProgressWriter() io.Writer + SetProgressWriter(w io.Writer) + Close() error +} diff --git a/internal/gitproto/fetch.go b/internal/gitproto/fetch.go index db20d38b..1a7f9b51 100644 --- a/internal/gitproto/fetch.go +++ b/internal/gitproto/fetch.go @@ -55,7 +55,7 @@ func (s *RefService) SupportsBootstrapBatch() bool { func (s *RefService) FetchToStore( ctx context.Context, store storer.Storer, - conn *Conn, + conn Conn, desired map[plumbing.ReferenceName]DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, ) error { @@ -81,7 +81,7 @@ func (s *RefService) FetchToStore( // Caller must close the returned ReadCloser. func (s *RefService) FetchPack( ctx context.Context, - conn *Conn, + conn Conn, desired map[plumbing.ReferenceName]DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, ) (io.ReadCloser, error) { @@ -102,7 +102,7 @@ func (s *RefService) FetchPack( func (s *RefService) FetchCommitGraph( ctx context.Context, store storer.Storer, - conn *Conn, + conn Conn, ref DesiredRef, haves []plumbing.Hash, ) error { @@ -156,7 +156,7 @@ func (s *RefService) Capabilities() []string { func fetchToStoreV2( ctx context.Context, store storer.Storer, - conn *Conn, + conn Conn, caps *V2Capabilities, desired map[plumbing.ReferenceName]DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, @@ -194,12 +194,12 @@ func fetchToStoreV2( return err } defer ioutil.CheckClose(reader, &err) - return storeV2FetchPack(store, reader, verbose, conn.ProgressOut) + return storeV2FetchPack(store, reader, verbose, conn.ProgressWriter()) } func fetchPackV2( ctx context.Context, - conn *Conn, + conn Conn, caps *V2Capabilities, desired map[plumbing.ReferenceName]DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, @@ -240,7 +240,7 @@ func fetchPackV2( if err != nil { return nil, err } - packStream, err := openV2PackStream(reader, verbose, conn.ProgressOut) + packStream, err := openV2PackStream(reader, verbose, conn.ProgressWriter()) if err != nil { _ = reader.Close() return nil, err @@ -431,7 +431,7 @@ func buildV1UploadPackBody( func fetchToStoreV1( ctx context.Context, store storer.Storer, - conn *Conn, + conn Conn, adv *packp.AdvRefs, desired map[plumbing.ReferenceName]DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, @@ -456,7 +456,7 @@ func fetchToStoreV1( if drainErr := drainTrailingNAKs(buffered); drainErr != nil { return fmt.Errorf("drain server response: %w", drainErr) } - sbReader := buildSidebandReader(caps, buffered, progressSink(verbose, "source: ", conn.ProgressOut)) + sbReader := buildSidebandReader(caps, buffered, progressSink(verbose, "source: ", conn.ProgressWriter())) if err := packfile.UpdateObjectStorage(store, sbReader); err != nil { return fmt.Errorf("update object storage: %w", err) } @@ -465,7 +465,7 @@ func fetchToStoreV1( func fetchPackV1( ctx context.Context, - conn *Conn, + conn Conn, adv *packp.AdvRefs, desired map[plumbing.ReferenceName]DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, @@ -491,7 +491,7 @@ func fetchPackV1( return nil, fmt.Errorf("drain server response: %w", drainErr) } return &wrappedRC{ - Reader: buildSidebandReader(caps, buffered, progressSink(verbose, "source: ", conn.ProgressOut)), + Reader: buildSidebandReader(caps, buffered, progressSink(verbose, "source: ", conn.ProgressWriter())), Closer: reader, }, nil } diff --git a/internal/gitproto/fetch_test.go b/internal/gitproto/fetch_test.go index 69e30a77..a9119f33 100644 --- a/internal/gitproto/fetch_test.go +++ b/internal/gitproto/fetch_test.go @@ -362,7 +362,7 @@ func TestFetchPackV1ContextCanceled(t *testing.T) { if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { started <- struct{}{} <-req.Context().Done() return nil, req.Context().Err() @@ -411,7 +411,7 @@ func TestFetchPackV2ContextCanceled(t *testing.T) { if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { started <- struct{}{} <-req.Context().Done() return nil, req.Context().Err() @@ -463,7 +463,7 @@ func TestFetchToStoreV2ContextCanceled(t *testing.T) { if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { started <- struct{}{} <-req.Context().Done() return nil, req.Context().Err() @@ -514,7 +514,7 @@ func TestFetchToStoreV2ClosesBodyOnDecodeError(t *testing.T) { t.Fatalf("parse endpoint: %v", err) } body := &trackingReadCloser{ReadCloser: io.NopCloser(bytes.NewBufferString(FormatPktLine("bogus\n") + "0000"))} - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -552,7 +552,7 @@ func TestFetchToStoreV2ContextCanceledMidStream(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { body := &blockingPacketBody{ ctx: req.Context(), startedRead: startedRead, @@ -609,7 +609,7 @@ func TestFetchPackV1ClosesBodyOnDecodeError(t *testing.T) { t.Fatalf("parse endpoint: %v", err) } body := &trackingReadCloser{ReadCloser: io.NopCloser(bytes.NewBufferString("not-a-valid-server-response"))} - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -644,7 +644,7 @@ func TestFetchPackV1ReturnedReaderClosesBodyOnInterruption(t *testing.T) { data: []byte("0008NAK\nPACK"), err: io.ErrUnexpectedEOF, } - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -686,7 +686,7 @@ func TestFetchPackV1ReturnedReaderErrorsOnMalformedMidStreamPacket(t *testing.T) t.Fatalf("parse endpoint: %v", err) } body := &trackingReadCloser{ReadCloser: io.NopCloser(bytes.NewBufferString("0008NAK\nzzzz"))} - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -733,7 +733,7 @@ func TestFetchPackV1DrainsSecondNAK(t *testing.T) { t.Fatalf("parse endpoint: %v", err) } body := &trackingReadCloser{ReadCloser: io.NopCloser(bytes.NewReader(payload))} - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -776,7 +776,7 @@ func TestFetchPackV2ClosesBodyOnDecodeError(t *testing.T) { t.Fatalf("parse endpoint: %v", err) } body := &trackingReadCloser{ReadCloser: io.NopCloser(bytes.NewBufferString("0000"))} - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -912,7 +912,7 @@ func TestFetchPackV2ManyHavesSendsDoneAndReadsPack(t *testing.T) { } seenRequest := false - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { seenRequest = true body, err := io.ReadAll(req.Body) if err != nil { @@ -999,7 +999,7 @@ func TestFetchPackV2ReturnedReaderClosesBodyOnInterruption(t *testing.T) { data: wire.Bytes(), err: io.ErrUnexpectedEOF, } - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -1045,7 +1045,7 @@ func TestFetchPackV2ReturnedReaderErrorsOnMalformedMidStreamPacket(t *testing.T) t.Fatalf("parse endpoint: %v", err) } body := &trackingReadCloser{ReadCloser: io.NopCloser(bytes.NewBufferString(FormatPktLine("packfile\n") + "zzzz"))} - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -1088,7 +1088,7 @@ func TestFetchToStoreV2ClosesBodyOnMalformedMidStreamPacket(t *testing.T) { t.Fatalf("parse endpoint: %v", err) } body := &trackingReadCloser{ReadCloser: io.NopCloser(bytes.NewBufferString(FormatPktLine("packfile\n") + "zzzz"))} - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Request: req, diff --git a/internal/gitproto/push.go b/internal/gitproto/push.go index f114ac3f..704601da 100644 --- a/internal/gitproto/push.go +++ b/internal/gitproto/push.go @@ -34,14 +34,14 @@ type PushCommand struct { // construction without worrying about whether downstream strategies have // already captured a value copy. type Pusher struct { - Conn *Conn + Conn Conn Adv *packp.AdvRefs Verbose bool OnRejection func(refName plumbing.ReferenceName, status string) } // NewPusher builds a target-side push executor. -func NewPusher(conn *Conn, adv *packp.AdvRefs, verbose bool) *Pusher { +func NewPusher(conn Conn, adv *packp.AdvRefs, verbose bool) *Pusher { return &Pusher{Conn: conn, Adv: adv, Verbose: verbose} } @@ -140,7 +140,7 @@ func annotateLeaseFailure(err error) error { // sendReceivePack encodes and POSTs a receive-pack request, then decodes the report. func sendReceivePack( ctx context.Context, - conn *Conn, + conn Conn, req *packp.UpdateRequests, packData io.Reader, verbose bool, @@ -166,11 +166,11 @@ func sendReceivePack( switch { case req.Capabilities.Supports(capability.Sideband64k): dem := sideband.NewDemuxer(sideband.Sideband64k, reader) - dem.Progress = progressSink(verbose, "target: ", conn.ProgressOut) + dem.Progress = progressSink(verbose, "target: ", conn.ProgressWriter()) respReader = dem case req.Capabilities.Supports(capability.Sideband): dem := sideband.NewDemuxer(sideband.Sideband, reader) - dem.Progress = progressSink(verbose, "target: ", conn.ProgressOut) + dem.Progress = progressSink(verbose, "target: ", conn.ProgressWriter()) respReader = dem } @@ -201,7 +201,7 @@ func sendReceivePack( // PushObjects pushes locally-materialized objects to the target. func PushObjects( ctx context.Context, - conn *Conn, + conn Conn, adv *packp.AdvRefs, commands []PushCommand, store storer.Storer, @@ -242,7 +242,7 @@ func PushObjects( // PushPack pushes a pack stream (relay) to the target. func PushPack( ctx context.Context, - conn *Conn, + conn Conn, adv *packp.AdvRefs, commands []PushCommand, pack io.ReadCloser, @@ -276,7 +276,7 @@ func PushPack( // PushCommands sends ref update commands without a pack (for ref-only changes). func PushCommands( ctx context.Context, - conn *Conn, + conn Conn, adv *packp.AdvRefs, commands []PushCommand, verbose bool, diff --git a/internal/gitproto/push_test.go b/internal/gitproto/push_test.go index f3cd36de..75265af2 100644 --- a/internal/gitproto/push_test.go +++ b/internal/gitproto/push_test.go @@ -127,13 +127,13 @@ func fakeReceivePackServer(t *testing.T, reportErr string) *httptest.Server { })) } -func connForServer(t *testing.T, srv *httptest.Server) *Conn { +func connForServer(t *testing.T, srv *httptest.Server) *HTTPConn { t.Helper() ep, err := transport.ParseURL(srv.URL + "/repo.git") if err != nil { t.Fatalf("parse endpoint: %v", err) } - return NewConn(ep, "test", nil, srv.Client().Transport) + return NewHTTPConn(ep, "test", nil, srv.Client().Transport) } func TestPushPackClosesPackOnSuccess(t *testing.T) { @@ -189,7 +189,7 @@ func TestPushPackClosesPackOnContextCanceled(t *testing.T) { if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "target", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "target", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { started <- struct{}{} <-req.Context().Done() return nil, req.Context().Err() @@ -322,7 +322,7 @@ func TestPushPackRejectsDeletes(t *testing.T) { // Use a nil-transport conn -- we should never reach the network. ep, err := transport.ParseURL("https://example.com/repo.git") require.NoError(t, err) - conn := &Conn{Endpoint: ep, HTTP: &http.Client{}} + conn := &HTTPConn{EndpointURL: ep, HTTP: &http.Client{}} err = PushPack(context.Background(), conn, adv, []PushCommand{ {Name: "refs/heads/old", Delete: true}, diff --git a/internal/gitproto/refs.go b/internal/gitproto/refs.go index 338a30a0..c2e106d2 100644 --- a/internal/gitproto/refs.go +++ b/internal/gitproto/refs.go @@ -16,6 +16,10 @@ import ( "github.com/go-git/go-git/v6/plumbing/transport" ) +// GitProtocolV2 is the value passed via the Git-Protocol header / GIT_PROTOCOL +// env var to negotiate protocol v2 with the remote. +const GitProtocolV2 = "version=2" + // RefService encapsulates the result of source ref discovery and the negotiated // protocol, providing methods for subsequent fetch and pack operations. type RefService struct { @@ -34,7 +38,7 @@ type RefService struct { // ListSourceRefs discovers refs from the source using the configured protocol mode. // Returns the list of refs and a RefService for subsequent operations. -func ListSourceRefs(ctx context.Context, conn *Conn, protocolMode string, refPrefixes []string) ([]*plumbing.Reference, *RefService, error) { +func ListSourceRefs(ctx context.Context, conn Conn, protocolMode string, refPrefixes []string) ([]*plumbing.Reference, *RefService, error) { switch protocolMode { case "v1": adv, refs, err := listSourceRefsV1(ctx, conn) @@ -44,8 +48,15 @@ func ListSourceRefs(ctx context.Context, conn *Conn, protocolMode string, refPre return refs, &RefService{Protocol: "v1", V1Adv: adv, HeadTarget: headTargetFromAdv(adv)}, nil case "auto", "v2": - data, err := RequestInfoRefs(ctx, conn, transport.UploadPackService, "version=2") + data, err := RequestInfoRefs(ctx, conn, transport.UploadPackService, GitProtocolV2) if err != nil { + if protocolMode == "auto" && isSSHScheme(conn) { + refs, svc, v1Err := listSourceRefsAutoV1(ctx, conn) + if v1Err != nil { + return nil, nil, errors.Join(err, v1Err) + } + return refs, svc, nil + } return nil, nil, err } if caps, err := DecodeV2Capabilities(bytes.NewReader(data)); err == nil { @@ -77,8 +88,27 @@ func ListSourceRefs(ctx context.Context, conn *Conn, protocolMode string, refPre } } +func listSourceRefsAutoV1(ctx context.Context, conn Conn) ([]*plumbing.Reference, *RefService, error) { + adv, refs, err := listSourceRefsV1(ctx, conn) + if err != nil { + return nil, nil, err + } + return refs, &RefService{Protocol: "v1", V1Adv: adv, HeadTarget: headTargetFromAdv(adv)}, nil +} + +func isSSHScheme(conn Conn) bool { + if conn == nil || conn.Endpoint() == nil { + return false + } + switch conn.Endpoint().Scheme { + case "ssh", "git+ssh": + return true + } + return false +} + // AdvertisedRefsV1 fetches and decodes v1 advertised refs for the given service. -func AdvertisedRefsV1(ctx context.Context, conn *Conn, service string) (*packp.AdvRefs, error) { +func AdvertisedRefsV1(ctx context.Context, conn Conn, service string) (*packp.AdvRefs, error) { data, err := RequestInfoRefs(ctx, conn, service, "") if err != nil { return nil, err @@ -128,7 +158,7 @@ func AdvRefsCaps(adv *packp.AdvRefs) []string { return items } -func listSourceRefsV1(ctx context.Context, conn *Conn) (*packp.AdvRefs, []*plumbing.Reference, error) { +func listSourceRefsV1(ctx context.Context, conn Conn) (*packp.AdvRefs, []*plumbing.Reference, error) { adv, err := AdvertisedRefsV1(ctx, conn, transport.UploadPackService) if err != nil { return nil, nil, err @@ -140,7 +170,7 @@ func listSourceRefsV1(ctx context.Context, conn *Conn) (*packp.AdvRefs, []*plumb return adv, refs, nil } -func listSourceRefsV2(ctx context.Context, conn *Conn, caps *V2Capabilities, prefixes []string) ([]*plumbing.Reference, plumbing.ReferenceName, error) { +func listSourceRefsV2(ctx context.Context, conn Conn, caps *V2Capabilities, prefixes []string) ([]*plumbing.Reference, plumbing.ReferenceName, error) { // Always include "HEAD" so the server returns the symref-target attribute // for HEAD. Without this, callers that pass only "refs/heads/" or // "refs/tags/" prefixes filter HEAD out of the response and lose the diff --git a/internal/gitproto/refs_test.go b/internal/gitproto/refs_test.go index 0002f132..88434114 100644 --- a/internal/gitproto/refs_test.go +++ b/internal/gitproto/refs_test.go @@ -3,6 +3,8 @@ package gitproto import ( "context" "errors" + "io" + "net/url" "strings" "testing" @@ -18,9 +20,9 @@ func TestRefHashMap(t *testing.T) { hashB := plumbing.NewHash("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") refs := []*plumbing.Reference{ - plumbing.NewHashReference("refs/heads/main", hashA), + plumbing.NewHashReference(refsHeadsMain, hashA), plumbing.NewHashReference("refs/heads/dev", hashB), - plumbing.NewSymbolicReference("HEAD", "refs/heads/main"), // symbolic, should be skipped + plumbing.NewSymbolicReference("HEAD", refsHeadsMain), // symbolic, should be skipped } m := RefHashMap(refs) @@ -28,7 +30,7 @@ func TestRefHashMap(t *testing.T) { if len(m) != 2 { t.Fatalf("expected 2 entries, got %d", len(m)) } - if got := m["refs/heads/main"]; got != hashA { + if got := m[refsHeadsMain]; got != hashA { t.Errorf("refs/heads/main = %s, want %s", got, hashA) } if got := m["refs/heads/dev"]; got != hashB { @@ -50,7 +52,7 @@ func TestHeadTargetFromAdv(t *testing.T) { adv := &packp.AdvRefs{} adv.Capabilities.Add(capability.SymRef, "HEAD:refs/heads/main") - if got := headTargetFromAdv(adv); got.String() != "refs/heads/main" { + if got := headTargetFromAdv(adv); got.String() != refsHeadsMain { t.Errorf("headTargetFromAdv = %q, want refs/heads/main", got) } @@ -106,7 +108,7 @@ func TestAdvRefsToSlice(t *testing.T) { hashA := plumbing.NewHash("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") hashB := plumbing.NewHash("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") adv.References = []*plumbing.Reference{ - plumbing.NewHashReference("refs/heads/main", hashA), + plumbing.NewHashReference(refsHeadsMain, hashA), plumbing.NewHashReference("refs/heads/dev", hashB), } @@ -122,8 +124,8 @@ func TestAdvRefsToSlice(t *testing.T) { for _, ref := range refs { found[ref.Name()] = ref.Hash() } - if found["refs/heads/main"] != hashA { - t.Errorf("refs/heads/main = %s, want %s", found["refs/heads/main"], hashA) + if found[refsHeadsMain] != hashA { + t.Errorf("refs/heads/main = %s, want %s", found[refsHeadsMain], hashA) } if found["refs/heads/dev"] != hashB { t.Errorf("refs/heads/dev = %s, want %s", found["refs/heads/dev"], hashB) @@ -200,3 +202,134 @@ func TestListSourceRefsUnsupportedProtocol(t *testing.T) { t.Fatal("expected error for unsupported protocol mode") } } + +func TestListSourceRefsAutoFallsBackToV1AfterSSHV2ProbeError(t *testing.T) { + t.Parallel() + + conn := &stubConn{ + reqInfoRefs: func(_ context.Context, _ string, gitProtocol string) ([]byte, error) { + if gitProtocol == GitProtocolV2 { + return nil, errors.New("ssh server rejected v2 probe") + } + + var body strings.Builder + if _, err := pktline.Writef(&body, "# service=%s\n", transport.UploadPackService); err != nil { + t.Fatalf("write smart service line: %v", err) + } + if err := pktline.WriteFlush(&body); err != nil { + t.Fatalf("write smart flush: %v", err) + } + if _, err := pktline.Writef(&body, "%s HEAD\x00%s\n", strings.Repeat("a", 40), capability.SymRef+"=HEAD:refs/heads/main"); err != nil { + t.Fatalf("write advertised head: %v", err) + } + if _, err := pktline.Writef(&body, "%s refs/heads/main\n", strings.Repeat("a", 40)); err != nil { + t.Fatalf("write advertised ref: %v", err) + } + if err := pktline.WriteFlush(&body); err != nil { + t.Fatalf("write trailing flush: %v", err) + } + return []byte(body.String()), nil + }, + } + + refs, svc, err := ListSourceRefs(t.Context(), conn, "auto", nil) + if err != nil { + t.Fatalf("ListSourceRefs(auto) error = %v", err) + } + if svc.Protocol != "v1" { + t.Fatalf("protocol = %q, want v1", svc.Protocol) + } + if got := svc.HeadTarget.String(); got != refsHeadsMain { + t.Fatalf("head target = %q, want refs/heads/main", got) + } + foundMain := false + for _, ref := range refs { + if ref.Name().String() == refsHeadsMain { + foundMain = true + break + } + } + if !foundMain { + t.Fatalf("refs = %#v, want refs/heads/main to be advertised", refs) + } +} + +func TestListSourceRefsAutoJoinsErrorsWhenV1FallbackAlsoFails(t *testing.T) { + t.Parallel() + + v2Err := errors.New("ssh v2 probe failed") + v1Err := errors.New("ssh v1 fallback failed") + conn := &stubConn{ + reqInfoRefs: func(_ context.Context, _ string, gitProtocol string) ([]byte, error) { + if gitProtocol == GitProtocolV2 { + return nil, v2Err + } + return nil, v1Err + }, + } + + _, _, err := ListSourceRefs(t.Context(), conn, "auto", nil) + if err == nil { + t.Fatal("expected error when both v2 probe and v1 fallback fail") + } + if !errors.Is(err, v2Err) { + t.Errorf("err does not wrap v2 error: %v", err) + } + if !errors.Is(err, v1Err) { + t.Errorf("err does not wrap v1 error: %v", err) + } +} + +func TestListSourceRefsAutoDoesNotFallBackForNonSSH(t *testing.T) { + t.Parallel() + + v2Err := errors.New("https v2 probe failed") + v1Called := false + conn := &stubConn{ + reqInfoRefs: func(_ context.Context, _ string, gitProtocol string) ([]byte, error) { + if gitProtocol == GitProtocolV2 { + return nil, v2Err + } + v1Called = true + return nil, errors.New("unexpected v1 call") + }, + endpoint: &url.URL{Scheme: "https", Host: "github.com"}, + } + + _, _, err := ListSourceRefs(t.Context(), conn, "auto", nil) + if err == nil { + t.Fatal("expected error from v2 probe") + } + if !errors.Is(err, v2Err) { + t.Errorf("err does not wrap v2 error: %v", err) + } + if v1Called { + t.Error("v1 fallback should not be attempted for non-SSH schemes") + } +} + +type stubConn struct { + reqInfoRefs func(ctx context.Context, service string, gitProtocol string) ([]byte, error) + endpoint *url.URL +} + +func (s *stubConn) RequestInfoRefs(ctx context.Context, service string, gitProtocol string) ([]byte, error) { + return s.reqInfoRefs(ctx, service, gitProtocol) +} + +func (s *stubConn) PostRPCStreamBody(context.Context, string, io.Reader, bool, string) (io.ReadCloser, error) { + return nil, errors.New("unexpected PostRPCStreamBody call") +} + +func (s *stubConn) Endpoint() *url.URL { + if s.endpoint != nil { + return s.endpoint + } + return &url.URL{Scheme: "ssh", Host: "github.com"} +} + +func (s *stubConn) ProgressWriter() io.Writer { return nil } + +func (s *stubConn) SetProgressWriter(io.Writer) {} + +func (s *stubConn) Close() error { return nil } diff --git a/internal/gitproto/smarthttp.go b/internal/gitproto/smarthttp.go index 81d97dc7..90c5fe45 100644 --- a/internal/gitproto/smarthttp.go +++ b/internal/gitproto/smarthttp.go @@ -70,12 +70,12 @@ type AuthMethod interface { Authorizer(req *http.Request) error } -// Conn represents a connection to a remote Git HTTP endpoint. -type Conn struct { - Label string - Endpoint *url.URL - HTTP *http.Client - Auth AuthMethod +// HTTPConn represents a connection to a remote Git HTTP endpoint. +type HTTPConn struct { + Label string + EndpointURL *url.URL + HTTP *http.Client + Auth AuthMethod // FollowInfoRefsRedirect, when true, rewrites Endpoint.Scheme and // Endpoint.Host to the final URL returned by RequestInfoRefs after @@ -96,28 +96,36 @@ type Conn struct { ProgressOut io.Writer } -// NewConn creates a new connection to the given endpoint. -func NewConn(ep *url.URL, label string, auth AuthMethod, rt http.RoundTripper) *Conn { +// NewHTTPConn creates a new connection to the given endpoint. +func NewHTTPConn(ep *url.URL, label string, auth AuthMethod, rt http.RoundTripper) *HTTPConn { httpClient := &http.Client{Transport: rt} - return NewConnWithHTTPClient(ep, label, auth, httpClient) + return NewHTTPConnWithClient(ep, label, auth, httpClient) } -// NewConnWithHTTPClient creates a new connection using the provided HTTP client. +// NewHTTPConnWithClient creates a new connection using the provided HTTP client. // Passing nil falls back to a default client and is intended only for direct // callers outside git-sync's normal instrumented session setup. -func NewConnWithHTTPClient(ep *url.URL, label string, auth AuthMethod, httpClient *http.Client) *Conn { +func NewHTTPConnWithClient(ep *url.URL, label string, auth AuthMethod, httpClient *http.Client) *HTTPConn { if httpClient == nil { httpClient = &http.Client{Transport: http.DefaultTransport} } normalizeEndpointPath(ep) - return &Conn{ - Label: label, - Endpoint: ep, - HTTP: httpClient, - Auth: auth, + return &HTTPConn{ + Label: label, + EndpointURL: ep, + HTTP: httpClient, + Auth: auth, } } +func (c *HTTPConn) Endpoint() *url.URL { return c.EndpointURL } + +func (c *HTTPConn) ProgressWriter() io.Writer { return c.ProgressOut } + +func (c *HTTPConn) SetProgressWriter(w io.Writer) { c.ProgressOut = w } + +func (c *HTTPConn) Close() error { return nil } + func normalizeEndpointPath(ep *url.URL) { if ep == nil { return @@ -143,8 +151,17 @@ func NewHTTPTransport(skipTLS bool) http.RoundTripper { } // RequestInfoRefs fetches /info/refs for the given service. -func RequestInfoRefs(ctx context.Context, conn *Conn, service string, gitProtocol string) ([]byte, error) { - reqURL := fmt.Sprintf("%s/info/refs?service=%s", conn.Endpoint.String(), service) +func RequestInfoRefs(ctx context.Context, conn Conn, service string, gitProtocol string) ([]byte, error) { + data, err := conn.RequestInfoRefs(ctx, service, gitProtocol) + if err != nil { + return nil, fmt.Errorf("request info refs: %w", err) + } + return data, nil +} + +// RequestInfoRefs fetches /info/refs for the given service. +func (c *HTTPConn) RequestInfoRefs(ctx context.Context, service string, gitProtocol string) ([]byte, error) { + reqURL := fmt.Sprintf("%s/info/refs?service=%s", c.EndpointURL.String(), service) req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { return nil, fmt.Errorf("create info-refs request: %w", err) @@ -155,9 +172,9 @@ func RequestInfoRefs(ctx context.Context, conn *Conn, service string, gitProtoco if gitProtocol != "" { req.Header.Set("Git-Protocol", gitProtocol) } - ApplyAuth(req, conn.Auth) + ApplyAuth(req, c.Auth) - res, err := conn.HTTP.Do(req) + res, err := c.HTTP.Do(req) if err != nil { return nil, fmt.Errorf("request info-refs: %w", err) } @@ -176,11 +193,11 @@ func RequestInfoRefs(ctx context.Context, conn *Conn, service string, gitProtoco if gotMediaType != wantContentType { return nil, fmt.Errorf("unexpected info/refs content-type %q, want %q", gotContentType, wantContentType) } - if conn.FollowInfoRefsRedirect && res.Request != nil && res.Request.URL != nil { + if c.FollowInfoRefsRedirect && res.Request != nil && res.Request.URL != nil { final := res.Request.URL - if final.Host != conn.Endpoint.Host || final.Scheme != conn.Endpoint.Scheme { - conn.Endpoint.Scheme = final.Scheme - conn.Endpoint.Host = final.Host + if final.Host != c.EndpointURL.Host || final.Scheme != c.EndpointURL.Scheme { + c.EndpointURL.Scheme = final.Scheme + c.EndpointURL.Host = final.Host } } // Bound the read to prevent unbounded memory allocation (issue #9). @@ -198,7 +215,7 @@ func RequestInfoRefs(ctx context.Context, conn *Conn, service string, gitProtoco // PostRPC sends a buffered POST to the given service and returns the full response body. // Responses are bounded to prevent unbounded memory allocation (issue #9). -func PostRPC(ctx context.Context, conn *Conn, service string, body []byte, v2 bool, phase string) ([]byte, error) { +func PostRPC(ctx context.Context, conn Conn, service string, body []byte, v2 bool, phase string) ([]byte, error) { reader, err := PostRPCStream(ctx, conn, service, body, v2, phase) if err != nil { return nil, err @@ -218,14 +235,24 @@ func PostRPC(ctx context.Context, conn *Conn, service string, body []byte, v2 bo // PostRPCStream sends a POST to the given service and returns the response body // as a streaming reader. Caller must close the returned ReadCloser. -func PostRPCStream(ctx context.Context, conn *Conn, service string, body []byte, v2 bool, phase string) (io.ReadCloser, error) { +func PostRPCStream(ctx context.Context, conn Conn, service string, body []byte, v2 bool, phase string) (io.ReadCloser, error) { return PostRPCStreamBody(ctx, conn, service, bytes.NewReader(body), v2, phase) } // PostRPCStreamBody sends a POST to the given service using a streaming request body. // Caller must close the returned ReadCloser. -func PostRPCStreamBody(ctx context.Context, conn *Conn, service string, body io.Reader, v2 bool, phase string) (io.ReadCloser, error) { - reqURL := fmt.Sprintf("%s/%s", conn.Endpoint.String(), service) +func PostRPCStreamBody(ctx context.Context, conn Conn, service string, body io.Reader, v2 bool, phase string) (io.ReadCloser, error) { + reader, err := conn.PostRPCStreamBody(ctx, service, body, v2, phase) + if err != nil { + return nil, fmt.Errorf("post RPC stream body: %w", err) + } + return reader, nil +} + +// PostRPCStreamBody sends a POST to the given service using a streaming request body. +// Caller must close the returned ReadCloser. +func (c *HTTPConn) PostRPCStreamBody(ctx context.Context, service string, body io.Reader, v2 bool, phase string) (io.ReadCloser, error) { + reqURL := fmt.Sprintf("%s/%s", c.EndpointURL.String(), service) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, body) if err != nil { return nil, fmt.Errorf("create RPC request: %w", err) @@ -235,11 +262,11 @@ func PostRPCStreamBody(ctx context.Context, conn *Conn, service string, body io. req.Header.Set("User-Agent", capability.DefaultAgent()) req.Header.Set(StatsPhaseHeader, phase) if v2 { - req.Header.Set("Git-Protocol", "version=2") + req.Header.Set("Git-Protocol", GitProtocolV2) } - ApplyAuth(req, conn.Auth) + ApplyAuth(req, c.Auth) - res, err := conn.HTTP.Do(req) + res, err := c.HTTP.Do(req) if err != nil { return nil, fmt.Errorf("post RPC: %w", err) } diff --git a/internal/gitproto/smarthttp_test.go b/internal/gitproto/smarthttp_test.go index d77a7c5a..4ba264eb 100644 --- a/internal/gitproto/smarthttp_test.go +++ b/internal/gitproto/smarthttp_test.go @@ -15,19 +15,19 @@ import ( transporthttp "github.com/go-git/go-git/v6/plumbing/transport/http" ) -func TestNewConn(t *testing.T) { +func TestNewHTTPConn(t *testing.T) { ep, err := transport.ParseURL("https://github.com/user/repo.git") if err != nil { t.Fatalf("parse endpoint: %v", err) } auth := &transporthttp.BasicAuth{Username: "user", Password: "pass"} - conn := NewConn(ep, "test-label", auth, http.DefaultTransport) + conn := NewHTTPConn(ep, "test-label", auth, http.DefaultTransport) if conn.Label != "test-label" { t.Errorf("Label = %q, want %q", conn.Label, "test-label") } - if conn.Endpoint != ep { - t.Error("Endpoint mismatch") + if conn.EndpointURL != ep { + t.Error("EndpointURL mismatch") } if conn.Auth != auth { t.Error("Auth mismatch") @@ -37,13 +37,13 @@ func TestNewConn(t *testing.T) { } } -func TestNewConnStripsTrailingEndpointSlash(t *testing.T) { +func TestNewHTTPConnStripsTrailingEndpointSlash(t *testing.T) { ep, err := url.Parse("https://example.com/repo.git///") if err != nil { t.Fatalf("parse endpoint: %v", err) } var gotURLs []string - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { gotURLs = append(gotURLs, req.URL.String()) res := &http.Response{ StatusCode: http.StatusOK, @@ -58,8 +58,8 @@ func TestNewConnStripsTrailingEndpointSlash(t *testing.T) { return res, nil })) - if got, want := conn.Endpoint.Path, "/repo.git"; got != want { - t.Fatalf("Endpoint.Path = %q, want %q", got, want) + if got, want := conn.EndpointURL.Path, "/repo.git"; got != want { + t.Fatalf("EndpointURL.Path = %q, want %q", got, want) } if _, err := RequestInfoRefs(t.Context(), conn, transport.UploadPackService, ""); err != nil { t.Fatalf("RequestInfoRefs: %v", err) @@ -141,7 +141,7 @@ func TestRequestInfoRefsContextCanceled(t *testing.T) { if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { started <- struct{}{} <-req.Context().Done() return nil, req.Context().Err() @@ -151,7 +151,7 @@ func TestRequestInfoRefsContextCanceled(t *testing.T) { done := make(chan error, 1) go func() { - _, err := RequestInfoRefs(ctx, conn, "git-upload-pack", "version=2") + _, err := RequestInfoRefs(ctx, conn, "git-upload-pack", GitProtocolV2) done <- err }() @@ -211,7 +211,7 @@ func TestRequestInfoRefsRequiresAdvertisementContentType(t *testing.T) { if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { res := &http.Response{ StatusCode: http.StatusOK, Request: req, @@ -250,7 +250,7 @@ func TestPostRPCStreamContextCanceled(t *testing.T) { if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { + conn := NewHTTPConn(ep, "source", nil, roundTripperFunc(func(req *http.Request) (*http.Response, error) { started <- struct{}{} <-req.Context().Done() return nil, req.Context().Err() @@ -285,7 +285,7 @@ func TestPostRPCStreamContextCanceled(t *testing.T) { } // TestRequestInfoRefs_FollowInfoRefsRedirect verifies that when the flag is -// set, a 307 on /info/refs rewrites Conn.Endpoint.Host so subsequent PostRPC +// set, a 307 on /info/refs rewrites HTTPConn.EndpointURL.Host so subsequent PostRPC // calls target the redirected node. Matches vanilla git's smart-HTTP // behaviour and lets clients use a cluster entry domain for info/refs while // packs land on the hosting replica. @@ -307,7 +307,7 @@ func TestRequestInfoRefs_FollowInfoRefsRedirect(t *testing.T) { if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "test", nil, http.DefaultTransport) + conn := NewHTTPConn(ep, "test", nil, http.DefaultTransport) conn.FollowInfoRefsRedirect = true if _, err := RequestInfoRefs(t.Context(), conn, transport.UploadPackService, ""); err != nil { @@ -315,8 +315,8 @@ func TestRequestInfoRefs_FollowInfoRefsRedirect(t *testing.T) { } nodeURL := strings.TrimPrefix(node.URL, "http://") - if conn.Endpoint.Host != nodeURL { - t.Errorf("Endpoint.Host = %q, want %q (endpoint should follow the 307)", conn.Endpoint.Host, nodeURL) + if conn.EndpointURL.Host != nodeURL { + t.Errorf("EndpointURL.Host = %q, want %q (endpoint should follow the 307)", conn.EndpointURL.Host, nodeURL) } } @@ -360,7 +360,7 @@ func TestRequestInfoRefs_FollowInfoRefsRedirect_SubsequentPOSTHitsRedirectedHost if err != nil { t.Fatalf("parse endpoint: %v", err) } - conn := NewConn(ep, "test", nil, http.DefaultTransport) + conn := NewHTTPConn(ep, "test", nil, http.DefaultTransport) conn.FollowInfoRefsRedirect = true if _, err := RequestInfoRefs(t.Context(), conn, transport.UploadPackService, ""); err != nil { @@ -387,7 +387,7 @@ func TestRequestInfoRefs_FollowInfoRefsRedirect_SubsequentPOSTHitsRedirectedHost } // TestRequestInfoRefs_DoesNotFollowByDefault confirms the default behaviour -// is unchanged: Endpoint is stable even if the server 307s. +// is unchanged: EndpointURL is stable even if the server 307s. func TestRequestInfoRefs_DoesNotFollowByDefault(t *testing.T) { node := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/x-git-upload-pack-advertisement") @@ -407,15 +407,15 @@ func TestRequestInfoRefs_DoesNotFollowByDefault(t *testing.T) { t.Fatalf("parse endpoint: %v", err) } entryHost := ep.Host - conn := NewConn(ep, "test", nil, http.DefaultTransport) + conn := NewHTTPConn(ep, "test", nil, http.DefaultTransport) // FollowInfoRefsRedirect intentionally not set. if _, err := RequestInfoRefs(t.Context(), conn, transport.UploadPackService, ""); err != nil { t.Fatalf("RequestInfoRefs: %v", err) } - if conn.Endpoint.Host != entryHost { - t.Errorf("Endpoint.Host = %q, want %q (endpoint should be unchanged by default)", conn.Endpoint.Host, entryHost) + if conn.EndpointURL.Host != entryHost { + t.Errorf("EndpointURL.Host = %q, want %q (endpoint should be unchanged by default)", conn.EndpointURL.Host, entryHost) } } diff --git a/internal/gitproto/ssh.go b/internal/gitproto/ssh.go new file mode 100644 index 00000000..f8a36d93 --- /dev/null +++ b/internal/gitproto/ssh.go @@ -0,0 +1,334 @@ +package gitproto + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net/url" + "os/exec" + "strings" + "sync" +) + +// SSHLookPath is replaceable in tests. +var SSHLookPath = exec.LookPath + +// SSHConn represents a Git transport over the local ssh binary. +type SSHConn struct { + Label string + EndpointURL *url.URL + sshPath string + progressOut io.Writer +} + +// NewSSHConn creates a new SSH transport connection backed by the local ssh +// binary. +func NewSSHConn(ep *url.URL, label string) (*SSHConn, error) { + sshPath, err := SSHLookPath("ssh") + if err != nil { + return nil, fmt.Errorf("locate ssh binary: %w", err) + } + normalizeEndpointPath(ep) + return &SSHConn{ + Label: label, + EndpointURL: ep, + sshPath: sshPath, + }, nil +} + +func (c *SSHConn) Endpoint() *url.URL { return c.EndpointURL } + +func (c *SSHConn) ProgressWriter() io.Writer { return c.progressOut } + +func (c *SSHConn) SetProgressWriter(w io.Writer) { c.progressOut = w } + +func (c *SSHConn) Close() error { return nil } + +func (c *SSHConn) RequestInfoRefs(ctx context.Context, service string, gitProtocol string) ([]byte, error) { + cmd, stderr, err := c.startRPC(ctx, service, gitProtocol) + if err != nil { + return nil, err + } + return requestInfoRefsWithCommand(ctx, service, cmd, stderr) +} + +func requestInfoRefsWithCommand(ctx context.Context, service string, cmd *sshCommand, stderr *sshCommandError) ([]byte, error) { + if err := cmd.Stdin.Close(); err != nil { + return nil, fmt.Errorf("close ssh stdin for %s: %w", service, errors.Join(err, cleanupSSHCommand(cmd))) + } + data, readErr := io.ReadAll(cmd.Stdout) + waitErr := cmd.wait() + if ctx.Err() != nil { + return nil, errors.Join(ctx.Err(), readErr, stderr.wrap(waitErr)) + } + if readErr != nil { + return nil, fmt.Errorf("%s info-refs: %w", service, readErr) + } + if len(data) > 0 { + return data, nil + } + if waitErr != nil { + return nil, fmt.Errorf("%s info-refs: %w", service, stderr.wrap(waitErr)) + } + return data, nil +} + +func (c *SSHConn) PostRPCStreamBody(ctx context.Context, service string, body io.Reader, v2 bool, phase string) (io.ReadCloser, error) { + _ = phase // phase labels are HTTP-only today; SSH transport has no per-RPC stats tagging + gitProtocol := "" + if v2 { + gitProtocol = GitProtocolV2 + } + cmd, stderr, err := c.startRPC(ctx, service, gitProtocol) + if err != nil { + return nil, err + } + copyErr := make(chan error, 1) + go func() { + _, err := io.Copy(cmd.Stdin, body) + closeErr := cmd.Stdin.Close() + if err != nil { + copyErr <- err + return + } + copyErr <- closeErr + }() + stdout, err := discardSSHAdvertisement(cmd.Stdout) + if err != nil { + _ = cmd.Stdout.Close() + waitErr := cmd.wait() + return nil, fmt.Errorf("%s advertisement: %w", service, errors.Join(stderr.wrap(err), waitErr)) + } + return &sshRPCStream{ + ctx: ctx, + stdout: stdout, + wait: cmd.wait, + copyErr: copyErr, + stderr: stderr, + }, nil +} + +func (c *SSHConn) startRPC(ctx context.Context, service string, gitProtocol string) (*sshCommand, *sshCommandError, error) { + args, err := sshInvocationArgs(c.EndpointURL, service, gitProtocol) + if err != nil { + return nil, nil, err + } + cmd := exec.CommandContext(ctx, c.sshPath, args...) + stderr := &sshCommandError{} + cmd.Stderr = stderr + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, nil, fmt.Errorf("open ssh stdout for %s: %w", service, err) + } + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, nil, fmt.Errorf("open ssh stdin for %s: %w", service, err) + } + if err := cmd.Start(); err != nil { + return nil, nil, fmt.Errorf("start ssh for %s: %w", service, stderr.wrap(err)) + } + return &sshCommand{Cmd: cmd, Stdin: stdin, Stdout: stdout}, stderr, nil +} + +func sshInvocationArgs(ep *url.URL, service string, gitProtocol string) ([]string, error) { + destination, err := sshDestination(ep) + if err != nil { + return nil, err + } + remoteCommand, err := sshRemoteCommand(ep, service, gitProtocol) + if err != nil { + return nil, err + } + args := []string{"-o", "BatchMode=yes"} + if port := ep.Port(); port != "" { + args = append(args, "-p", port) + } + args = append(args, destination, remoteCommand) + return args, nil +} + +func sshDestination(ep *url.URL) (string, error) { + if ep == nil || ep.Hostname() == "" { + return "", errors.New("missing SSH host") + } + host := ep.Hostname() + if ep.User != nil && ep.User.Username() != "" { + return ep.User.Username() + "@" + host, nil + } + return host, nil +} + +func sshRemoteCommand(ep *url.URL, service string, gitProtocol string) (string, error) { + if ep == nil || ep.Path == "" { + return "", errors.New("missing SSH repository path") + } + path := shellQuotePath(ep.Path) + if gitProtocol != "" { + return gitProtocolEnv(gitProtocol) + " " + service + " " + path, nil + } + return service + " " + path, nil +} + +func gitProtocolEnv(gitProtocol string) string { + return "GIT_PROTOCOL=" + shellQuote(gitProtocol) +} + +func shellQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", `'"'"'`) + "'" +} + +func shellQuotePath(path string) string { + if !strings.HasPrefix(path, "~") { + return shellQuote(path) + } + slash := strings.IndexByte(path, '/') + if slash < 0 { + return path + } + return path[:slash+1] + shellQuote(path[slash+1:]) +} + +type sshCommand struct { + Cmd *exec.Cmd + Stdin io.WriteCloser + Stdout io.ReadCloser + waitFn func() error +} + +func (c *sshCommand) wait() error { + if c.waitFn != nil { + return c.waitFn() + } + if err := c.Cmd.Wait(); err != nil { + return fmt.Errorf("wait for ssh command: %w", err) + } + return nil +} + +func cleanupSSHCommand(cmd *sshCommand) error { + if cmd == nil { + return nil + } + var errs []error + if cmd.Stdout != nil { + if err := cmd.Stdout.Close(); err != nil { + errs = append(errs, fmt.Errorf("close ssh stdout: %w", err)) + } + } + if err := cmd.wait(); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +} + +func discardSSHAdvertisement(stdout io.ReadCloser) (io.ReadCloser, error) { + buffered := bufio.NewReader(stdout) + header, err := buffered.Peek(4) + if err != nil { + if errors.Is(err, io.EOF) { + return &bufferedReadCloser{Reader: buffered, Closer: stdout}, nil + } + return nil, fmt.Errorf("peek SSH advertisement: %w", err) + } + if !looksLikePktlineHeader(header) { + return &bufferedReadCloser{Reader: buffered, Closer: stdout}, nil + } + reader := NewPacketReader(buffered) + for { + kind, _, err := reader.ReadPacket() + if err != nil { + return nil, err + } + if kind == PacketFlush { + break + } + } + return &bufferedReadCloser{Reader: reader.BufReader(), Closer: stdout}, nil +} + +func looksLikePktlineHeader(header []byte) bool { + if len(header) != 4 { + return false + } + var fixed [4]byte + copy(fixed[:], header) + switch string(fixed[:]) { + case "0000", "0001", "0002": + return true + } + _, err := parseHexLength(fixed) + return err == nil +} + +type sshRPCStream struct { + ctx context.Context + stdout io.ReadCloser + wait func() error + copyErr <-chan error + stderr *sshCommandError + + waitOnce sync.Once + waitErr error +} + +type bufferedReadCloser struct { + *bufio.Reader + io.Closer +} + +func (s *sshRPCStream) Read(p []byte) (int, error) { + n, err := s.stdout.Read(p) + return n, err //nolint:wrapcheck // io.Reader contract requires forwarding EOF and stream errors as-is +} + +func (s *sshRPCStream) Close() error { + closeErr := s.stdout.Close() + s.waitOnce.Do(func() { + copyErr := <-s.copyErr + waitErr := s.wait() + if copyErr != nil { + s.waitErr = copyErr + if waitErr != nil { + s.waitErr = errors.Join(copyErr, s.stderr.wrap(waitErr)) + } + return + } + if waitErr != nil { + s.waitErr = s.stderr.wrap(waitErr) + } + if s.ctx != nil && s.ctx.Err() != nil { + s.waitErr = errors.Join(s.ctx.Err(), s.waitErr) + } + }) + return errors.Join(closeErr, s.waitErr) +} + +type sshCommandError struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (e *sshCommandError) Write(p []byte) (int, error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.buf.Write(p) //nolint:wrapcheck // io.Writer implementation forwards buffer write errors verbatim +} + +func (e *sshCommandError) String() string { + e.mu.Lock() + defer e.mu.Unlock() + return strings.TrimSpace(e.buf.String()) +} + +func (e *sshCommandError) wrap(err error) error { + if err == nil { + return nil + } + if msg := e.String(); msg != "" { + return fmt.Errorf("%w: %s", err, msg) + } + return err +} diff --git a/internal/gitproto/ssh_test.go b/internal/gitproto/ssh_test.go new file mode 100644 index 00000000..64d7094c --- /dev/null +++ b/internal/gitproto/ssh_test.go @@ -0,0 +1,286 @@ +package gitproto + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/go-git/go-git/v6/plumbing/transport" +) + +func TestNewSSHConnRequiresBinary(t *testing.T) { + orig := SSHLookPath + t.Cleanup(func() { SSHLookPath = orig }) + SSHLookPath = func(string) (string, error) { + return "", errors.New("not found") + } + + ep, err := transport.ParseURL("ssh://example.com/repo.git") + if err != nil { + t.Fatalf("parse url: %v", err) + } + _, err = NewSSHConn(ep, "source") + if err == nil || !strings.Contains(err.Error(), "locate ssh binary") { + t.Fatalf("NewSSHConn error = %v, want locate ssh binary failure", err) + } +} + +func TestSSHConnRequestInfoRefsHonorsUserConfigAndProtocolV2(t *testing.T) { + env := newSSHShimEnv(t) + conn := newSSHTestConn(t, "ssh://example.com/repo.git", env.script) + + body, err := conn.RequestInfoRefs(t.Context(), "git-upload-pack", GitProtocolV2) + if err != nil { + t.Fatalf("RequestInfoRefs: %v", err) + } + if string(body) != "response-1" { + t.Fatalf("RequestInfoRefs body = %q, want %q", body, "response-1") + } + + logLine := env.logLines(t)[0] + if got, want := logLine, "example.com\tGIT_PROTOCOL='version=2' git-upload-pack '/repo.git'"; got != want { + t.Fatalf("ssh invocation = %q, want %q", got, want) + } +} + +func TestSSHConnRequestInfoRefsSupportsSCPStyleAndPort(t *testing.T) { + env := newSSHShimEnv(t) + + scpConn := newSSHTestConn(t, "git@example.com:repo.git", env.script) + if _, err := scpConn.RequestInfoRefs(t.Context(), "git-upload-pack", ""); err != nil { + t.Fatalf("scp RequestInfoRefs: %v", err) + } + + portConn := newSSHTestConn(t, "ssh://alice@example.com:2222/repo.git", env.script) + if _, err := portConn.RequestInfoRefs(t.Context(), "git-upload-pack", ""); err != nil { + t.Fatalf("port RequestInfoRefs: %v", err) + } + + lines := env.logLines(t) + if got, want := lines[0], "git@example.com\tgit-upload-pack 'repo.git'"; got != want { + t.Fatalf("scp invocation = %q, want %q", got, want) + } + if got, want := lines[1], "-p 2222 alice@example.com\tgit-upload-pack '/repo.git'"; got != want { + t.Fatalf("port invocation = %q, want %q", got, want) + } +} + +func TestSSHConnRequestInfoRefsPreservesTildePaths(t *testing.T) { + env := newSSHShimEnv(t) + + conn := newSSHTestConn(t, "git@example.com:~/repo with spaces.git", env.script) + if _, err := conn.RequestInfoRefs(t.Context(), "git-upload-pack", ""); err != nil { + t.Fatalf("RequestInfoRefs: %v", err) + } + if got, want := env.logLines(t)[0], "git@example.com\tgit-upload-pack ~/'repo with spaces.git'"; got != want { + t.Fatalf("ssh invocation = %q, want %q", got, want) + } +} + +func TestSSHConnPostRPCStreamBodyCanBeCalledRepeatedly(t *testing.T) { + env := newSSHShimEnv(t) + conn := newSSHTestConn(t, "ssh://example.com/repo.git", env.script) + + reader1, err := conn.PostRPCStreamBody(t.Context(), "git-upload-pack", strings.NewReader("first-body"), false, "fetch one") + if err != nil { + t.Fatalf("PostRPCStreamBody first: %v", err) + } + data1, err := io.ReadAll(reader1) + if err != nil { + t.Fatalf("read first response: %v", err) + } + if err := reader1.Close(); err != nil { + t.Fatalf("close first response: %v", err) + } + + reader2, err := conn.PostRPCStreamBody(t.Context(), "git-upload-pack", strings.NewReader("second-body"), false, "fetch two") + if err != nil { + t.Fatalf("PostRPCStreamBody second: %v", err) + } + data2, err := io.ReadAll(reader2) + if err != nil { + t.Fatalf("read second response: %v", err) + } + if err := reader2.Close(); err != nil { + t.Fatalf("close second response: %v", err) + } + + if string(data1) != "response-1" || string(data2) != "response-2" { + t.Fatalf("responses = %q / %q, want response-1 / response-2", data1, data2) + } + if got, want := env.body(t, 1), "first-body"; got != want { + t.Fatalf("first request body = %q, want %q", got, want) + } + if got, want := env.body(t, 2), "second-body"; got != want { + t.Fatalf("second request body = %q, want %q", got, want) + } + if got := len(env.logLines(t)); got != 2 { + t.Fatalf("ssh invocation count = %d, want 2", got) + } +} + +func TestSSHConnRequestInfoRefsHonorsContext(t *testing.T) { + dir := t.TempDir() + script := filepath.Join(dir, "ssh-sleep.sh") + if err := os.WriteFile(script, []byte("#!/bin/sh\nsleep 5\n"), 0o755); err != nil { + t.Fatalf("write script: %v", err) + } + + conn := newSSHTestConn(t, "ssh://example.com/repo.git", script) + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + + _, err := conn.RequestInfoRefs(ctx, "git-upload-pack", "") + if err == nil { + t.Fatal("RequestInfoRefs returned nil error on canceled context") + } + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("RequestInfoRefs error = %v, want context deadline exceeded", err) + } +} + +func TestSSHConnPostRPCStreamBodyHonorsContext(t *testing.T) { + dir := t.TempDir() + script := filepath.Join(dir, "ssh-read-sleep.sh") + if err := os.WriteFile(script, []byte("#!/bin/sh\ncat >/dev/null\nsleep 5\n"), 0o755); err != nil { + t.Fatalf("write script: %v", err) + } + + conn := newSSHTestConn(t, "ssh://example.com/repo.git", script) + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + + reader, err := conn.PostRPCStreamBody(ctx, "git-upload-pack", strings.NewReader("body"), false, "fetch") + if err != nil { + t.Fatalf("PostRPCStreamBody: %v", err) + } + if _, err := io.ReadAll(reader); err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("ReadAll error = %v", err) + } + if err := reader.Close(); err == nil || !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("reader.Close error = %v, want context deadline exceeded", err) + } +} + +func TestRequestInfoRefsCleansUpWhenStdinCloseFails(t *testing.T) { + t.Parallel() + + stdoutClosed := false + waitCalled := false + _, err := requestInfoRefsWithCommand( + t.Context(), + "git-upload-pack", + &sshCommand{ + Stdin: closeWriterFunc(func() error { return errors.New("close failed") }), + Stdout: closeReaderFunc(func() error { stdoutClosed = true; return nil }), + waitFn: func() error { waitCalled = true; return nil }, + }, + &sshCommandError{}, + ) + if err == nil || !strings.Contains(err.Error(), "close ssh stdin for git-upload-pack") { + t.Fatalf("requestInfoRefsWithCommand error = %v", err) + } + if !stdoutClosed { + t.Fatal("stdout was not closed on stdin-close failure") + } + if !waitCalled { + t.Fatal("wait was not called on stdin-close failure") + } +} + +type sshShimEnv struct { + script string + logFile string + bodyPrefix string +} + +type closeWriterFunc func() error + +func (f closeWriterFunc) Write(p []byte) (int, error) { return len(p), nil } +func (f closeWriterFunc) Close() error { return f() } + +type closeReaderFunc func() error + +func (f closeReaderFunc) Read([]byte) (int, error) { return 0, io.EOF } +func (f closeReaderFunc) Close() error { return f() } + +func newSSHShimEnv(t *testing.T) sshShimEnv { + t.Helper() + dir := t.TempDir() + logFile := filepath.Join(dir, "ssh.log") + countFile := filepath.Join(dir, "count") + bodyPrefix := filepath.Join(dir, "body-") + script := filepath.Join(dir, "ssh-shim.sh") + content := strings.Join([]string{ + "#!/bin/sh", + "count=0", + "if [ -f " + shellQuote(countFile) + " ]; then count=$(cat " + shellQuote(countFile) + "); fi", + "count=$((count+1))", + "printf '%s' \"$count\" >" + shellQuote(countFile), + "dest=\"\"", + "remote=\"\"", + "if [ \"$1\" = \"-o\" ]; then", + " shift 2", + "fi", + "if [ \"$1\" = \"-p\" ]; then", + " port=\"$2\"", + " shift 2", + " dest=\"-p $port $1\"", + "else", + " dest=\"$1\"", + "fi", + "remote=\"$2\"", + "printf '%s\\t%s\\n' \"$dest\" \"$remote\" >>" + shellQuote(logFile), + "cat >" + shellQuote(bodyPrefix) + "\"$count\"", + "printf 'response-%s' \"$count\"", + }, "\n") + if err := os.WriteFile(script, []byte(content), 0o755); err != nil { + t.Fatalf("write ssh shim: %v", err) + } + return sshShimEnv{script: script, logFile: logFile, bodyPrefix: bodyPrefix} +} + +func newSSHTestConn(t *testing.T, rawURL, script string) *SSHConn { + t.Helper() + orig := SSHLookPath + t.Cleanup(func() { SSHLookPath = orig }) + SSHLookPath = func(string) (string, error) { return script, nil } + + ep, err := transport.ParseURL(rawURL) + if err != nil { + t.Fatalf("parse url %q: %v", rawURL, err) + } + conn, err := NewSSHConn(ep, "source") + if err != nil { + t.Fatalf("NewSSHConn(%q): %v", rawURL, err) + } + return conn +} + +func (e sshShimEnv) logLines(t *testing.T) []string { + t.Helper() + data, err := os.ReadFile(e.logFile) + if err != nil { + t.Fatalf("read log: %v", err) + } + trimmed := strings.TrimSpace(string(data)) + if trimmed == "" { + return nil + } + return strings.Split(trimmed, "\n") +} + +func (e sshShimEnv) body(t *testing.T, count int) string { + t.Helper() + data, err := os.ReadFile(e.bodyPrefix + strconv.Itoa(count)) + if err != nil { + t.Fatalf("read body %d: %v", count, err) + } + return string(data) +} diff --git a/internal/strategy/bootstrap/bootstrap.go b/internal/strategy/bootstrap/bootstrap.go index f1b9db2c..2fb47cb8 100644 --- a/internal/strategy/bootstrap/bootstrap.go +++ b/internal/strategy/bootstrap/bootstrap.go @@ -40,10 +40,10 @@ var GitHubRepoAPIBaseURL = "https://api.github.com" // Params holds the inputs for a bootstrap execution. type Params struct { - SourceConn *gitproto.Conn + SourceConn gitproto.Conn SourceService interface { - FetchPack(ctx context.Context, conn *gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, haves map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) - FetchCommitGraph(ctx context.Context, store storer.Storer, conn *gitproto.Conn, ref gitproto.DesiredRef, haves []plumbing.Hash) error + FetchPack(ctx context.Context, conn gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, haves map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) + FetchCommitGraph(ctx context.Context, store storer.Storer, conn gitproto.Conn, ref gitproto.DesiredRef, haves []plumbing.Hash) error SupportsBootstrapBatch() bool } TargetPusher interface { @@ -1326,7 +1326,7 @@ func packReaderForCheckpoint( // --- GitHub preflight --- func githubBatchLimit(ctx context.Context, p Params) (int64, bool) { - if p.TargetMaxPack > 0 || p.SourceConn == nil || p.SourceConn.Endpoint == nil { + if p.TargetMaxPack > 0 || p.SourceConn == nil || p.SourceConn.Endpoint() == nil { return 0, false } if p.SourceService == nil || !p.SourceService.SupportsBootstrapBatch() { @@ -1346,7 +1346,11 @@ func githubBatchLimit(ctx context.Context, p Params) (int64, bool) { return limit, true } -func lookupGitHubRepoSizeKB(ctx context.Context, conn *gitproto.Conn) (int64, bool) { +func lookupGitHubRepoSizeKB(ctx context.Context, conn gitproto.Conn) (int64, bool) { + httpConn, ok := conn.(*gitproto.HTTPConn) + if !ok { + return 0, false + } owner, repo, ok := GitHubOwnerRepo(conn) if !ok { return 0, false @@ -1360,7 +1364,7 @@ func lookupGitHubRepoSizeKB(ctx context.Context, conn *gitproto.Conn) (int64, bo req.Header.Set("X-Github-Api-Version", "2022-11-28") req.Header.Set("User-Agent", capability.DefaultAgent()) req.Header.Set(gitproto.StatsPhaseHeader, "github repo metadata") - resp, err := conn.HTTP.Do(req) + resp, err := httpConn.HTTP.Do(req) if err != nil { return 0, false } @@ -1378,11 +1382,11 @@ func lookupGitHubRepoSizeKB(ctx context.Context, conn *gitproto.Conn) (int64, bo } // GitHubOwnerRepo extracts the owner/repo from a GitHub endpoint. -func GitHubOwnerRepo(conn *gitproto.Conn) (string, string, bool) { - if conn == nil || conn.Endpoint == nil { +func GitHubOwnerRepo(conn gitproto.Conn) (string, string, bool) { + if conn == nil || conn.Endpoint() == nil { return "", "", false } - ep := conn.Endpoint + ep := conn.Endpoint() if ep.Scheme != "http" && ep.Scheme != "https" { return "", "", false } diff --git a/internal/strategy/bootstrap/bootstrap_test.go b/internal/strategy/bootstrap/bootstrap_test.go index 339e6047..11452921 100644 --- a/internal/strategy/bootstrap/bootstrap_test.go +++ b/internal/strategy/bootstrap/bootstrap_test.go @@ -1149,7 +1149,7 @@ func TestExecuteBatchedSubsumedBranchSkipsPack(t *testing.T) { _, err := Execute(context.Background(), Params{ SourceService: fakeBootstrapSource{ - fetchCommitGraph: func(_ context.Context, store storer.Storer, _ *gitproto.Conn, ref gitproto.DesiredRef, _ []plumbing.Hash) error { + fetchCommitGraph: func(_ context.Context, store storer.Storer, _ gitproto.Conn, ref gitproto.DesiredRef, _ []plumbing.Hash) error { graphFetches++ if ref.SourceRef != mainRef { t.Errorf("unexpected commit-graph fetch for %s; subsumed branch should have been skipped", ref.SourceRef) @@ -1157,7 +1157,7 @@ func TestExecuteBatchedSubsumedBranchSkipsPack(t *testing.T) { writeLinearCommitChain(t, store, 3) return nil }, - fetchPack: func(_ context.Context, _ *gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { packFetches++ if _, ok := desired[featureRef]; ok { t.Errorf("unexpected pack fetch including feature ref: %+v", desired) @@ -1212,13 +1212,13 @@ func TestExecuteBatchedSubsumedBranchSkipsPack(t *testing.T) { } type fakeBootstrapSource struct { - fetchPack func(context.Context, *gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) - fetchCommitGraph func(context.Context, storer.Storer, *gitproto.Conn, gitproto.DesiredRef, []plumbing.Hash) error + fetchPack func(context.Context, gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) + fetchCommitGraph func(context.Context, storer.Storer, gitproto.Conn, gitproto.DesiredRef, []plumbing.Hash) error } func (f fakeBootstrapSource) FetchPack( ctx context.Context, - conn *gitproto.Conn, + conn gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, ) (io.ReadCloser, error) { @@ -1228,7 +1228,7 @@ func (f fakeBootstrapSource) FetchPack( func (f fakeBootstrapSource) FetchCommitGraph( ctx context.Context, store storer.Storer, - conn *gitproto.Conn, + conn gitproto.Conn, ref gitproto.DesiredRef, haves []plumbing.Hash, ) error { @@ -1298,7 +1298,7 @@ func TestExecuteOneShotUsesTargetPusher(t *testing.T) { result, err := Execute(context.Background(), Params{ SourceService: fakeBootstrapSource{ - fetchPack: func(_ context.Context, _ *gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { gotDesired = desired if targetRefs != nil { t.Fatalf("expected nil target refs during one-shot bootstrap fetch, got %v", targetRefs) @@ -1344,7 +1344,7 @@ func TestExecuteOneShotClosesPackOnPushError(t *testing.T) { _, err := Execute(context.Background(), Params{ SourceService: fakeBootstrapSource{ - fetchPack: func(_ context.Context, _ *gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { return pack, nil }, }, @@ -1378,7 +1378,7 @@ func TestExecuteOneShotClosesPackWhenPusherDoesNot(t *testing.T) { _, err := Execute(context.Background(), Params{ SourceService: fakeBootstrapSource{ - fetchPack: func(_ context.Context, _ *gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { return pack, nil }, }, @@ -1411,11 +1411,11 @@ func TestExecuteBatchedClosesCheckpointPackOnPushError(t *testing.T) { _, err := Execute(context.Background(), Params{ SourceService: fakeBootstrapSource{ - fetchCommitGraph: func(_ context.Context, store storer.Storer, _ *gitproto.Conn, _ gitproto.DesiredRef, _ []plumbing.Hash) error { + fetchCommitGraph: func(_ context.Context, store storer.Storer, _ gitproto.Conn, _ gitproto.DesiredRef, _ []plumbing.Hash) error { writeLinearCommitChain(t, store, 1) return nil }, - fetchPack: func(_ context.Context, _ *gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { return pack, nil }, }, @@ -1451,11 +1451,11 @@ func TestExecuteBatchedClosesCheckpointPackOnReadInterruption(t *testing.T) { _, err := Execute(context.Background(), Params{ SourceService: fakeBootstrapSource{ - fetchCommitGraph: func(_ context.Context, store storer.Storer, _ *gitproto.Conn, _ gitproto.DesiredRef, _ []plumbing.Hash) error { + fetchCommitGraph: func(_ context.Context, store storer.Storer, _ gitproto.Conn, _ gitproto.DesiredRef, _ []plumbing.Hash) error { writeLinearCommitChain(t, store, 1) return nil }, - fetchPack: func(_ context.Context, _ *gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { return pack, nil }, }, @@ -1502,7 +1502,7 @@ func TestExecuteRequiresTargetPusherBeforeFetch(t *testing.T) { calledFetch := false _, err := Execute(context.Background(), Params{ SourceService: fakeBootstrapSource{ - fetchPack: func(context.Context, *gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(context.Context, gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { calledFetch = true return io.NopCloser(bytes.NewReader(nil)), nil }, @@ -1546,12 +1546,12 @@ func TestExecuteRequiresTargetPusherBeforeGitHubPreflight(t *testing.T) { } _, err = Execute(context.Background(), Params{ - SourceConn: &gitproto.Conn{ - Endpoint: ep, - HTTP: server.Client(), + SourceConn: &gitproto.HTTPConn{ + EndpointURL: ep, + HTTP: server.Client(), }, SourceService: fakeBootstrapSource{ - fetchPack: func(context.Context, *gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(context.Context, gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { t.Fatal("unexpected fetch") return nil, nil //nolint:nilnil // test fake returns nil to signal no data }, diff --git a/internal/strategy/incremental/incremental.go b/internal/strategy/incremental/incremental.go index 6862ea26..09ab0326 100644 --- a/internal/strategy/incremental/incremental.go +++ b/internal/strategy/incremental/incremental.go @@ -19,9 +19,9 @@ import ( // Params holds the inputs for an incremental relay execution. type Params struct { - SourceConn *gitproto.Conn + SourceConn gitproto.Conn SourceService interface { - FetchPack(ctx context.Context, conn *gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, haves map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) + FetchPack(ctx context.Context, conn gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, haves map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) } TargetPusher interface { PushPack(ctx context.Context, cmds []gitproto.PushCommand, pack io.ReadCloser) error diff --git a/internal/strategy/incremental/incremental_test.go b/internal/strategy/incremental/incremental_test.go index 317cf0ee..bc3f3900 100644 --- a/internal/strategy/incremental/incremental_test.go +++ b/internal/strategy/incremental/incremental_test.go @@ -113,12 +113,12 @@ func TestPlansToPushPlans(t *testing.T) { } type fakeSourceService struct { - fetchPack func(context.Context, *gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) + fetchPack func(context.Context, gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) } func (f fakeSourceService) FetchPack( ctx context.Context, - conn *gitproto.Conn, + conn gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, ) (io.ReadCloser, error) { @@ -179,7 +179,7 @@ func TestExecuteIncrementalRelayUsesTargetRefsAsHaves(t *testing.T) { params := Params{ SourceService: fakeSourceService{ - fetchPack: func(_ context.Context, _ *gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { gotDesired = desired gotHaves = targetRefs return io.NopCloser(bytes.NewReader([]byte("PACK"))), nil @@ -242,7 +242,7 @@ func TestExecuteFullTagCreateRelayOmitsHaves(t *testing.T) { params := Params{ SourceService: fakeSourceService{ - fetchPack: func(_ context.Context, _ *gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { gotHaves = targetRefs return io.NopCloser(bytes.NewReader([]byte("PACK"))), nil }, @@ -297,7 +297,7 @@ func TestExecuteIncrementalRelayClosesPackOnPushError(t *testing.T) { _, err := Execute(context.Background(), Params{ SourceService: fakeSourceService{ - fetchPack: func(_ context.Context, _ *gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { return pack, nil }, }, @@ -344,7 +344,7 @@ func TestExecuteIncrementalRelayClosesPackOnReadInterruption(t *testing.T) { _, err := Execute(context.Background(), Params{ SourceService: fakeSourceService{ - fetchPack: func(_ context.Context, _ *gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, _ map[plumbing.ReferenceName]gitproto.DesiredRef, _ map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { return pack, nil }, }, diff --git a/internal/strategy/materialized/materialized.go b/internal/strategy/materialized/materialized.go index 6e7bb1dd..a15c124a 100644 --- a/internal/strategy/materialized/materialized.go +++ b/internal/strategy/materialized/materialized.go @@ -20,9 +20,9 @@ import ( // Params holds the inputs for a materialized push. type Params struct { Store storer.Storer - SourceConn *gitproto.Conn + SourceConn gitproto.Conn SourceService interface { - FetchToStore(ctx context.Context, store storer.Storer, conn *gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, haves map[plumbing.ReferenceName]plumbing.Hash) error + FetchToStore(ctx context.Context, store storer.Storer, conn gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, haves map[plumbing.ReferenceName]plumbing.Hash) error } TargetPusher interface { PushObjects(ctx context.Context, cmds []gitproto.PushCommand, store storer.Storer, hashes []plumbing.Hash) error diff --git a/internal/strategy/replicate/replicate.go b/internal/strategy/replicate/replicate.go index 8886ef16..70f1b600 100644 --- a/internal/strategy/replicate/replicate.go +++ b/internal/strategy/replicate/replicate.go @@ -17,9 +17,9 @@ import ( // Params holds the inputs for a replication relay execution. type Params struct { - SourceConn *gitproto.Conn + SourceConn gitproto.Conn SourceService interface { - FetchPack(ctx context.Context, conn *gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, haves map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) + FetchPack(ctx context.Context, conn gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, haves map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) } TargetPusher interface { PushPack(ctx context.Context, cmds []gitproto.PushCommand, pack io.ReadCloser) error diff --git a/internal/strategy/replicate/replicate_test.go b/internal/strategy/replicate/replicate_test.go index 5751aa7f..5d4892c3 100644 --- a/internal/strategy/replicate/replicate_test.go +++ b/internal/strategy/replicate/replicate_test.go @@ -13,12 +13,12 @@ import ( ) type fakeSourceService struct { - fetchPack func(context.Context, *gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) + fetchPack func(context.Context, gitproto.Conn, map[plumbing.ReferenceName]gitproto.DesiredRef, map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) } func (f fakeSourceService) FetchPack( ctx context.Context, - conn *gitproto.Conn, + conn gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash, ) (io.ReadCloser, error) { @@ -51,7 +51,7 @@ func TestExecuteReplicateRelaysUpdatesAndDeletesSeparately(t *testing.T) { result, err := Execute(context.Background(), Params{ SourceService: fakeSourceService{ - fetchPack: func(_ context.Context, _ *gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { + fetchPack: func(_ context.Context, _ gitproto.Conn, desired map[plumbing.ReferenceName]gitproto.DesiredRef, targetRefs map[plumbing.ReferenceName]plumbing.Hash) (io.ReadCloser, error) { gotDesired = desired gotHaves = targetRefs return io.NopCloser(bytes.NewReader([]byte("PACK"))), nil diff --git a/internal/syncer/auth_test.go b/internal/syncer/auth_test.go index 0c5be333..7e2019a6 100644 --- a/internal/syncer/auth_test.go +++ b/internal/syncer/auth_test.go @@ -15,6 +15,7 @@ import ( transporthttp "github.com/go-git/go-git/v6/plumbing/transport/http" "entire.io/entire/git-sync/internal/auth" + "entire.io/entire/git-sync/internal/gitproto" ) func TestResolveAuthMethodPrefersExplicitToken(t *testing.T) { @@ -47,7 +48,7 @@ func TestResolveAuthMethodPrefersExplicitToken(t *testing.T) { } } -func TestNewConnSkipTLSVerify(t *testing.T) { +func TestNewHTTPConnSkipTLSVerify(t *testing.T) { stats := newStats(false) conn, err := newConn(Endpoint{ URL: "https://example.com/repo.git", @@ -56,9 +57,13 @@ func TestNewConnSkipTLSVerify(t *testing.T) { if err != nil { t.Fatalf("new conn: %v", err) } - rt, ok := conn.HTTP.Transport.(*countingRoundTripper) + httpConn, ok := conn.(*gitproto.HTTPConn) if !ok { - t.Fatalf("expected countingRoundTripper, got %T", conn.HTTP.Transport) + t.Fatalf("expected *gitproto.HTTPConn, got %T", conn) + } + rt, ok := httpConn.HTTP.Transport.(*countingRoundTripper) + if !ok { + t.Fatalf("expected countingRoundTripper, got %T", httpConn.HTTP.Transport) } base, ok := rt.base.(*http.Transport) if !ok { @@ -69,7 +74,7 @@ func TestNewConnSkipTLSVerify(t *testing.T) { } } -func TestNewConnUsesProvidedHTTPClient(t *testing.T) { +func TestNewHTTPConnUsesProvidedHTTPClient(t *testing.T) { stats := newStats(false) baseTransport := http.DefaultTransport baseClient := &http.Client{Transport: baseTransport} @@ -78,12 +83,16 @@ func TestNewConnUsesProvidedHTTPClient(t *testing.T) { if err != nil { t.Fatalf("new conn: %v", err) } - if conn.HTTP == baseClient { + httpConn, ok := conn.(*gitproto.HTTPConn) + if !ok { + t.Fatalf("expected *gitproto.HTTPConn, got %T", conn) + } + if httpConn.HTTP == baseClient { t.Fatalf("expected cloned HTTP client, got original pointer") } - rt, ok := conn.HTTP.Transport.(*countingRoundTripper) + rt, ok := httpConn.HTTP.Transport.(*countingRoundTripper) if !ok { - t.Fatalf("expected countingRoundTripper, got %T", conn.HTTP.Transport) + t.Fatalf("expected countingRoundTripper, got %T", httpConn.HTTP.Transport) } if rt.base != baseTransport { t.Fatalf("wrapped base transport = %T, want %T", rt.base, baseTransport) diff --git a/internal/syncer/ssh_integration_test.go b/internal/syncer/ssh_integration_test.go new file mode 100644 index 00000000..3b3162d6 --- /dev/null +++ b/internal/syncer/ssh_integration_test.go @@ -0,0 +1,106 @@ +package syncer + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "entire.io/entire/git-sync/internal/gitproto" + "github.com/go-git/go-git/v6/plumbing" +) + +func TestRun_IntegrationSyncOverSSHShimV2(t *testing.T) { + root := t.TempDir() + sourceBare := filepath.Join(root, "source.git") + targetBare := filepath.Join(root, "target.git") + worktree := filepath.Join(root, "worktree") + logFile := filepath.Join(root, "ssh.log") + shim := filepath.Join(root, "ssh-shim.sh") + + runGit(t, root, "init", "--bare", sourceBare) + runGit(t, root, "init", "--bare", targetBare) + runGit(t, root, "init", worktree) + runGit(t, worktree, "config", "user.name", "test") + runGit(t, worktree, "config", "user.email", "test@example.com") + writeFile(t, filepath.Join(worktree, "tracked.txt"), "hello over ssh\n") + runGit(t, worktree, "add", "tracked.txt") + runGit(t, worktree, "commit", "-m", "initial") + runGit(t, worktree, "remote", "add", "origin", sourceBare) + runGit(t, worktree, "push", "origin", "HEAD:refs/heads/"+testBranch) + + shimBody := strings.Join([]string{ + "#!/bin/sh", + "if [ \"$1\" = \"-o\" ]; then", + " shift 2", + "fi", + "if [ \"$1\" = \"-p\" ]; then", + " shift 2", + "fi", + "dest=\"$1\"", + "remote=\"$2\"", + "printf '%s\\t%s\\n' \"$dest\" \"$remote\" >>" + shSingleQuote(logFile), + "exec sh -c \"$remote\"", + }, "\n") + if err := os.WriteFile(shim, []byte(shimBody), 0o755); err != nil { + t.Fatalf("write ssh shim: %v", err) + } + + orig := gitproto.SSHLookPath + t.Cleanup(func() { gitproto.SSHLookPath = orig }) + gitproto.SSHLookPath = func(string) (string, error) { + return shim, nil + } + + result, err := Run(context.Background(), Config{ + Source: Endpoint{URL: "ssh://example.com" + sourceBare}, + Target: Endpoint{URL: "ssh://example.com" + targetBare}, + ProtocolMode: protocolModeAuto, + }) + if err != nil { + t.Fatalf("Run over SSH shim failed: %v", err) + } + if result.Protocol != protocolModeV2 { + t.Fatalf("result.Protocol = %q, want %q", result.Protocol, protocolModeV2) + } + + assertGitRefEqual(t, sourceBare, targetBare, plumbing.NewBranchReferenceName(testBranch)) + + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("read ssh log: %v", err) + } + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + if len(lines) < 5 { + t.Fatalf("expected at least 5 ssh invocations, got %d\n%s", len(lines), string(data)) + } + + uploadPackCalls := 0 + receivePackCalls := 0 + v2Calls := 0 + for _, line := range lines { + switch { + case strings.Contains(line, "git-upload-pack"): + uploadPackCalls++ + case strings.Contains(line, "git-receive-pack"): + receivePackCalls++ + } + if strings.Contains(line, "GIT_PROTOCOL='version=2' git-upload-pack") { + v2Calls++ + } + } + if uploadPackCalls < 3 { + t.Fatalf("expected repeated source upload-pack RPCs, got %d\n%s", uploadPackCalls, string(data)) + } + if receivePackCalls < 2 { + t.Fatalf("expected target receive-pack discovery + push, got %d\n%s", receivePackCalls, string(data)) + } + if v2Calls == 0 { + t.Fatalf("expected at least one protocol v2 upload-pack call\n%s", string(data)) + } +} + +func shSingleQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", `'"'"'`) + "'" +} diff --git a/internal/syncer/syncer.go b/internal/syncer/syncer.go index 1f1d155e..862cd3b7 100644 --- a/internal/syncer/syncer.go +++ b/internal/syncer/syncer.go @@ -349,11 +349,20 @@ func measurementLine(m Measurement) []string { // --- Session setup --- -func newConn(raw Endpoint, label string, stats *statsCollector, httpClient *http.Client) (*gitproto.Conn, error) { +func newConn(raw Endpoint, label string, stats *statsCollector, httpClient *http.Client) (gitproto.Conn, error) { //nolint:ireturn // transport selection intentionally returns the shared connection interface ep, err := transport.ParseURL(raw.URL) if err != nil { return nil, fmt.Errorf("parse endpoint: %w", err) } + switch ep.Scheme { + case "ssh", "git+ssh": + stats.setSideDisplay(label, hostnameFromURL(raw.URL)) + conn, err := gitproto.NewSSHConn(ep, label) + if err != nil { + return nil, fmt.Errorf("new SSH connection: %w", err) + } + return conn, nil + } authEp := auth.Endpoint{ Username: raw.Username, Token: raw.Token, @@ -366,7 +375,7 @@ func newConn(raw Endpoint, label string, stats *statsCollector, httpClient *http } stats.setSideDisplay(label, hostnameFromURL(raw.URL)) client := instrumentHTTPClient(httpClient, raw.SkipTLSVerify, label, stats) - conn := gitproto.NewConnWithHTTPClient(ep, label, authMethod, client) + conn := gitproto.NewHTTPConnWithClient(ep, label, authMethod, client) conn.FollowInfoRefsRedirect = raw.FollowInfoRefsRedirect return conn, nil } @@ -531,6 +540,25 @@ func needsLocalSourceClosure( return false } +func sshStatsWarning(cfg Config, sourceConn, targetConn gitproto.Conn) string { + if !cfg.Progress && !cfg.ShowStats { + return "" + } + hasSSH := false + if _, ok := sourceConn.(*gitproto.SSHConn); ok { + hasSSH = true + } + if !hasSSH && targetConn != nil { + if _, ok := targetConn.(*gitproto.SSHConn); ok { + hasSSH = true + } + } + if !hasSSH { + return "" + } + return "warning: SSH transport does not yet expose byte-counted throughput; --progress and --stats output will omit SSH transfer bytes" +} + // --- Session setup (issue #12) --- // syncSession holds the shared state for a sync operation, reducing @@ -539,7 +567,7 @@ type syncSession struct { cfg Config stats *statsCollector logger *slog.Logger - sourceConn *gitproto.Conn + sourceConn gitproto.Conn sourceService *gitproto.RefService sourceRefMap map[plumbing.ReferenceName]plumbing.Hash target *targetSession @@ -556,6 +584,12 @@ func (s *syncSession) finish() { if s.progress != nil { s.progress.terminate() } + if s.sourceConn != nil { + _ = s.sourceConn.Close() + } + if s.target != nil && s.target.conn != nil { + _ = s.target.conn.Close() + } } // notice surfaces a one-line human-readable event during a sync. When @@ -571,7 +605,7 @@ func (s *syncSession) notice(msg string) { } type targetSession struct { - conn *gitproto.Conn + conn gitproto.Conn adv *packp.AdvRefs refMap map[plumbing.ReferenceName]plumbing.Hash features gitproto.TargetFeatures @@ -614,6 +648,22 @@ func newSession(ctx context.Context, cfg Config, needTarget bool) (*syncSession, stats: newStats(cfg.ShowStats), measurementDone: startMeasurement(cfg.MeasureMemory), } + var warnedSSHStats bool + warnSSHStats := func(sourceConn, targetConn gitproto.Conn) { + if warnedSSHStats { + return + } + warning := sshStatsWarning(cfg, sourceConn, targetConn) + if warning == "" { + return + } + warnedSSHStats = true + out := cfg.progressOut + if out == nil { + out = os.Stderr + } + fmt.Fprintln(out, warning) + } if cfg.Verbose { s.logger = slog.New(slog.NewTextHandler(&sessionStderr{s: s}, &slog.HandlerOptions{ Level: slog.LevelInfo, @@ -624,7 +674,8 @@ func newSession(ctx context.Context, cfg Config, needTarget bool) (*syncSession, if err != nil { return nil, fmt.Errorf("create source transport: %w", err) } - s.sourceConn.ProgressOut = &sessionStderr{s: s} + s.sourceConn.SetProgressWriter(&sessionStderr{s: s}) + warnSSHStats(s.sourceConn, nil) refPrefixes := planner.RefPrefixes(planConfig(cfg)) sourceRefs, sourceService, err := gitproto.ListSourceRefs(ctx, s.sourceConn, cfg.ProtocolMode, refPrefixes) @@ -640,7 +691,8 @@ func newSession(ctx context.Context, cfg Config, needTarget bool) (*syncSession, if err != nil { return nil, fmt.Errorf("create target transport: %w", err) } - targetConn.ProgressOut = &sessionStderr{s: s} + targetConn.SetProgressWriter(&sessionStderr{s: s}) + warnSSHStats(s.sourceConn, targetConn) targetAdv, err := gitproto.AdvertisedRefsV1(ctx, targetConn, transport.ReceivePackService) if err != nil { return nil, fmt.Errorf("list target refs: %w", err) @@ -669,7 +721,6 @@ func newSession(ctx context.Context, cfg Config, needTarget bool) (*syncSession, } } } - // Start the live progress ticker only after auth resolution and the // initial ref-listing round trips have completed. The auth path may // shell out to `git credential fill`, which inherits our stderr and diff --git a/internal/syncer/syncer_test.go b/internal/syncer/syncer_test.go index 7a63c37d..2306f7c6 100644 --- a/internal/syncer/syncer_test.go +++ b/internal/syncer/syncer_test.go @@ -2,11 +2,14 @@ package syncer import ( "context" + "os" + "path/filepath" "strings" "testing" "github.com/go-git/go-git/v6/plumbing" + "entire.io/entire/git-sync/internal/gitproto" bstrap "entire.io/entire/git-sync/internal/strategy/bootstrap" ) @@ -212,14 +215,18 @@ func TestProbeWithoutTargetIgnoresEndpointEqualityCheck(t *testing.T) { // TestNewConn_PropagatesFollowInfoRefsRedirect proves the plumbing from // Endpoint → gitproto.Conn is in place. Without this the flag on // Endpoint is dead config. -func TestNewConn_PropagatesFollowInfoRefsRedirect(t *testing.T) { +func TestNewHTTPConn_PropagatesFollowInfoRefsRedirect(t *testing.T) { stats := newStats(false) off, err := newConn(Endpoint{URL: "https://node.example/repo.git"}, "target", stats, nil) if err != nil { t.Fatalf("new conn (off): %v", err) } - if off.FollowInfoRefsRedirect { + offHTTP, ok := off.(*gitproto.HTTPConn) + if !ok { + t.Fatalf("expected *gitproto.HTTPConn, got %T", off) + } + if offHTTP.FollowInfoRefsRedirect { t.Error("FollowInfoRefsRedirect should default to false") } @@ -227,7 +234,84 @@ func TestNewConn_PropagatesFollowInfoRefsRedirect(t *testing.T) { if err != nil { t.Fatalf("new conn (on): %v", err) } - if !on.FollowInfoRefsRedirect { + onHTTP, ok := on.(*gitproto.HTTPConn) + if !ok { + t.Fatalf("expected *gitproto.HTTPConn, got %T", on) + } + if !onHTTP.FollowInfoRefsRedirect { t.Error("FollowInfoRefsRedirect was not propagated from Endpoint to Conn") } } + +func TestNewConnBuildsSSHTransport(t *testing.T) { + orig := gitproto.SSHLookPath + t.Cleanup(func() { gitproto.SSHLookPath = orig }) + script := filepath.Join(t.TempDir(), "ssh-stub.sh") + if err := os.WriteFile(script, []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil { + t.Fatalf("write ssh stub: %v", err) + } + gitproto.SSHLookPath = func(string) (string, error) { + return script, nil + } + + stats := newStats(false) + tests := []string{ + "ssh://example.com/repo.git", + "git+ssh://example.com/repo.git", + "git@example.com:repo.git", + } + for _, raw := range tests { + t.Run(raw, func(t *testing.T) { + conn, err := newConn(Endpoint{URL: raw}, "source", stats, nil) + if err != nil { + t.Fatalf("new conn: %v", err) + } + if _, ok := conn.(*gitproto.SSHConn); !ok { + t.Fatalf("expected *gitproto.SSHConn, got %T", conn) + } + }) + } +} + +func TestSSHStatsWarning(t *testing.T) { + tests := []struct { + name string + cfg Config + source gitproto.Conn + target gitproto.Conn + want bool + }{ + { + name: "no flags", + cfg: Config{}, + source: &gitproto.SSHConn{}, + }, + { + name: "http only", + cfg: Config{Progress: true}, + source: &gitproto.HTTPConn{}, + target: &gitproto.HTTPConn{}, + }, + { + name: "progress with ssh source", + cfg: Config{Progress: true}, + source: &gitproto.SSHConn{}, + want: true, + }, + { + name: "show stats with ssh target", + cfg: Config{ShowStats: true}, + source: &gitproto.HTTPConn{}, + target: &gitproto.SSHConn{}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sshStatsWarning(tt.cfg, tt.source, tt.target) + if (got != "") != tt.want { + t.Fatalf("sshStatsWarning() = %q, want warning=%t", got, tt.want) + } + }) + } +}