Add SIMD indexByteTwo/lastIndexByteTwo for faster prefiltering

Use SIMD to search for two bytes simultaneously, replacing the
two-pass bytes.IndexByte approach in trySkip and the scalar backward
loop in asciiFuzzyIndex. AVX2+SSE2 on amd64, NEON on arm64, with
scalar fallback for other architectures.

=== query: 'l' ===
  [all]   baseline:   100.61ms  current:    98.88ms  (+1.7%)  matches: 5069891 (94.78%)
  [1T]    baseline:   889.28ms  current:   852.71ms  (+4.1%)  matches: 5069891 (94.78%)

=== query: 'lin' ===
  [all]   baseline:   281.31ms  current:   269.35ms  (+4.3%)  matches: 3516507 (65.74%)
  [1T]    baseline:  2266.51ms  current:  2238.24ms  (+1.2%)  matches: 3516507 (65.74%)

=== query: 'linux' ===
  [all]   baseline:    69.94ms  current:    68.33ms  (+2.3%)  matches: 307229 (5.74%)
  [1T]    baseline:   642.66ms  current:   589.10ms  (+8.3%)  matches: 307229 (5.74%)

=== query: 'linuxlinux' ===
  [all]   baseline:    39.56ms  current:    35.48ms  (+10.3%)  matches: 12230 (0.23%)
  [1T]    baseline:   367.88ms  current:   333.49ms  (+9.3%)  matches: 12230 (0.23%)

=== query: 'linuxlinuxlinux' ===
  [all]   baseline:    36.22ms  current:    31.59ms  (+12.8%)  matches: 865 (0.02%)
  [1T]    baseline:   339.48ms  current:   293.02ms  (+13.7%)  matches: 865 (0.02%)
This commit is contained in:
Junegunn Choi
2026-03-02 10:37:55 +09:00
parent 4866c34361
commit 97ac7794cf
7 changed files with 975 additions and 20 deletions

View File

