From 13686421f1667f00e1c55b3ee65127808d34a922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Henrique=20Guard=C3=A3o=20Gandarez?= Date: Fri, 13 Mar 2026 17:31:02 -0300 Subject: [PATCH] feat: Add auto-detect ESP chip via security info and improve test coverage Refactor chip detection from magic-register-only approach to use GET_SECURITY_INFO command with ChipID fallback via readReg. Add security_info.go for parsing security info responses (12/20 byte variants) and security flag decoding. Expand protocol layer with securityInfo command support and improve test coverage across flasher, protocol, and security_info packages. Co-Authored-By: Claude Opus 4.6 --- pkg/espflasher/chip.go | 44 +----- pkg/espflasher/chip_test.go | 4 +- pkg/espflasher/flasher.go | 113 ++++++++------- pkg/espflasher/flasher_test.go | 127 +++++++++++++++++ pkg/espflasher/protocol.go | 82 +++++++---- pkg/espflasher/protocol_test.go | 167 +++++++++++++++++++++- pkg/espflasher/security_info.go | 97 +++++++++++++ pkg/espflasher/security_info_test.go | 204 +++++++++++++++++++++++++++ pkg/espflasher/slip.go | 7 +- 9 files changed, 718 insertions(+), 127 deletions(-) create mode 100644 pkg/espflasher/security_info.go create mode 100644 pkg/espflasher/security_info_test.go diff --git a/pkg/espflasher/chip.go b/pkg/espflasher/chip.go index d96b34b..4025c6a 100644 --- a/pkg/espflasher/chip.go +++ b/pkg/espflasher/chip.go @@ -73,6 +73,9 @@ type chipDef struct { SPIMISOOffs uint32 SPIW0Offs uint32 + // SecureDownloadMode is true if esptool detects the ROM is in Secure Download Mode + SecureDownloadMode bool + // SPIAddrRegMSB: if true, SPI peripheral sends from MSB of 32-bit register. SPIAddrRegMSB bool @@ -129,45 +132,8 @@ var chipDefs = map[ChipType]*chipDef{ ChipESP32H2: defESP32H2, } -// detectChip reads the chip magic register or chip ID to identify the -// connected ESP device. -func detectChip(c *conn) (*chipDef, error) { - // First try reading the magic register (works for all chips) - magic, err := c.readReg(chipDetectMagicRegAddr) - if err != nil { - return nil, fmt.Errorf("read chip detect register: %w", err) - } - - // Check magic value against known chips - for _, def := range chipDefs { - if def.UsesMagicValue && def.MagicValue == magic { - return def, nil - } - } - - // For newer chips (ESP32-S3, ESP32-C3, etc.), the magic register value - // may not match. These chips use chip ID from the security info instead. - // However, reading chip ID requires the GET_SECURITY_INFO command which - // may not work before we know the chip type. - // - // As a fallback, we can also detect by the UART date register value, - // or by trying known magic values for newer chips. - // Newer chips have different magic values at this register. - - // Try matching by non-magic-value chips - for _, def := range chipDefs { - if !def.UsesMagicValue { - // For these chips, try to read a chip-specific register to verify - // We can't easily do chip_id check without knowing the chip first, - // so try to match by reading registers at known addresses. - continue - } - } - - return nil, &ChipDetectError{MagicValue: magic} -} - -// detectChipByMagic returns the chip definition matching the given magic value. +// detectChipByMagic returns the chip definition matching the given magic value, +// or nil if no match is found. func detectChipByMagic(magic uint32) *chipDef { for _, def := range chipDefs { if def.UsesMagicValue && def.MagicValue == magic { diff --git a/pkg/espflasher/chip_test.go b/pkg/espflasher/chip_test.go index ae92ea9..0c17025 100644 --- a/pkg/espflasher/chip_test.go +++ b/pkg/espflasher/chip_test.go @@ -56,10 +56,10 @@ func TestChipDefsCompleteness(t *testing.T) { if def.Name == "" { t.Errorf("chipDefs[%s].Name is empty", ct) } - if def.FlashSizes == nil || len(def.FlashSizes) == 0 { + if len(def.FlashSizes) == 0 { t.Errorf("chipDefs[%s].FlashSizes is empty", ct) } - if def.FlashFrequency == nil || len(def.FlashFrequency) == 0 { + if len(def.FlashFrequency) == 0 { t.Errorf("chipDefs[%s].FlashFrequency is empty", ct) } } diff --git a/pkg/espflasher/flasher.go b/pkg/espflasher/flasher.go index f9107e3..72be7e2 100644 --- a/pkg/espflasher/flasher.go +++ b/pkg/espflasher/flasher.go @@ -85,14 +85,38 @@ func DefaultOptions() *FlasherOptions { } } +// connection defines the low-level protocol operations for communicating +// with an ESP bootloader over a serial connection. +type connection interface { + sync() (uint32, error) + readReg(addr uint32) (uint32, error) + writeReg(addr, value, mask, delayUS uint32) error + securityInfo() ([]byte, error) + flashBegin(size, offset uint32, encrypted bool) error + flashData(block []byte, seq uint32) error + flashEnd(reboot bool) error + flashDeflBegin(uncompSize, compSize, offset uint32, encrypted bool) error + flashDeflData(block []byte, seq uint32) error + flashDeflEnd(reboot bool) error + flashMD5(addr, size uint32) ([]byte, error) + flashWriteSize() uint32 + spiAttach(value uint32) error + spiSetParams(totalSize, blockSize, sectorSize, pageSize uint32) error + changeBaud(newBaud, oldBaud uint32) error + eraseFlash() error + eraseRegion(offset, size uint32) error + flushInput() + isStub() bool + setSupportsEncryptedFlash(v bool) +} + // Flasher manages the connection to an ESP device and provides // high-level flash operations. type Flasher struct { port serial.Port - conn *conn + conn connection chip *chipDef opts *FlasherOptions - isStub bool portStr string } @@ -175,8 +199,6 @@ func (f *Flasher) connect() error { attempts = 7 } - var lastErr error - for attempt := 0; attempt < attempts; attempt++ { // Reset the chip into bootloader mode switch f.opts.ResetMode { @@ -194,12 +216,11 @@ func (f *Flasher) connect() error { // Try to sync with the bootloader time.Sleep(100 * time.Millisecond) // Give bootloader time to start f.conn.flushInput() - for syncAttempt := 0; syncAttempt < 5; syncAttempt++ { + for range 5 { _, err := f.conn.sync() if err == nil { goto synced } - lastErr = err time.Sleep(50 * time.Millisecond) } } @@ -224,50 +245,46 @@ synced: f.chip = def } - _ = lastErr // suppress unused warning - f.logf("Detected chip: %s", f.chip.Name) // Propagate chip capabilities to the connection layer. - f.conn.supportsEncryptedFlash = f.chip.SupportsEncryptedFlash + f.conn.setSupportsEncryptedFlash(f.chip.SupportsEncryptedFlash) return nil } // detectChip identifies the connected ESP chip. func (f *Flasher) detectChip() (*chipDef, error) { - // Try reading the magic register first - magic, err := f.conn.readReg(chipDetectMagicRegAddr) + si, err := f.readSecurityInfo() if err != nil { - return nil, fmt.Errorf("detect chip: %w", err) + f.logf("unable to read security info: %s", err) } - // Check magic value for older chips (ESP8266, ESP32, ESP32-S2) - if def := detectChipByMagic(magic); def != nil { - return def, nil - } + for _, def := range chipDefs { + if def.UsesMagicValue { + continue + } - // For newer chips, try to match the magic value against known values - // These chips have different magic values at the detect register - // Source: esptool chip detection logic - type magicEntry struct { - magic uint32 - def *chipDef + if si != nil && si.ChipID != nil && *si.ChipID == uint32(def.ImageChipID) { + def.SecureDownloadMode = si.ParsedFlags.SecureDownloadEnable + return def, nil + } } - // Known magic values for newer chips (from esptool source) - newChipMagics := []magicEntry{ - {0x6F51306F, defESP32C2}, // ESP32-C2 - {0x1B31506F, defESP32C3}, // ESP32-C3 - {0x0DA1806F, defESP32C6}, // ESP32-C6 - {0xD7B73E80, defESP32H2}, // ESP32-H2 - {0x09, defESP32S3}, // ESP32-S3 returns chip_id + // Otherwise, try to read the chip magic value to verify the chip type (ESP8266, ESP32, ESP32-S2) + magic, err := f.conn.readReg(chipDetectMagicRegAddr) + if err != nil { + // Only ESP32-S2 does not support chip id detection + // and supports secure download mode + f.logf("unable to read chip magic value. Defaulting to ESP32-S2: %s", err) + + chipDefs[ChipESP32S2].SecureDownloadMode = si.ParsedFlags.SecureDownloadEnable + return chipDefs[ChipESP32S2], nil } - for _, entry := range newChipMagics { - if magic == entry.magic { - return entry.def, nil - } + // Check magic value for older chips (ESP8266, ESP32, ESP32-S2) + if def := detectChipByMagic(magic); def != nil { + return def, nil } return nil, &ChipDetectError{MagicValue: magic} @@ -332,7 +349,7 @@ func (f *Flasher) FlashImage(data []byte, offset uint32, progress ProgressFunc) } // Optionally switch to higher baud rate (not supported by ESP8266 ROM) - canChangeBaud := f.chip == nil || f.chip.ROMHasChangeBaud || f.conn.isStub + canChangeBaud := f.chip == nil || f.chip.ROMHasChangeBaud || f.conn.isStub() if canChangeBaud && f.opts.FlashBaudRate > 0 && f.opts.FlashBaudRate != f.opts.BaudRate { if err := f.changeBaud(f.opts.FlashBaudRate); err != nil { f.logf("Warning: could not change baud rate to %d: %v", f.opts.FlashBaudRate, err) @@ -341,7 +358,7 @@ func (f *Flasher) FlashImage(data []byte, offset uint32, progress ProgressFunc) } // Use compressed flash only if supported (ESP8266 ROM doesn't support it) - canCompress := f.chip == nil || f.chip.ROMHasCompressedFlash || f.conn.isStub + canCompress := f.chip == nil || f.chip.ROMHasCompressedFlash || f.conn.isStub() if f.opts.Compress && canCompress { return f.flashCompressed(data, offset, progress) } @@ -384,7 +401,7 @@ func (f *Flasher) FlashImages(images []ImagePart, progress ProgressFunc) error { } // Optionally switch to higher baud rate (not supported by ESP8266 ROM) - canChangeBaud := f.chip == nil || f.chip.ROMHasChangeBaud || f.conn.isStub + canChangeBaud := f.chip == nil || f.chip.ROMHasChangeBaud || f.conn.isStub() if canChangeBaud && f.opts.FlashBaudRate > 0 && f.opts.FlashBaudRate != f.opts.BaudRate { if err := f.changeBaud(f.opts.FlashBaudRate); err != nil { f.logf("Warning: could not change baud rate to %d: %v", f.opts.FlashBaudRate, err) @@ -392,7 +409,7 @@ func (f *Flasher) FlashImages(images []ImagePart, progress ProgressFunc) error { } // Determine if compressed flash is available - canCompress := f.chip == nil || f.chip.ROMHasCompressedFlash || f.conn.isStub + canCompress := f.chip == nil || f.chip.ROMHasCompressedFlash || f.conn.isStub() totalSize := 0 for _, img := range images { @@ -481,14 +498,11 @@ func (f *Flasher) flashCompressed(data []byte, offset uint32, progress ProgressF seq := uint32(0) for sent < compSize { - blockLen := compSize - sent - if blockLen > writeSize { - blockLen = writeSize - } + blockLen := min(compSize-sent, writeSize) block := compressed[sent : sent+blockLen] if err := f.conn.flashDeflData(block, seq); err != nil { - return fmt.Errorf("flash block %d: %w", seq, err) + return fmt.Errorf("flash block %d of %d: %w", seq, numBlocks, err) } sent += blockLen @@ -496,10 +510,7 @@ func (f *Flasher) flashCompressed(data []byte, offset uint32, progress ProgressF if progress != nil { // Map compressed bytes sent to approximate uncompressed progress - approxUncomp := int(float64(sent) / float64(compSize) * float64(uncompSize)) - if approxUncomp > int(uncompSize) { - approxUncomp = int(uncompSize) - } + approxUncomp := min(int(float64(sent)/float64(compSize)*float64(uncompSize)), int(uncompSize)) progress(approxUncomp, int(uncompSize)) } } @@ -512,7 +523,7 @@ func (f *Flasher) flashCompressed(data []byte, offset uint32, progress ProgressF f.logf("Flash complete. Verifying...") // Verify via MD5 if stub is running - if f.conn.isStub { + if f.conn.isStub() { if err := f.verifyMD5(data, offset); err != nil { return err } @@ -555,7 +566,7 @@ func (f *Flasher) flashUncompressed(data []byte, offset uint32, progress Progres } if err := f.conn.flashData(block, seq); err != nil { - return fmt.Errorf("flash block %d: %w", seq, err) + return fmt.Errorf("flash block %d of %d: %w", seq, numBlocks, err) } sent += blockLen @@ -574,7 +585,7 @@ func (f *Flasher) flashUncompressed(data []byte, offset uint32, progress Progres f.logf("Flash complete. Verifying...") // Verify via MD5 if stub is running - if f.conn.isStub { + if f.conn.isStub() { if err := f.verifyMD5(data, offset); err != nil { return err } @@ -607,7 +618,7 @@ func (f *Flasher) verifyMD5(data []byte, offset uint32) error { // This operation can take a significant amount of time (30-120 seconds). // Requires the stub loader to be running. func (f *Flasher) EraseFlash() error { - if !f.conn.isStub { + if !f.conn.isStub() { return &UnsupportedCommandError{Command: "erase flash (requires stub)"} } @@ -880,7 +891,7 @@ func flashSizeFromJEDEC(capByte uint8) string { // chip capacity. Returns a size string (e.g. "1MB", "4MB") that matches // the chip's FlashSizes map, or empty string if detection fails. func (f *Flasher) detectFlashSize() string { - if f.chip == nil || f.conn == nil || f.conn.port == nil { + if f.chip == nil || f.conn == nil { return "" } diff --git a/pkg/espflasher/flasher_test.go b/pkg/espflasher/flasher_test.go index 24ad4d7..15ab759 100644 --- a/pkg/espflasher/flasher_test.go +++ b/pkg/espflasher/flasher_test.go @@ -2,6 +2,8 @@ package espflasher import ( "bytes" + "encoding/binary" + "errors" "fmt" "testing" ) @@ -256,6 +258,131 @@ func TestDetectFlashSizeNilChip(t *testing.T) { } } +func TestDetectChipOlderChipsByMagic(t *testing.T) { + tests := []struct { + name string + magic uint32 + wantChip ChipType + }{ + {"ESP8266", 0xFFF0C101, ChipESP8266}, + {"ESP32", 0x00F01D83, ChipESP32}, + {"ESP32-S2", 0x000007C6, ChipESP32S2}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Return a 12-byte security info response (short format, no ChipID) + // so the chip ID loop doesn't match and falls through to magic detection. + secInfo := make([]byte, 12) + + mc := &mockConnection{ + securityInfoFunc: func() ([]byte, error) { + return secInfo, nil + }, + readRegFunc: func(addr uint32) (uint32, error) { + return tt.magic, nil + }, + } + f := &Flasher{ + conn: mc, + } + + def, err := f.detectChip() + if err != nil { + t.Fatalf("detectChip() error: %v", err) + } + if def.ChipType != tt.wantChip { + t.Errorf("detectChip() = %v, want %v", def.ChipType, tt.wantChip) + } + }) + } +} + +// makeSecurityInfo20 builds a 20-byte security info response (long format) +// with the given chipID encoded at byte offset 12 (little-endian). +// Layout: Flags(4) + B1-B8(8) + ChipID(4) + APIVersion(4). +func makeSecurityInfo20(chipID uint32) []byte { + buf := make([]byte, 20) + binary.LittleEndian.PutUint32(buf[12:16], chipID) + return buf +} + +func TestDetectChipNewerChips(t *testing.T) { + tests := map[string]struct { + secInfo []byte + expected ChipType + }{ + "ESP32-S3": { + secInfo: makeSecurityInfo20(9), + expected: ChipESP32S3, + }, + "ESP32-C2": { + secInfo: makeSecurityInfo20(12), + expected: ChipESP32C2, + }, + "ESP32-C3": { + secInfo: makeSecurityInfo20(5), + expected: ChipESP32C3, + }, + "ESP32-C6": { + secInfo: makeSecurityInfo20(13), + expected: ChipESP32C6, + }, + "ESP32-H2": { + secInfo: makeSecurityInfo20(16), + expected: ChipESP32H2, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + mc := &mockConnection{ + securityInfoFunc: func() ([]byte, error) { + return tt.secInfo, nil + }, + } + f := &Flasher{ + conn: mc, + } + + def, err := f.detectChip() + if err != nil { + t.Fatalf("detectChip() error: %v", err) + } + if def.ChipType != tt.expected { + t.Errorf("detectChip() = %v, want %v", def.ChipType, tt.expected) + } + }) + } +} + +func TestDetectChipReadRegError(t *testing.T) { + // When securityInfo returns no ChipID and readReg fails, + // detectChip falls back to ESP32-S2. + secInfo := make([]byte, 12) + + mc := &mockConnection{ + securityInfoFunc: func() ([]byte, error) { + return secInfo, nil + }, + readRegFunc: func(addr uint32) (uint32, error) { + return 0, errors.New("serial port disconnected") + }, + } + f := &Flasher{ + opts: DefaultOptions(), + conn: mc, + } + + def, err := f.detectChip() + if err != nil { + t.Fatalf("detectChip() should not error when readReg fails (falls back to ESP32-S2), got: %v", err) + } + if def.ChipType != ChipESP32S2 { + t.Errorf("detectChip() = %v, want ChipESP32S2 (fallback)", def.ChipType) + } +} + func TestFlashSizeFromJEDECMatchesChipSizes(t *testing.T) { // Verify that all JEDEC-detected sizes are present in at least one chip's // FlashSizes map. diff --git a/pkg/espflasher/protocol.go b/pkg/espflasher/protocol.go index 10f819e..db53231 100644 --- a/pkg/espflasher/protocol.go +++ b/pkg/espflasher/protocol.go @@ -11,22 +11,23 @@ import ( // ROM bootloader command opcodes. // These are shared across all ESP chip families. const ( - cmdFlashBegin byte = 0x02 // Start flash download - cmdFlashData byte = 0x03 // Flash data block - cmdFlashEnd byte = 0x04 // Finish flash download - cmdMemBegin byte = 0x05 // Start RAM download - cmdMemEnd byte = 0x06 // Finish RAM download / execute - cmdMemData byte = 0x07 // RAM data block - cmdSync byte = 0x08 // Sync with bootloader - cmdWriteReg byte = 0x09 // Write 32-bit memory-mapped register - cmdReadReg byte = 0x0A // Read 32-bit memory-mapped register - cmdSPISetParams byte = 0x0B // Configure SPI flash parameters - cmdSPIAttach byte = 0x0D // Attach SPI flash - cmdChangeBaud byte = 0x0F // Change UART baud rate - cmdFlashDeflBeg byte = 0x10 // Start compressed flash download - cmdFlashDeflData byte = 0x11 // Compressed flash data block - cmdFlashDeflEnd byte = 0x12 // Finish compressed flash download - cmdSPIFlashMD5 byte = 0x13 // Calculate MD5 of flash region + cmdFlashBegin byte = 0x02 // Start flash download + cmdFlashData byte = 0x03 // Flash data block + cmdFlashEnd byte = 0x04 // Finish flash download + cmdMemBegin byte = 0x05 // Start RAM download + cmdMemEnd byte = 0x06 // Finish RAM download / execute + cmdMemData byte = 0x07 // RAM data block + cmdSync byte = 0x08 // Sync with bootloader + cmdWriteReg byte = 0x09 // Write 32-bit memory-mapped register + cmdReadReg byte = 0x0A // Read 32-bit memory-mapped register + cmdSecurityInfoReg byte = 0x14 // Read security info (chip ID, flash encryption, etc.) + cmdSPISetParams byte = 0x0B // Configure SPI flash parameters + cmdSPIAttach byte = 0x0D // Attach SPI flash + cmdChangeBaud byte = 0x0F // Change UART baud rate + cmdFlashDeflBeg byte = 0x10 // Start compressed flash download + cmdFlashDeflData byte = 0x11 // Compressed flash data block + cmdFlashDeflEnd byte = 0x12 // Finish compressed flash download + cmdSPIFlashMD5 byte = 0x13 // Calculate MD5 of flash region // Stub-only commands (available after stub loader is running) cmdEraseFlash byte = 0xD0 // Erase entire flash @@ -76,13 +77,23 @@ const ( type conn struct { port serial.Port reader *slipReader - isStub bool + stub bool // supportsEncryptedFlash indicates the ROM supports the 5th parameter // (encrypted flag) in flash_begin/flash_defl_begin commands. // Set based on chip type after detection. supportsEncryptedFlash bool } +// isStub returns whether the stub loader is running. +func (c *conn) isStub() bool { + return c.stub +} + +// setSupportsEncryptedFlash sets whether the ROM supports encrypted flash commands. +func (c *conn) setSupportsEncryptedFlash(v bool) { + c.supportsEncryptedFlash = v +} + // newConn creates a new protocol connection over the given serial port. func newConn(port serial.Port) *conn { return &conn{ @@ -136,7 +147,7 @@ func (c *conn) command(opcode byte, data []byte, chk uint32, timeout time.Durati } // Try to get a matching response (filter for correct opcode) - for retry := 0; retry < 100; retry++ { + for range 100 { frame, err := c.reader.ReadFrame(timeout) if err != nil { return nil, err @@ -236,7 +247,7 @@ func (c *conn) sync() (uint32, error) { } // Read remaining sync responses (ROM sends up to 7 more) - for i := 0; i < 7; i++ { + for range 7 { c.command(0, nil, 0, syncTimeout, true) //nolint:errcheck } @@ -268,6 +279,27 @@ func (c *conn) writeReg(addr, value, mask, delayUS uint32) error { return err } +// securityInfo reads security-related information from the device. +// Try with 20 bytes first (most chips), fallback to 12 bytes (ESP32-S2). +func (c *conn) securityInfo() ([]byte, error) { + data := make([]byte, 20) + + result, err := c.checkCommand("get security info", cmdSecurityInfoReg, data, 0, defaultTimeout, 20) + if err == nil { + // early return if successful with 20-byte response + return result, nil + } + + // fallback to 12-byte response for older ROMs (ESP32-S2) + data = make([]byte, 12) + result, err = c.checkCommand("get security info", cmdSecurityInfoReg, data, 0, defaultTimeout, 12) + if err != nil { + return nil, err + } + + return result, nil +} + // memBegin starts a RAM download operation. func (c *conn) memBegin(size, blocks, blockSize, offset uint32) error { data := make([]byte, 16) @@ -316,7 +348,7 @@ func (c *conn) flashBegin(size, offset uint32, encrypted bool) error { numBlocks := (size + writeSize - 1) / writeSize eraseSize := size - if !c.isStub { + if !c.stub { eraseSize = (numBlocks * writeSize) } @@ -326,7 +358,7 @@ func (c *conn) flashBegin(size, offset uint32, encrypted bool) error { // flag). ESP8266 and original ESP32 ROM only accept 4 parameters (16 bytes). // Sending 20 bytes to those older ROMs causes error 0x05 (invalid message). paramLen := 16 - if c.supportsEncryptedFlash || c.isStub { + if c.supportsEncryptedFlash || c.stub { paramLen = 20 } data := make([]byte, paramLen) @@ -372,7 +404,7 @@ func (c *conn) flashDeflBegin(uncompSize, compSize, offset uint32, encrypted boo numBlocks := (compSize + writeSize - 1) / writeSize var writeArg uint32 - if c.isStub { + if c.stub { writeArg = uncompSize // stub handles erase internally } else { eraseBlocks := (uncompSize + writeSize - 1) / writeSize @@ -384,7 +416,7 @@ func (c *conn) flashDeflBegin(uncompSize, compSize, offset uint32, encrypted boo // ESP32-S2 and newer ROM bootloaders support a 5th parameter (encrypted // flag). ESP8266 and original ESP32 ROM only accept 4 parameters (16 bytes). paramLen := 16 - if c.supportsEncryptedFlash || c.isStub { + if c.supportsEncryptedFlash || c.stub { paramLen = 20 } data := make([]byte, paramLen) @@ -476,7 +508,7 @@ func (c *conn) changeBaud(newBaud, oldBaud uint32) error { binary.LittleEndian.PutUint32(data[0:4], newBaud) // ROM bootloader ignores the second parameter; stub uses it to know // the current baud rate. Send 0 for ROM to match esptool behavior. - if c.isStub { + if c.stub { binary.LittleEndian.PutUint32(data[4:8], oldBaud) } @@ -503,7 +535,7 @@ func (c *conn) eraseRegion(offset, size uint32) error { // flashWriteSize returns the appropriate block size based on loader type. func (c *conn) flashWriteSize() uint32 { - if c.isStub { + if c.stub { return flashWriteSizeStub } return flashWriteSizeROM diff --git a/pkg/espflasher/protocol_test.go b/pkg/espflasher/protocol_test.go index ef0d133..9c84a6e 100644 --- a/pkg/espflasher/protocol_test.go +++ b/pkg/espflasher/protocol_test.go @@ -52,12 +52,12 @@ func TestChecksum(t *testing.T) { } func TestFlashWriteSize(t *testing.T) { - romConn := &conn{isStub: false} + romConn := &conn{stub: false} if romConn.flashWriteSize() != flashWriteSizeROM { t.Errorf("ROM write size = %d, want %d", romConn.flashWriteSize(), flashWriteSizeROM) } - stubConn := &conn{isStub: true} + stubConn := &conn{stub: true} if stubConn.flashWriteSize() != flashWriteSizeStub { t.Errorf("Stub write size = %d, want %d", stubConn.flashWriteSize(), flashWriteSizeStub) } @@ -197,7 +197,7 @@ func TestFlashBeginParamLength(t *testing.T) { c := &conn{ port: mock, reader: newSlipReader(mock), - isStub: tt.isStub, + stub: tt.isStub, supportsEncryptedFlash: tt.supportsEncr, } @@ -249,7 +249,7 @@ func TestFlashDeflBeginParamLength(t *testing.T) { c := &conn{ port: mock, reader: newSlipReader(mock), - isStub: tt.isStub, + stub: tt.isStub, supportsEncryptedFlash: tt.supportsEncr, } @@ -301,7 +301,7 @@ func TestChangeBaudSecondParam(t *testing.T) { c := &conn{ port: mock, reader: newSlipReader(mock), - isStub: tt.isStub, + stub: tt.isStub, } _ = c.changeBaud(tt.newBaud, tt.oldBaud) @@ -421,6 +421,163 @@ func (m *mockPort) GetModemStatusBits() (*serial.ModemStatusBits, error) { retur func (m *mockPort) Break(t time.Duration) error { return nil } func (m *mockPort) Drain() error { return nil } +// mockConnection implements the connection interface for testing. +type mockConnection struct { + syncFunc func() (uint32, error) + readRegFunc func(addr uint32) (uint32, error) + writeRegFunc func(addr, value, mask, delayUS uint32) error + securityInfoFunc func() ([]byte, error) + flashBeginFunc func(size, offset uint32, encrypted bool) error + flashDataFunc func(block []byte, seq uint32) error + flashEndFunc func(reboot bool) error + flashDeflBeginFunc func(uncompSize, compSize, offset uint32, encrypted bool) error + flashDeflDataFunc func(block []byte, seq uint32) error + flashDeflEndFunc func(reboot bool) error + flashMD5Func func(addr, size uint32) ([]byte, error) + flashWriteSizeFunc func() uint32 + spiAttachFunc func(value uint32) error + spiSetParamsFunc func(totalSize, blockSize, sectorSize, pageSize uint32) error + changeBaudFunc func(newBaud, oldBaud uint32) error + eraseFlashFunc func() error + eraseRegionFunc func(offset, size uint32) error + flushInputFunc func() + stubMode bool + supportsEncryptedFlashValue bool +} + +func (m *mockConnection) sync() (uint32, error) { + if m.syncFunc != nil { + return m.syncFunc() + } + return 0, nil +} + +func (m *mockConnection) readReg(addr uint32) (uint32, error) { + if m.readRegFunc != nil { + return m.readRegFunc(addr) + } + return 0, nil +} + +func (m *mockConnection) writeReg(addr, value, mask, delayUS uint32) error { + if m.writeRegFunc != nil { + return m.writeRegFunc(addr, value, mask, delayUS) + } + return nil +} + +func (m *mockConnection) securityInfo() ([]byte, error) { + if m.securityInfoFunc != nil { + return m.securityInfoFunc() + } + return nil, nil +} + +func (m *mockConnection) flashBegin(size, offset uint32, encrypted bool) error { + if m.flashBeginFunc != nil { + return m.flashBeginFunc(size, offset, encrypted) + } + return nil +} + +func (m *mockConnection) flashData(block []byte, seq uint32) error { + if m.flashDataFunc != nil { + return m.flashDataFunc(block, seq) + } + return nil +} + +func (m *mockConnection) flashEnd(reboot bool) error { + if m.flashEndFunc != nil { + return m.flashEndFunc(reboot) + } + return nil +} + +func (m *mockConnection) flashDeflBegin(uncompSize, compSize, offset uint32, encrypted bool) error { + if m.flashDeflBeginFunc != nil { + return m.flashDeflBeginFunc(uncompSize, compSize, offset, encrypted) + } + return nil +} + +func (m *mockConnection) flashDeflData(block []byte, seq uint32) error { + if m.flashDeflDataFunc != nil { + return m.flashDeflDataFunc(block, seq) + } + return nil +} + +func (m *mockConnection) flashDeflEnd(reboot bool) error { + if m.flashDeflEndFunc != nil { + return m.flashDeflEndFunc(reboot) + } + return nil +} + +func (m *mockConnection) flashMD5(addr, size uint32) ([]byte, error) { + if m.flashMD5Func != nil { + return m.flashMD5Func(addr, size) + } + return nil, nil +} + +func (m *mockConnection) flashWriteSize() uint32 { + if m.flashWriteSizeFunc != nil { + return m.flashWriteSizeFunc() + } + return flashWriteSizeROM +} + +func (m *mockConnection) spiAttach(value uint32) error { + if m.spiAttachFunc != nil { + return m.spiAttachFunc(value) + } + return nil +} + +func (m *mockConnection) spiSetParams(totalSize, blockSize, sectorSize, pageSize uint32) error { + if m.spiSetParamsFunc != nil { + return m.spiSetParamsFunc(totalSize, blockSize, sectorSize, pageSize) + } + return nil +} + +func (m *mockConnection) changeBaud(newBaud, oldBaud uint32) error { + if m.changeBaudFunc != nil { + return m.changeBaudFunc(newBaud, oldBaud) + } + return nil +} + +func (m *mockConnection) eraseFlash() error { + if m.eraseFlashFunc != nil { + return m.eraseFlashFunc() + } + return nil +} + +func (m *mockConnection) eraseRegion(offset, size uint32) error { + if m.eraseRegionFunc != nil { + return m.eraseRegionFunc(offset, size) + } + return nil +} + +func (m *mockConnection) flushInput() { + if m.flushInputFunc != nil { + m.flushInputFunc() + } +} + +func (m *mockConnection) isStub() bool { + return m.stubMode +} + +func (m *mockConnection) setSupportsEncryptedFlash(v bool) { + m.supportsEncryptedFlashValue = v +} + // makeFlashBeginResponse creates a mock response for flash begin command. func makeFlashBeginResponse() []byte { resp := make([]byte, 10) diff --git a/pkg/espflasher/security_info.go b/pkg/espflasher/security_info.go new file mode 100644 index 0000000..a5c034b --- /dev/null +++ b/pkg/espflasher/security_info.go @@ -0,0 +1,97 @@ +package espflasher + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type SecurityInfo struct { + Flags uint32 + FlashCryptCnt uint8 + KeyPurposes [7]uint8 + ChipID *uint32 + APIVersion *uint32 + ParsedFlags ParsedFlags +} + +type ParsedFlags struct { + SecureBootEn bool + SecureBootAggressiveRevoke bool + SecureDownloadEnable bool + SecureBootKeyRevoke0 bool + SecureBootKeyRevoke1 bool + SecureBootKeyRevoke2 bool + SoftDisJtag bool + HardDisJtag bool + DisUSB bool + DisDownloadDcache bool + DisDownloadIcache bool +} + +func (f *Flasher) readSecurityInfo() (*SecurityInfo, error) { + res, err := f.conn.securityInfo() + if err != nil { + return nil, err + } + + var si SecurityInfo + r := bytes.NewReader(res) + + switch len(res) { + case 20: + var full struct { + Flags uint32 + B1, B2, B3, B4, B5, B6, B7, B8 uint8 + ChipID uint32 + APIVersion uint32 + } + if err := binary.Read(r, binary.LittleEndian, &full); err != nil { + return nil, fmt.Errorf("parsing security info (20 bytes): %w", err) + } + chipID := full.ChipID + apiVersion := full.APIVersion + si = SecurityInfo{ + Flags: full.Flags, + FlashCryptCnt: full.B1, + KeyPurposes: [7]uint8{full.B2, full.B3, full.B4, full.B5, full.B6, full.B7, full.B8}, + ChipID: &chipID, + APIVersion: &apiVersion, + ParsedFlags: parseSecurityFlags(full.Flags), + } + case 12: + var short struct { + Flags uint32 + B1, B2, B3, B4, B5, B6, B7, B8 uint8 + } + if err := binary.Read(r, binary.LittleEndian, &short); err != nil { + return nil, fmt.Errorf("parsing security info (12 bytes): %w", err) + } + si = SecurityInfo{ + Flags: short.Flags, + FlashCryptCnt: short.B1, + KeyPurposes: [7]uint8{short.B2, short.B3, short.B4, short.B5, short.B6, short.B7, short.B8}, + ParsedFlags: parseSecurityFlags(short.Flags), + } + default: + return nil, fmt.Errorf("unexpected security info length: %d", len(res)) + } + + return &si, nil +} + +func parseSecurityFlags(flags uint32) ParsedFlags { + return ParsedFlags{ + SecureBootEn: flags&(1<<0) != 0, + SecureBootAggressiveRevoke: flags&(1<<1) != 0, + SecureDownloadEnable: flags&(1<<2) != 0, + SecureBootKeyRevoke0: flags&(1<<3) != 0, + SecureBootKeyRevoke1: flags&(1<<4) != 0, + SecureBootKeyRevoke2: flags&(1<<5) != 0, + SoftDisJtag: flags&(1<<6) != 0, + HardDisJtag: flags&(1<<7) != 0, + DisUSB: flags&(1<<8) != 0, + DisDownloadDcache: flags&(1<<9) != 0, + DisDownloadIcache: flags&(1<<10) != 0, + } +} diff --git a/pkg/espflasher/security_info_test.go b/pkg/espflasher/security_info_test.go new file mode 100644 index 0000000..25f8bdc --- /dev/null +++ b/pkg/espflasher/security_info_test.go @@ -0,0 +1,204 @@ +package espflasher + +import ( + "encoding/binary" + "errors" + "testing" +) + +func TestParseSecurityFlags(t *testing.T) { + tests := map[string]struct { + flags uint32 + check func(t *testing.T, pf ParsedFlags) + }{ + "all zeros": { + flags: 0, + check: func(t *testing.T, pf ParsedFlags) { + if pf.SecureBootEn || pf.SecureBootAggressiveRevoke || pf.SecureDownloadEnable || + pf.SecureBootKeyRevoke0 || pf.SecureBootKeyRevoke1 || pf.SecureBootKeyRevoke2 || + pf.SoftDisJtag || pf.HardDisJtag || pf.DisUSB || + pf.DisDownloadDcache || pf.DisDownloadIcache { + t.Error("expected all flags false for input 0") + } + }, + }, + "all bits set": { + flags: 0x7FF, + check: func(t *testing.T, pf ParsedFlags) { + if !pf.SecureBootEn || !pf.SecureBootAggressiveRevoke || !pf.SecureDownloadEnable || + !pf.SecureBootKeyRevoke0 || !pf.SecureBootKeyRevoke1 || !pf.SecureBootKeyRevoke2 || + !pf.SoftDisJtag || !pf.HardDisJtag || !pf.DisUSB || + !pf.DisDownloadDcache || !pf.DisDownloadIcache { + t.Error("expected all flags true for input 0x7FF") + } + }, + }, + "only SecureBootEn": { + flags: 1 << 0, + check: func(t *testing.T, pf ParsedFlags) { + if !pf.SecureBootEn { + t.Error("expected SecureBootEn true") + } + if pf.HardDisJtag || pf.DisUSB { + t.Error("expected other flags false") + } + }, + }, + "only HardDisJtag": { + flags: 1 << 7, + check: func(t *testing.T, pf ParsedFlags) { + if !pf.HardDisJtag { + t.Error("expected HardDisJtag true") + } + if pf.SecureBootEn || pf.SoftDisJtag { + t.Error("expected other flags false") + } + }, + }, + "only DisDownloadIcache": { + flags: 1 << 10, + check: func(t *testing.T, pf ParsedFlags) { + if !pf.DisDownloadIcache { + t.Error("expected DisDownloadIcache true") + } + if pf.DisDownloadDcache { + t.Error("expected DisDownloadDcache false") + } + }, + }, + "high bits ignored": { + flags: 0xFFFFF800, + check: func(t *testing.T, pf ParsedFlags) { + if pf.SecureBootEn || pf.DisDownloadIcache { + t.Error("expected all flags false when only high bits set") + } + }, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + pf := parseSecurityFlags(tt.flags) + tt.check(t, pf) + }) + } +} + +func TestReadSecurityInfo(t *testing.T) { + buf20 := make([]byte, 20) + binary.LittleEndian.PutUint32(buf20[0:4], 0x05) // flags: SecureBootEn + SecureDownloadEnable + buf20[4] = 0x03 // FlashCryptCnt + buf20[5] = 0x10 // KeyPurposes[0] + buf20[6] = 0x20 // KeyPurposes[1] + buf20[7] = 0x30 // KeyPurposes[2] + buf20[8] = 0x40 // KeyPurposes[3] + buf20[9] = 0x50 // KeyPurposes[4] + buf20[10] = 0x60 // KeyPurposes[5] + buf20[11] = 0x70 // KeyPurposes[6] + binary.LittleEndian.PutUint32(buf20[12:16], 0x0009) // ChipID + binary.LittleEndian.PutUint32(buf20[16:20], 0x0002) // APIVersion + + buf12 := make([]byte, 12) + binary.LittleEndian.PutUint32(buf12[0:4], 0x01) // flags: SecureBootEn + buf12[4] = 0x07 // FlashCryptCnt + buf12[5] = 0xAA // KeyPurposes[0] + + tests := map[string]struct { + securityInfoFunc func() ([]byte, error) + expectErr bool + check func(t *testing.T, si *SecurityInfo) + }{ + "20 bytes with ChipID and APIVersion": { + securityInfoFunc: func() ([]byte, error) { + return buf20, nil + }, + check: func(t *testing.T, si *SecurityInfo) { + if si.Flags != 0x05 { + t.Errorf("Flags = 0x%x, want 0x05", si.Flags) + } + if si.FlashCryptCnt != 0x03 { + t.Errorf("FlashCryptCnt = %d, want 3", si.FlashCryptCnt) + } + expectedKeys := [7]uint8{0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70} + if si.KeyPurposes != expectedKeys { + t.Errorf("KeyPurposes = %v, want %v", si.KeyPurposes, expectedKeys) + } + if si.ChipID == nil || *si.ChipID != 0x0009 { + t.Errorf("ChipID = %v, want 9", si.ChipID) + } + if si.APIVersion == nil || *si.APIVersion != 0x0002 { + t.Errorf("APIVersion = %v, want 2", si.APIVersion) + } + if !si.ParsedFlags.SecureBootEn { + t.Error("expected SecureBootEn true") + } + if !si.ParsedFlags.SecureDownloadEnable { + t.Error("expected SecureDownloadEnable true") + } + if si.ParsedFlags.SecureBootAggressiveRevoke { + t.Error("expected SecureBootAggressiveRevoke false") + } + }, + }, + "12 bytes without ChipID": { + securityInfoFunc: func() ([]byte, error) { + return buf12, nil + }, + check: func(t *testing.T, si *SecurityInfo) { + if si.Flags != 0x01 { + t.Errorf("Flags = 0x%x, want 0x01", si.Flags) + } + if si.FlashCryptCnt != 0x07 { + t.Errorf("FlashCryptCnt = %d, want 7", si.FlashCryptCnt) + } + if si.KeyPurposes[0] != 0xAA { + t.Errorf("KeyPurposes[0] = 0x%x, want 0xAA", si.KeyPurposes[0]) + } + if si.ChipID != nil { + t.Errorf("ChipID = %v, want nil", si.ChipID) + } + if si.APIVersion != nil { + t.Errorf("APIVersion = %v, want nil", si.APIVersion) + } + if !si.ParsedFlags.SecureBootEn { + t.Error("expected SecureBootEn true") + } + }, + }, + "unexpected length": { + securityInfoFunc: func() ([]byte, error) { + return make([]byte, 5), nil + }, + expectErr: true, + }, + "connection error": { + securityInfoFunc: func() ([]byte, error) { + return nil, errors.New("connection timeout") + }, + expectErr: true, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + mc := &mockConnection{ + securityInfoFunc: tt.securityInfoFunc, + } + f := &Flasher{conn: mc} + + si, err := f.readSecurityInfo() + if tt.expectErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.check != nil { + tt.check(t, si) + } + }) + } +} diff --git a/pkg/espflasher/slip.go b/pkg/espflasher/slip.go index f6b1339..28753ba 100644 --- a/pkg/espflasher/slip.go +++ b/pkg/espflasher/slip.go @@ -92,10 +92,7 @@ func (r *slipReader) ReadFrame(timeout time.Duration) ([]byte, error) { break } - readTimeout := remaining - if readTimeout > 100*time.Millisecond { - readTimeout = 100 * time.Millisecond - } + readTimeout := min(remaining, 100*time.Millisecond) r.port.SetReadTimeout(readTimeout) n, err := r.port.Read(buf) @@ -109,7 +106,7 @@ func (r *slipReader) ReadFrame(timeout time.Duration) ([]byte, error) { continue } - for i := 0; i < n; i++ { + for i := range n { b := buf[i] if !inFrame {