@@ -321,22 +321,15 @@ type Algo func(caseSensitive bool, normalize bool, forward bool, input *util.Cha
func trySkip(input *util.Chars, caseSensitive bool, b byte, from int) int {
byteArray := input.Bytes()[from:]
idx := bytes.IndexByte(byteArray, b)
if idx == 0 {
// Can't skip any further
return from
}
// We may need to search for the uppercase letter again. We don't have to
// consider normalization as we can be sure that this is an ASCII string.
// For case-insensitive search of a letter, search for both cases in one pass
if !caseSensitive && b >= 'a' && b <= 'z' {
if idx > 0 {
byteArray = byteArray[:idx]
}
uidx := bytes.IndexByte(byteArray, b-32)
if uidx >= 0 {
idx = uidx
idx := indexByteTwo(byteArray, b, b-32)
if idx < 0 {
return -1
}
return from + idx
}
idx := bytes.IndexByte(byteArray, b)
if idx < 0 {
return -1
}
@@ -380,14 +373,17 @@ func asciiFuzzyIndex(input *util.Chars, pattern []rune, caseSensitive bool) (int
}
// Find the last appearance of the last character of the pattern to limit the search scope
bu := b
if !caseSensitive && b >= 'a' && b <= 'z' {
bu = b - 32
}
scope := input.Bytes()[lastIdx:]
for offset := len(scope) - 1; offset > 0; offset-- {
if scope[offset] == b || scope[offset] == bu {
return firstIdx, lastIdx + offset + 1
if len(scope) > 1 {
tail := scope[1:]
var end int
if !caseSensitive && b >= 'a' && b <= 'z' {
end = lastIndexByteTwo(tail, b, b-32)
} else {
end = bytes.LastIndexByte(tail, b)
}
if end >= 0 {
return firstIdx, lastIdx + 1 + end + 1
}
}
return firstIdx, lastIdx + 1

View File

@@ -0,0 +1,24 @@
//go:build amd64
package algo
var _useAVX2 bool
func init() {
_useAVX2 = cpuHasAVX2()
}
//go:noescape
func cpuHasAVX2() bool
// indexByteTwo returns the index of the first occurrence of b1 or b2 in s,
// or -1 if neither is present. Uses AVX2 when available, SSE2 otherwise.
//
//go:noescape
func indexByteTwo(s []byte, b1, b2 byte) int
// lastIndexByteTwo returns the index of the last occurrence of b1 or b2 in s,
// or -1 if neither is present. Uses AVX2 when available, SSE2 otherwise.
//
//go:noescape
func lastIndexByteTwo(s []byte, b1, b2 byte) int

377
src/algo/indexbyte2_amd64.s Normal file
View File

@@ -0,0 +1,377 @@
#include "textflag.h"
// func cpuHasAVX2() bool
//
// Checks CPUID and XGETBV for AVX2 + OS YMM support.
TEXT ·cpuHasAVX2(SB),NOSPLIT,$0-1
MOVQ BX, R8 // save BX (callee-saved, clobbered by CPUID)
// Check max CPUID leaf >= 7
MOVL $0, AX
CPUID
CMPL AX, $7
JL cpuid_no
// Check OSXSAVE (CPUID.1:ECX bit 27)
MOVL $1, AX
CPUID
TESTL $(1<<27), CX
JZ cpuid_no
// Check AVX2 (CPUID.7.0:EBX bit 5)
MOVL $7, AX
MOVL $0, CX
CPUID
TESTL $(1<<5), BX
JZ cpuid_no
// Check OS YMM state support via XGETBV
MOVL $0, CX
BYTE $0x0F; BYTE $0x01; BYTE $0xD0 // XGETBV → EDX:EAX
ANDL $6, AX // bits 1 (XMM) and 2 (YMM)
CMPL AX, $6
JNE cpuid_no
MOVQ R8, BX // restore BX
MOVB $1, ret+0(FP)
RET
cpuid_no:
MOVQ R8, BX
MOVB $0, ret+0(FP)
RET
// func indexByteTwo(s []byte, b1, b2 byte) int
//
// Returns the index of the first occurrence of b1 or b2 in s, or -1.
// Uses AVX2 (32 bytes/iter) when available, SSE2 (16 bytes/iter) otherwise.
TEXT ·indexByteTwo(SB),NOSPLIT,$0-40
MOVQ s_base+0(FP), SI
MOVQ s_len+8(FP), BX
MOVBLZX b1+24(FP), AX
MOVBLZX b2+25(FP), CX
LEAQ ret+32(FP), R8
TESTQ BX, BX
JEQ fwd_failure
// Try AVX2 for inputs >= 32 bytes
CMPQ BX, $32
JLT fwd_sse2
CMPB ·_useAVX2(SB), $1
JNE fwd_sse2
// ====== AVX2 forward search ======
MOVD AX, X0
VPBROADCASTB X0, Y0 // Y0 = splat(b1)
MOVD CX, X1
VPBROADCASTB X1, Y1 // Y1 = splat(b2)
MOVQ SI, DI
LEAQ -32(SI)(BX*1), AX // AX = last valid 32-byte chunk
JMP fwd_avx2_entry
fwd_avx2_loop:
VMOVDQU (DI), Y2
VPCMPEQB Y0, Y2, Y3
VPCMPEQB Y1, Y2, Y4
VPOR Y3, Y4, Y3
VPMOVMSKB Y3, DX
BSFL DX, DX
JNZ fwd_avx2_success
ADDQ $32, DI
fwd_avx2_entry:
CMPQ DI, AX
JB fwd_avx2_loop
// Last 32-byte chunk (may overlap with previous)
MOVQ AX, DI
VMOVDQU (AX), Y2
VPCMPEQB Y0, Y2, Y3
VPCMPEQB Y1, Y2, Y4
VPOR Y3, Y4, Y3
VPMOVMSKB Y3, DX
BSFL DX, DX
JNZ fwd_avx2_success
MOVQ $-1, (R8)
VZEROUPPER
RET
fwd_avx2_success:
SUBQ SI, DI
ADDQ DX, DI
MOVQ DI, (R8)
VZEROUPPER
RET
// ====== SSE2 forward search (< 32 bytes or no AVX2) ======
fwd_sse2:
// Broadcast b1 into X0
MOVD AX, X0
PUNPCKLBW X0, X0
PUNPCKLBW X0, X0
PSHUFL $0, X0, X0
// Broadcast b2 into X4
MOVD CX, X4
PUNPCKLBW X4, X4
PUNPCKLBW X4, X4
PSHUFL $0, X4, X4
CMPQ BX, $16
JLT fwd_small
MOVQ SI, DI
LEAQ -16(SI)(BX*1), AX
JMP fwd_sseloopentry
fwd_sseloop:
MOVOU (DI), X1
MOVOU X1, X2
PCMPEQB X0, X1
PCMPEQB X4, X2
POR X2, X1
PMOVMSKB X1, DX
BSFL DX, DX
JNZ fwd_ssesuccess
ADDQ $16, DI
fwd_sseloopentry:
CMPQ DI, AX
JB fwd_sseloop
// Search the last 16-byte chunk (may overlap)
MOVQ AX, DI
MOVOU (AX), X1
MOVOU X1, X2
PCMPEQB X0, X1
PCMPEQB X4, X2
POR X2, X1
PMOVMSKB X1, DX
BSFL DX, DX
JNZ fwd_ssesuccess
fwd_failure:
MOVQ $-1, (R8)
RET
fwd_ssesuccess:
SUBQ SI, DI
ADDQ DX, DI
MOVQ DI, (R8)
RET
fwd_small:
// Check if loading 16 bytes from SI would cross a page boundary
LEAQ 16(SI), AX
TESTW $0xff0, AX
JEQ fwd_endofpage
MOVOU (SI), X1
MOVOU X1, X2
PCMPEQB X0, X1
PCMPEQB X4, X2
POR X2, X1
PMOVMSKB X1, DX
BSFL DX, DX
JZ fwd_failure
CMPL DX, BX
JAE fwd_failure
MOVQ DX, (R8)
RET
fwd_endofpage:
MOVOU -16(SI)(BX*1), X1
MOVOU X1, X2
PCMPEQB X0, X1
PCMPEQB X4, X2
POR X2, X1
PMOVMSKB X1, DX
MOVL BX, CX
SHLL CX, DX
SHRL $16, DX
BSFL DX, DX
JZ fwd_failure
MOVQ DX, (R8)
RET
// func lastIndexByteTwo(s []byte, b1, b2 byte) int
//
// Returns the index of the last occurrence of b1 or b2 in s, or -1.
// Uses AVX2 (32 bytes/iter) when available, SSE2 (16 bytes/iter) otherwise.
TEXT ·lastIndexByteTwo(SB),NOSPLIT,$0-40
MOVQ s_base+0(FP), SI
MOVQ s_len+8(FP), BX
MOVBLZX b1+24(FP), AX
MOVBLZX b2+25(FP), CX
LEAQ ret+32(FP), R8
TESTQ BX, BX
JEQ back_failure
// Try AVX2 for inputs >= 32 bytes
CMPQ BX, $32
JLT back_sse2
CMPB ·_useAVX2(SB), $1
JNE back_sse2
// ====== AVX2 backward search ======
MOVD AX, X0
VPBROADCASTB X0, Y0
MOVD CX, X1
VPBROADCASTB X1, Y1
// DI = start of last 32-byte chunk
LEAQ -32(SI)(BX*1), DI
back_avx2_loop:
CMPQ DI, SI
JBE back_avx2_first
VMOVDQU (DI), Y2
VPCMPEQB Y0, Y2, Y3
VPCMPEQB Y1, Y2, Y4
VPOR Y3, Y4, Y3
VPMOVMSKB Y3, DX
BSRL DX, DX
JNZ back_avx2_success
SUBQ $32, DI
JMP back_avx2_loop
back_avx2_first:
// First 32 bytes (DI <= SI, load from SI)
VMOVDQU (SI), Y2
VPCMPEQB Y0, Y2, Y3
VPCMPEQB Y1, Y2, Y4
VPOR Y3, Y4, Y3
VPMOVMSKB Y3, DX
BSRL DX, DX
JNZ back_avx2_firstsuccess
MOVQ $-1, (R8)
VZEROUPPER
RET
back_avx2_success:
SUBQ SI, DI
ADDQ DX, DI
MOVQ DI, (R8)
VZEROUPPER
RET
back_avx2_firstsuccess:
MOVQ DX, (R8)
VZEROUPPER
RET
// ====== SSE2 backward search (< 32 bytes or no AVX2) ======
back_sse2:
// Broadcast b1 into X0
MOVD AX, X0
PUNPCKLBW X0, X0
PUNPCKLBW X0, X0
PSHUFL $0, X0, X0
// Broadcast b2 into X4
MOVD CX, X4
PUNPCKLBW X4, X4
PUNPCKLBW X4, X4
PSHUFL $0, X4, X4
CMPQ BX, $16
JLT back_small
// DI = start of last 16-byte chunk
LEAQ -16(SI)(BX*1), DI
back_sseloop:
CMPQ DI, SI
JBE back_ssefirst
MOVOU (DI), X1
MOVOU X1, X2
PCMPEQB X0, X1
PCMPEQB X4, X2
POR X2, X1
PMOVMSKB X1, DX
BSRL DX, DX
JNZ back_ssesuccess
SUBQ $16, DI
JMP back_sseloop
back_ssefirst:
// First 16 bytes (DI <= SI, load from SI)
MOVOU (SI), X1
MOVOU X1, X2
PCMPEQB X0, X1
PCMPEQB X4, X2
POR X2, X1
PMOVMSKB X1, DX
BSRL DX, DX
JNZ back_ssefirstsuccess
back_failure:
MOVQ $-1, (R8)
RET
back_ssesuccess:
SUBQ SI, DI
ADDQ DX, DI
MOVQ DI, (R8)
RET
back_ssefirstsuccess:
// DX = byte offset from base
MOVQ DX, (R8)
RET
back_small:
// Check page boundary
LEAQ 16(SI), AX
TESTW $0xff0, AX
JEQ back_endofpage
MOVOU (SI), X1
MOVOU X1, X2
PCMPEQB X0, X1
PCMPEQB X4, X2
POR X2, X1
PMOVMSKB X1, DX
// Mask to first BX bytes: keep bits 0..BX-1
MOVL $1, AX
MOVL BX, CX
SHLL CX, AX
DECL AX
ANDL AX, DX
BSRL DX, DX
JZ back_failure
MOVQ DX, (R8)
RET
back_endofpage:
// Load 16 bytes ending at base+n
MOVOU -16(SI)(BX*1), X1
MOVOU X1, X2
PCMPEQB X0, X1
PCMPEQB X4, X2
POR X2, X1
PMOVMSKB X1, DX
// Bits correspond to bytes [base+n-16, base+n).
// We want original bytes [0, n), which are bits [16-n, 16).
// Mask: keep bits (16-n) through 15.
MOVL $16, CX
SUBL BX, CX
SHRL CX, DX
SHLL CX, DX
BSRL DX, DX
JZ back_failure
// DX is the bit position in the loaded chunk.
// Original byte index = DX - (16 - n) = DX + n - 16
ADDL BX, DX
SUBL $16, DX
MOVQ DX, (R8)
RET

View File

@@ -0,0 +1,17 @@
//go:build arm64
package algo
// indexByteTwo returns the index of the first occurrence of b1 or b2 in s,
// or -1 if neither is present. Implemented in assembly using ARM64 NEON
// to search for both bytes in a single pass.
//
//go:noescape
func indexByteTwo(s []byte, b1, b2 byte) int
// lastIndexByteTwo returns the index of the last occurrence of b1 or b2 in s,
// or -1 if neither is present. Implemented in assembly using ARM64 NEON,
// scanning backward.
//
//go:noescape
func lastIndexByteTwo(s []byte, b1, b2 byte) int

249
src/algo/indexbyte2_arm64.s Normal file
View File

@@ -0,0 +1,249 @@
#include "textflag.h"
// func indexByteTwo(s []byte, b1, b2 byte) int
//
// Returns the index of the first occurrence of b1 or b2 in s, or -1.
// Uses ARM64 NEON to search for both bytes in a single pass over the data.
// Adapted from Go's internal/bytealg/indexbyte_arm64.s (single-byte version).
TEXT ·indexByteTwo(SB),NOSPLIT,$0-40
MOVD s_base+0(FP), R0
MOVD s_len+8(FP), R2
MOVBU b1+24(FP), R1
MOVBU b2+25(FP), R7
MOVD $ret+32(FP), R8
// Core algorithm:
// For each 32-byte chunk we calculate a 64-bit syndrome value,
// with two bits per byte. We compare against both b1 and b2,
// OR the results, then use the same syndrome extraction as
// Go's IndexByte.
CBZ R2, fail
MOVD R0, R11
// Magic constant 0x40100401 allows us to identify which lane matches.
// Each byte in the group of 4 gets a distinct bit: 1, 4, 16, 64.
MOVD $0x40100401, R5
VMOV R1, V0.B16 // V0 = splat(b1)
VMOV R7, V7.B16 // V7 = splat(b2)
// Work with aligned 32-byte chunks
BIC $0x1f, R0, R3
VMOV R5, V5.S4
ANDS $0x1f, R0, R9
AND $0x1f, R2, R10
BEQ loop
// Input string is not 32-byte aligned. Process the first
// aligned 32-byte block and mask off bytes before our start.
VLD1.P (R3), [V1.B16, V2.B16]
SUB $0x20, R9, R4
ADDS R4, R2, R2
// Compare against both needles
VCMEQ V0.B16, V1.B16, V3.B16 // b1 vs first 16 bytes
VCMEQ V7.B16, V1.B16, V8.B16 // b2 vs first 16 bytes
VORR V8.B16, V3.B16, V3.B16 // combine
VCMEQ V0.B16, V2.B16, V4.B16 // b1 vs second 16 bytes
VCMEQ V7.B16, V2.B16, V9.B16 // b2 vs second 16 bytes
VORR V9.B16, V4.B16, V4.B16 // combine
// Build syndrome
VAND V5.B16, V3.B16, V3.B16
VAND V5.B16, V4.B16, V4.B16
VADDP V4.B16, V3.B16, V6.B16
VADDP V6.B16, V6.B16, V6.B16
VMOV V6.D[0], R6
// Clear the irrelevant lower bits
LSL $1, R9, R4
LSR R4, R6, R6
LSL R4, R6, R6
// The first block can also be the last
BLS masklast
// Have we found something already?
CBNZ R6, tail
loop:
VLD1.P (R3), [V1.B16, V2.B16]
SUBS $0x20, R2, R2
// Compare against both needles, OR results
VCMEQ V0.B16, V1.B16, V3.B16
VCMEQ V7.B16, V1.B16, V8.B16
VORR V8.B16, V3.B16, V3.B16
VCMEQ V0.B16, V2.B16, V4.B16
VCMEQ V7.B16, V2.B16, V9.B16
VORR V9.B16, V4.B16, V4.B16
// If we're out of data we finish regardless of the result
BLS end
// Fast check: OR both halves and check for any match
VORR V4.B16, V3.B16, V6.B16
VADDP V6.D2, V6.D2, V6.D2
VMOV V6.D[0], R6
CBZ R6, loop
end:
// Found something or out of data build full syndrome
VAND V5.B16, V3.B16, V3.B16
VAND V5.B16, V4.B16, V4.B16
VADDP V4.B16, V3.B16, V6.B16
VADDP V6.B16, V6.B16, V6.B16
VMOV V6.D[0], R6
// Only mask for the last block
BHS tail
masklast:
// Clear irrelevant upper bits
ADD R9, R10, R4
AND $0x1f, R4, R4
SUB $0x20, R4, R4
NEG R4<<1, R4
LSL R4, R6, R6
LSR R4, R6, R6
tail:
CBZ R6, fail
RBIT R6, R6
SUB $0x20, R3, R3
CLZ R6, R6
ADD R6>>1, R3, R0
SUB R11, R0, R0
MOVD R0, (R8)
RET
fail:
MOVD $-1, R0
MOVD R0, (R8)
RET
// func lastIndexByteTwo(s []byte, b1, b2 byte) int
//
// Returns the index of the last occurrence of b1 or b2 in s, or -1.
// Scans backward using ARM64 NEON.
TEXT ·lastIndexByteTwo(SB),NOSPLIT,$0-40
MOVD s_base+0(FP), R0
MOVD s_len+8(FP), R2
MOVBU b1+24(FP), R1
MOVBU b2+25(FP), R7
MOVD $ret+32(FP), R8
CBZ R2, lfail
MOVD R0, R11 // save base
ADD R0, R2, R12 // R12 = end = base + len
MOVD $0x40100401, R5
VMOV R1, V0.B16 // V0 = splat(b1)
VMOV R7, V7.B16 // V7 = splat(b2)
VMOV R5, V5.S4
// Align: find the aligned block containing the last byte
SUB $1, R12, R3
BIC $0x1f, R3, R3 // R3 = start of aligned block containing last byte
// --- Process tail block ---
VLD1 (R3), [V1.B16, V2.B16]
VCMEQ V0.B16, V1.B16, V3.B16
VCMEQ V7.B16, V1.B16, V8.B16
VORR V8.B16, V3.B16, V3.B16
VCMEQ V0.B16, V2.B16, V4.B16
VCMEQ V7.B16, V2.B16, V9.B16
VORR V9.B16, V4.B16, V4.B16
VAND V5.B16, V3.B16, V3.B16
VAND V5.B16, V4.B16, V4.B16
VADDP V4.B16, V3.B16, V6.B16
VADDP V6.B16, V6.B16, V6.B16
VMOV V6.D[0], R6
// Mask upper bits (bytes past end of slice)
// tail_bytes = end - R3 (1..32)
SUB R3, R12, R10 // R10 = tail_bytes
MOVD $64, R4
SUB R10<<1, R4, R4 // R4 = 64 - 2*tail_bytes
LSL R4, R6, R6
LSR R4, R6, R6
// Is this also the head block?
CMP R11, R3 // R3 - R11
BLO lmaskfirst // R3 < base: head+tail in same block
BEQ ltailonly // R3 == base: single aligned block
// R3 > base: more blocks before this one
CBNZ R6, llast
B lbacksetup
ltailonly:
// Single block, already masked upper bits
CBNZ R6, llast
B lfail
lmaskfirst:
// Mask lower bits (bytes before start of slice)
SUB R3, R11, R4 // R4 = base - R3
LSL $1, R4, R4
LSR R4, R6, R6
LSL R4, R6, R6
CBNZ R6, llast
B lfail
lbacksetup:
SUB $0x20, R3
lbackloop:
VLD1 (R3), [V1.B16, V2.B16]
VCMEQ V0.B16, V1.B16, V3.B16
VCMEQ V7.B16, V1.B16, V8.B16
VORR V8.B16, V3.B16, V3.B16
VCMEQ V0.B16, V2.B16, V4.B16
VCMEQ V7.B16, V2.B16, V9.B16
VORR V9.B16, V4.B16, V4.B16
// Quick check: any match in this block?
VORR V4.B16, V3.B16, V6.B16
VADDP V6.D2, V6.D2, V6.D2
VMOV V6.D[0], R6
// Is this a head block? (R3 < base)
CMP R11, R3
BLO lheadblock
// Full block (R3 >= base)
CBNZ R6, lbackfound
// More blocks?
BEQ lfail // R3 == base, no more
SUB $0x20, R3
B lbackloop
lbackfound:
// Build full syndrome
VAND V5.B16, V3.B16, V3.B16
VAND V5.B16, V4.B16, V4.B16
VADDP V4.B16, V3.B16, V6.B16
VADDP V6.B16, V6.B16, V6.B16
VMOV V6.D[0], R6
B llast
lheadblock:
// R3 < base. Build full syndrome if quick check had a match.
CBZ R6, lfail
VAND V5.B16, V3.B16, V3.B16
VAND V5.B16, V4.B16, V4.B16
VADDP V4.B16, V3.B16, V6.B16
VADDP V6.B16, V6.B16, V6.B16
VMOV V6.D[0], R6
// Mask lower bits
SUB R3, R11, R4 // R4 = base - R3
LSL $1, R4, R4
LSR R4, R6, R6
LSL R4, R6, R6
CBZ R6, lfail
llast:
// Find last match: highest set bit in syndrome
// Syndrome has bit 2i set for matching byte i.
// CLZ gives leading zeros; byte_offset = (63 - CLZ) / 2.
CLZ R6, R6
MOVD $63, R4
SUB R6, R4, R6 // R6 = 63 - CLZ = bit position
LSR $1, R6 // R6 = byte offset within block
ADD R3, R6, R0 // R0 = absolute address
SUB R11, R0, R0 // R0 = slice index
MOVD R0, (R8)
RET
lfail:
MOVD $-1, R0
MOVD R0, (R8)
RET

View File

@@ -0,0 +1,33 @@
//go:build !arm64 && !amd64
package algo
import "bytes"
// indexByteTwo returns the index of the first occurrence of b1 or b2 in s,
// or -1 if neither is present.
func indexByteTwo(s []byte, b1, b2 byte) int {
i1 := bytes.IndexByte(s, b1)
if i1 == 0 {
return 0
}
scope := s
if i1 > 0 {
scope = s[:i1]
}
if i2 := bytes.IndexByte(scope, b2); i2 >= 0 {
return i2
}
return i1
}
// lastIndexByteTwo returns the index of the last occurrence of b1 or b2 in s,
// or -1 if neither is present.
func lastIndexByteTwo(s []byte, b1, b2 byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == b1 || s[i] == b2 {
return i
}
}
return -1
}

259
src/algo/indexbyte2_test.go Normal file
View File

@@ -0,0 +1,259 @@
package algo
import (
"bytes"
"testing"
)
func TestIndexByteTwo(t *testing.T) {
tests := []struct {
name string
s string
b1 byte
b2 byte
want int
}{
{"empty", "", 'a', 'b', -1},
{"single_b1", "a", 'a', 'b', 0},
{"single_b2", "b", 'a', 'b', 0},
{"single_none", "c", 'a', 'b', -1},
{"b1_first", "xaxb", 'a', 'b', 1},
{"b2_first", "xbxa", 'a', 'b', 1},
{"same_byte", "xxa", 'a', 'a', 2},
{"at_end", "xxxxa", 'a', 'b', 4},
{"not_found", "xxxxxxxx", 'a', 'b', -1},
{"long_b1_at_3000", string(make([]byte, 3000)) + "a" + string(make([]byte, 1000)), 'a', 'b', 3000},
{"long_b2_at_3000", string(make([]byte, 3000)) + "b" + string(make([]byte, 1000)), 'a', 'b', 3000},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := indexByteTwo([]byte(tt.s), tt.b1, tt.b2)
if got != tt.want {
t.Errorf("indexByteTwo(%q, %c, %c) = %d, want %d", tt.s[:min(len(tt.s), 40)], tt.b1, tt.b2, got, tt.want)
}
})
}
// Exhaustive test: compare against loop reference for various lengths,
// including sizes around SIMD block boundaries (16, 32, 64).
for n := 0; n <= 256; n++ {
data := make([]byte, n)
for i := range data {
data[i] = byte('c' + (i % 20))
}
// Test with match at every position
for pos := 0; pos < n; pos++ {
for _, b := range []byte{'A', 'B'} {
data[pos] = b
got := indexByteTwo(data, 'A', 'B')
want := loopIndexByteTwo(data, 'A', 'B')
if got != want {
t.Fatalf("indexByteTwo(len=%d, match=%c@%d) = %d, want %d", n, b, pos, got, want)
}
data[pos] = byte('c' + (pos % 20))
}
}
// Test with no match
got := indexByteTwo(data, 'A', 'B')
if got != -1 {
t.Fatalf("indexByteTwo(len=%d, no match) = %d, want -1", n, got)
}
// Test with both bytes present
if n >= 2 {
data[n/3] = 'A'
data[n*2/3] = 'B'
got := indexByteTwo(data, 'A', 'B')
want := loopIndexByteTwo(data, 'A', 'B')
if got != want {
t.Fatalf("indexByteTwo(len=%d, both@%d,%d) = %d, want %d", n, n/3, n*2/3, got, want)
}
data[n/3] = byte('c' + ((n / 3) % 20))
data[n*2/3] = byte('c' + ((n * 2 / 3) % 20))
}
}
}
func TestLastIndexByteTwo(t *testing.T) {
tests := []struct {
name string
s string
b1 byte
b2 byte
want int
}{
{"empty", "", 'a', 'b', -1},
{"single_b1", "a", 'a', 'b', 0},
{"single_b2", "b", 'a', 'b', 0},
{"single_none", "c", 'a', 'b', -1},
{"b1_last", "xbxa", 'a', 'b', 3},
{"b2_last", "xaxb", 'a', 'b', 3},
{"same_byte", "axx", 'a', 'a', 0},
{"at_start", "axxxx", 'a', 'b', 0},
{"both_present", "axbx", 'a', 'b', 2},
{"not_found", "xxxxxxxx", 'a', 'b', -1},
{"long_b1_at_3000", string(make([]byte, 3000)) + "a" + string(make([]byte, 1000)), 'a', 'b', 3000},
{"long_b2_at_end", string(make([]byte, 4000)) + "b", 'a', 'b', 4000},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := lastIndexByteTwo([]byte(tt.s), tt.b1, tt.b2)
if got != tt.want {
t.Errorf("lastIndexByteTwo(%q, %c, %c) = %d, want %d", tt.s[:min(len(tt.s), 40)], tt.b1, tt.b2, got, tt.want)
}
})
}
// Exhaustive test against loop reference
for n := 0; n <= 256; n++ {
data := make([]byte, n)
for i := range data {
data[i] = byte('c' + (i % 20))
}
for pos := 0; pos < n; pos++ {
for _, b := range []byte{'A', 'B'} {
data[pos] = b
got := lastIndexByteTwo(data, 'A', 'B')
want := refLastIndexByteTwo(data, 'A', 'B')
if got != want {
t.Fatalf("lastIndexByteTwo(len=%d, match=%c@%d) = %d, want %d", n, b, pos, got, want)
}
data[pos] = byte('c' + (pos % 20))
}
}
// No match
got := lastIndexByteTwo(data, 'A', 'B')
if got != -1 {
t.Fatalf("lastIndexByteTwo(len=%d, no match) = %d, want -1", n, got)
}
// Both bytes present
if n >= 2 {
data[n/3] = 'A'
data[n*2/3] = 'B'
got := lastIndexByteTwo(data, 'A', 'B')
want := refLastIndexByteTwo(data, 'A', 'B')
if got != want {
t.Fatalf("lastIndexByteTwo(len=%d, both@%d,%d) = %d, want %d", n, n/3, n*2/3, got, want)
}
data[n/3] = byte('c' + ((n / 3) % 20))
data[n*2/3] = byte('c' + ((n * 2 / 3) % 20))
}
}
}
func FuzzIndexByteTwo(f *testing.F) {
f.Add([]byte("hello world"), byte('o'), byte('l'))
f.Add([]byte(""), byte('a'), byte('b'))
f.Add([]byte("aaa"), byte('a'), byte('a'))
f.Fuzz(func(t *testing.T, data []byte, b1, b2 byte) {
got := indexByteTwo(data, b1, b2)
want := loopIndexByteTwo(data, b1, b2)
if got != want {
t.Errorf("indexByteTwo(len=%d, b1=%d, b2=%d) = %d, want %d", len(data), b1, b2, got, want)
}
})
}
func FuzzLastIndexByteTwo(f *testing.F) {
f.Add([]byte("hello world"), byte('o'), byte('l'))
f.Add([]byte(""), byte('a'), byte('b'))
f.Add([]byte("aaa"), byte('a'), byte('a'))
f.Fuzz(func(t *testing.T, data []byte, b1, b2 byte) {
got := lastIndexByteTwo(data, b1, b2)
want := refLastIndexByteTwo(data, b1, b2)
if got != want {
t.Errorf("lastIndexByteTwo(len=%d, b1=%d, b2=%d) = %d, want %d", len(data), b1, b2, got, want)
}
})
}
// Reference implementations for correctness checking
func refIndexByteTwo(s []byte, b1, b2 byte) int {
i1 := bytes.IndexByte(s, b1)
if i1 == 0 {
return 0
}
scope := s
if i1 > 0 {
scope = s[:i1]
}
if i2 := bytes.IndexByte(scope, b2); i2 >= 0 {
return i2
}
return i1
}
func loopIndexByteTwo(s []byte, b1, b2 byte) int {
for i, b := range s {
if b == b1 || b == b2 {
return i
}
}
return -1
}
func refLastIndexByteTwo(s []byte, b1, b2 byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == b1 || s[i] == b2 {
return i
}
}
return -1
}
func benchIndexByteTwo(b *testing.B, size int, pos int) {
data := make([]byte, size)
for i := range data {
data[i] = byte('a' + (i % 20))
}
data[pos] = 'Z'
type impl struct {
name string
fn func([]byte, byte, byte) int
}
impls := []impl{
{"asm", indexByteTwo},
{"2xIndexByte", refIndexByteTwo},
{"loop", loopIndexByteTwo},
}
for _, im := range impls {
b.Run(im.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
im.fn(data, 'Z', 'z')
}
})
}
}
func benchLastIndexByteTwo(b *testing.B, size int, pos int) {
data := make([]byte, size)
for i := range data {
data[i] = byte('a' + (i % 20))
}
data[pos] = 'Z'
type impl struct {
name string
fn func([]byte, byte, byte) int
}
impls := []impl{
{"asm", lastIndexByteTwo},
{"loop", refLastIndexByteTwo},
}
for _, im := range impls {
b.Run(im.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
im.fn(data, 'Z', 'z')
}
})
}
}
func BenchmarkIndexByteTwo_10(b *testing.B) { benchIndexByteTwo(b, 10, 8) }
func BenchmarkIndexByteTwo_100(b *testing.B) { benchIndexByteTwo(b, 100, 80) }
func BenchmarkIndexByteTwo_1000(b *testing.B) { benchIndexByteTwo(b, 1000, 800) }
func BenchmarkLastIndexByteTwo_10(b *testing.B) { benchLastIndexByteTwo(b, 10, 2) }
func BenchmarkLastIndexByteTwo_100(b *testing.B) { benchLastIndexByteTwo(b, 100, 20) }
func BenchmarkLastIndexByteTwo_1000(b *testing.B) { benchLastIndexByteTwo(b, 1000, 200) }