diff --git a/src/algo/algo.go b/src/algo/algo.go index f9ce4df5..72d35946 100644 --- a/src/algo/algo.go +++ b/src/algo/algo.go @@ -321,22 +321,15 @@ type Algo func(caseSensitive bool, normalize bool, forward bool, input *util.Cha func trySkip(input *util.Chars, caseSensitive bool, b byte, from int) int { byteArray := input.Bytes()[from:] - idx := bytes.IndexByte(byteArray, b) - if idx == 0 { - // Can't skip any further - return from - } - // We may need to search for the uppercase letter again. We don't have to - // consider normalization as we can be sure that this is an ASCII string. + // For case-insensitive search of a letter, search for both cases in one pass if !caseSensitive && b >= 'a' && b <= 'z' { - if idx > 0 { - byteArray = byteArray[:idx] - } - uidx := bytes.IndexByte(byteArray, b-32) - if uidx >= 0 { - idx = uidx + idx := indexByteTwo(byteArray, b, b-32) + if idx < 0 { + return -1 } + return from + idx } + idx := bytes.IndexByte(byteArray, b) if idx < 0 { return -1 } @@ -380,14 +373,17 @@ func asciiFuzzyIndex(input *util.Chars, pattern []rune, caseSensitive bool) (int } // Find the last appearance of the last character of the pattern to limit the search scope - bu := b - if !caseSensitive && b >= 'a' && b <= 'z' { - bu = b - 32 - } scope := input.Bytes()[lastIdx:] - for offset := len(scope) - 1; offset > 0; offset-- { - if scope[offset] == b || scope[offset] == bu { - return firstIdx, lastIdx + offset + 1 + if len(scope) > 1 { + tail := scope[1:] + var end int + if !caseSensitive && b >= 'a' && b <= 'z' { + end = lastIndexByteTwo(tail, b, b-32) + } else { + end = bytes.LastIndexByte(tail, b) + } + if end >= 0 { + return firstIdx, lastIdx + 1 + end + 1 } } return firstIdx, lastIdx + 1 diff --git a/src/algo/indexbyte2_amd64.go b/src/algo/indexbyte2_amd64.go new file mode 100644 index 00000000..eca483c5 --- /dev/null +++ b/src/algo/indexbyte2_amd64.go @@ -0,0 +1,24 @@ +//go:build amd64 + +package algo + +var _useAVX2 bool + +func init() { + _useAVX2 = cpuHasAVX2() +} + +//go:noescape +func cpuHasAVX2() bool + +// indexByteTwo returns the index of the first occurrence of b1 or b2 in s, +// or -1 if neither is present. Uses AVX2 when available, SSE2 otherwise. +// +//go:noescape +func indexByteTwo(s []byte, b1, b2 byte) int + +// lastIndexByteTwo returns the index of the last occurrence of b1 or b2 in s, +// or -1 if neither is present. Uses AVX2 when available, SSE2 otherwise. +// +//go:noescape +func lastIndexByteTwo(s []byte, b1, b2 byte) int diff --git a/src/algo/indexbyte2_amd64.s b/src/algo/indexbyte2_amd64.s new file mode 100644 index 00000000..56f8ea61 --- /dev/null +++ b/src/algo/indexbyte2_amd64.s @@ -0,0 +1,377 @@ +#include "textflag.h" + +// func cpuHasAVX2() bool +// +// Checks CPUID and XGETBV for AVX2 + OS YMM support. +TEXT ·cpuHasAVX2(SB),NOSPLIT,$0-1 + MOVQ BX, R8 // save BX (callee-saved, clobbered by CPUID) + + // Check max CPUID leaf >= 7 + MOVL $0, AX + CPUID + CMPL AX, $7 + JL cpuid_no + + // Check OSXSAVE (CPUID.1:ECX bit 27) + MOVL $1, AX + CPUID + TESTL $(1<<27), CX + JZ cpuid_no + + // Check AVX2 (CPUID.7.0:EBX bit 5) + MOVL $7, AX + MOVL $0, CX + CPUID + TESTL $(1<<5), BX + JZ cpuid_no + + // Check OS YMM state support via XGETBV + MOVL $0, CX + BYTE $0x0F; BYTE $0x01; BYTE $0xD0 // XGETBV → EDX:EAX + ANDL $6, AX // bits 1 (XMM) and 2 (YMM) + CMPL AX, $6 + JNE cpuid_no + + MOVQ R8, BX // restore BX + MOVB $1, ret+0(FP) + RET + +cpuid_no: + MOVQ R8, BX + MOVB $0, ret+0(FP) + RET + +// func indexByteTwo(s []byte, b1, b2 byte) int +// +// Returns the index of the first occurrence of b1 or b2 in s, or -1. +// Uses AVX2 (32 bytes/iter) when available, SSE2 (16 bytes/iter) otherwise. +TEXT ·indexByteTwo(SB),NOSPLIT,$0-40 + MOVQ s_base+0(FP), SI + MOVQ s_len+8(FP), BX + MOVBLZX b1+24(FP), AX + MOVBLZX b2+25(FP), CX + LEAQ ret+32(FP), R8 + + TESTQ BX, BX + JEQ fwd_failure + + // Try AVX2 for inputs >= 32 bytes + CMPQ BX, $32 + JLT fwd_sse2 + CMPB ·_useAVX2(SB), $1 + JNE fwd_sse2 + + // ====== AVX2 forward search ====== + MOVD AX, X0 + VPBROADCASTB X0, Y0 // Y0 = splat(b1) + MOVD CX, X1 + VPBROADCASTB X1, Y1 // Y1 = splat(b2) + + MOVQ SI, DI + LEAQ -32(SI)(BX*1), AX // AX = last valid 32-byte chunk + JMP fwd_avx2_entry + +fwd_avx2_loop: + VMOVDQU (DI), Y2 + VPCMPEQB Y0, Y2, Y3 + VPCMPEQB Y1, Y2, Y4 + VPOR Y3, Y4, Y3 + VPMOVMSKB Y3, DX + BSFL DX, DX + JNZ fwd_avx2_success + ADDQ $32, DI + +fwd_avx2_entry: + CMPQ DI, AX + JB fwd_avx2_loop + + // Last 32-byte chunk (may overlap with previous) + MOVQ AX, DI + VMOVDQU (AX), Y2 + VPCMPEQB Y0, Y2, Y3 + VPCMPEQB Y1, Y2, Y4 + VPOR Y3, Y4, Y3 + VPMOVMSKB Y3, DX + BSFL DX, DX + JNZ fwd_avx2_success + + MOVQ $-1, (R8) + VZEROUPPER + RET + +fwd_avx2_success: + SUBQ SI, DI + ADDQ DX, DI + MOVQ DI, (R8) + VZEROUPPER + RET + + // ====== SSE2 forward search (< 32 bytes or no AVX2) ====== + +fwd_sse2: + // Broadcast b1 into X0 + MOVD AX, X0 + PUNPCKLBW X0, X0 + PUNPCKLBW X0, X0 + PSHUFL $0, X0, X0 + + // Broadcast b2 into X4 + MOVD CX, X4 + PUNPCKLBW X4, X4 + PUNPCKLBW X4, X4 + PSHUFL $0, X4, X4 + + CMPQ BX, $16 + JLT fwd_small + + MOVQ SI, DI + LEAQ -16(SI)(BX*1), AX + JMP fwd_sseloopentry + +fwd_sseloop: + MOVOU (DI), X1 + MOVOU X1, X2 + PCMPEQB X0, X1 + PCMPEQB X4, X2 + POR X2, X1 + PMOVMSKB X1, DX + BSFL DX, DX + JNZ fwd_ssesuccess + ADDQ $16, DI + +fwd_sseloopentry: + CMPQ DI, AX + JB fwd_sseloop + + // Search the last 16-byte chunk (may overlap) + MOVQ AX, DI + MOVOU (AX), X1 + MOVOU X1, X2 + PCMPEQB X0, X1 + PCMPEQB X4, X2 + POR X2, X1 + PMOVMSKB X1, DX + BSFL DX, DX + JNZ fwd_ssesuccess + +fwd_failure: + MOVQ $-1, (R8) + RET + +fwd_ssesuccess: + SUBQ SI, DI + ADDQ DX, DI + MOVQ DI, (R8) + RET + +fwd_small: + // Check if loading 16 bytes from SI would cross a page boundary + LEAQ 16(SI), AX + TESTW $0xff0, AX + JEQ fwd_endofpage + + MOVOU (SI), X1 + MOVOU X1, X2 + PCMPEQB X0, X1 + PCMPEQB X4, X2 + POR X2, X1 + PMOVMSKB X1, DX + BSFL DX, DX + JZ fwd_failure + CMPL DX, BX + JAE fwd_failure + MOVQ DX, (R8) + RET + +fwd_endofpage: + MOVOU -16(SI)(BX*1), X1 + MOVOU X1, X2 + PCMPEQB X0, X1 + PCMPEQB X4, X2 + POR X2, X1 + PMOVMSKB X1, DX + MOVL BX, CX + SHLL CX, DX + SHRL $16, DX + BSFL DX, DX + JZ fwd_failure + MOVQ DX, (R8) + RET + +// func lastIndexByteTwo(s []byte, b1, b2 byte) int +// +// Returns the index of the last occurrence of b1 or b2 in s, or -1. +// Uses AVX2 (32 bytes/iter) when available, SSE2 (16 bytes/iter) otherwise. +TEXT ·lastIndexByteTwo(SB),NOSPLIT,$0-40 + MOVQ s_base+0(FP), SI + MOVQ s_len+8(FP), BX + MOVBLZX b1+24(FP), AX + MOVBLZX b2+25(FP), CX + LEAQ ret+32(FP), R8 + + TESTQ BX, BX + JEQ back_failure + + // Try AVX2 for inputs >= 32 bytes + CMPQ BX, $32 + JLT back_sse2 + CMPB ·_useAVX2(SB), $1 + JNE back_sse2 + + // ====== AVX2 backward search ====== + MOVD AX, X0 + VPBROADCASTB X0, Y0 + MOVD CX, X1 + VPBROADCASTB X1, Y1 + + // DI = start of last 32-byte chunk + LEAQ -32(SI)(BX*1), DI + +back_avx2_loop: + CMPQ DI, SI + JBE back_avx2_first + + VMOVDQU (DI), Y2 + VPCMPEQB Y0, Y2, Y3 + VPCMPEQB Y1, Y2, Y4 + VPOR Y3, Y4, Y3 + VPMOVMSKB Y3, DX + BSRL DX, DX + JNZ back_avx2_success + SUBQ $32, DI + JMP back_avx2_loop + +back_avx2_first: + // First 32 bytes (DI <= SI, load from SI) + VMOVDQU (SI), Y2 + VPCMPEQB Y0, Y2, Y3 + VPCMPEQB Y1, Y2, Y4 + VPOR Y3, Y4, Y3 + VPMOVMSKB Y3, DX + BSRL DX, DX + JNZ back_avx2_firstsuccess + + MOVQ $-1, (R8) + VZEROUPPER + RET + +back_avx2_success: + SUBQ SI, DI + ADDQ DX, DI + MOVQ DI, (R8) + VZEROUPPER + RET + +back_avx2_firstsuccess: + MOVQ DX, (R8) + VZEROUPPER + RET + + // ====== SSE2 backward search (< 32 bytes or no AVX2) ====== + +back_sse2: + // Broadcast b1 into X0 + MOVD AX, X0 + PUNPCKLBW X0, X0 + PUNPCKLBW X0, X0 + PSHUFL $0, X0, X0 + + // Broadcast b2 into X4 + MOVD CX, X4 + PUNPCKLBW X4, X4 + PUNPCKLBW X4, X4 + PSHUFL $0, X4, X4 + + CMPQ BX, $16 + JLT back_small + + // DI = start of last 16-byte chunk + LEAQ -16(SI)(BX*1), DI + +back_sseloop: + CMPQ DI, SI + JBE back_ssefirst + + MOVOU (DI), X1 + MOVOU X1, X2 + PCMPEQB X0, X1 + PCMPEQB X4, X2 + POR X2, X1 + PMOVMSKB X1, DX + BSRL DX, DX + JNZ back_ssesuccess + SUBQ $16, DI + JMP back_sseloop + +back_ssefirst: + // First 16 bytes (DI <= SI, load from SI) + MOVOU (SI), X1 + MOVOU X1, X2 + PCMPEQB X0, X1 + PCMPEQB X4, X2 + POR X2, X1 + PMOVMSKB X1, DX + BSRL DX, DX + JNZ back_ssefirstsuccess + +back_failure: + MOVQ $-1, (R8) + RET + +back_ssesuccess: + SUBQ SI, DI + ADDQ DX, DI + MOVQ DI, (R8) + RET + +back_ssefirstsuccess: + // DX = byte offset from base + MOVQ DX, (R8) + RET + +back_small: + // Check page boundary + LEAQ 16(SI), AX + TESTW $0xff0, AX + JEQ back_endofpage + + MOVOU (SI), X1 + MOVOU X1, X2 + PCMPEQB X0, X1 + PCMPEQB X4, X2 + POR X2, X1 + PMOVMSKB X1, DX + // Mask to first BX bytes: keep bits 0..BX-1 + MOVL $1, AX + MOVL BX, CX + SHLL CX, AX + DECL AX + ANDL AX, DX + BSRL DX, DX + JZ back_failure + MOVQ DX, (R8) + RET + +back_endofpage: + // Load 16 bytes ending at base+n + MOVOU -16(SI)(BX*1), X1 + MOVOU X1, X2 + PCMPEQB X0, X1 + PCMPEQB X4, X2 + POR X2, X1 + PMOVMSKB X1, DX + // Bits correspond to bytes [base+n-16, base+n). + // We want original bytes [0, n), which are bits [16-n, 16). + // Mask: keep bits (16-n) through 15. + MOVL $16, CX + SUBL BX, CX + SHRL CX, DX + SHLL CX, DX + BSRL DX, DX + JZ back_failure + // DX is the bit position in the loaded chunk. + // Original byte index = DX - (16 - n) = DX + n - 16 + ADDL BX, DX + SUBL $16, DX + MOVQ DX, (R8) + RET diff --git a/src/algo/indexbyte2_arm64.go b/src/algo/indexbyte2_arm64.go new file mode 100644 index 00000000..fa028aff --- /dev/null +++ b/src/algo/indexbyte2_arm64.go @@ -0,0 +1,17 @@ +//go:build arm64 + +package algo + +// indexByteTwo returns the index of the first occurrence of b1 or b2 in s, +// or -1 if neither is present. Implemented in assembly using ARM64 NEON +// to search for both bytes in a single pass. +// +//go:noescape +func indexByteTwo(s []byte, b1, b2 byte) int + +// lastIndexByteTwo returns the index of the last occurrence of b1 or b2 in s, +// or -1 if neither is present. Implemented in assembly using ARM64 NEON, +// scanning backward. +// +//go:noescape +func lastIndexByteTwo(s []byte, b1, b2 byte) int diff --git a/src/algo/indexbyte2_arm64.s b/src/algo/indexbyte2_arm64.s new file mode 100644 index 00000000..7442c4dd --- /dev/null +++ b/src/algo/indexbyte2_arm64.s @@ -0,0 +1,249 @@ +#include "textflag.h" + +// func indexByteTwo(s []byte, b1, b2 byte) int +// +// Returns the index of the first occurrence of b1 or b2 in s, or -1. +// Uses ARM64 NEON to search for both bytes in a single pass over the data. +// Adapted from Go's internal/bytealg/indexbyte_arm64.s (single-byte version). +TEXT ·indexByteTwo(SB),NOSPLIT,$0-40 + MOVD s_base+0(FP), R0 + MOVD s_len+8(FP), R2 + MOVBU b1+24(FP), R1 + MOVBU b2+25(FP), R7 + MOVD $ret+32(FP), R8 + + // Core algorithm: + // For each 32-byte chunk we calculate a 64-bit syndrome value, + // with two bits per byte. We compare against both b1 and b2, + // OR the results, then use the same syndrome extraction as + // Go's IndexByte. + + CBZ R2, fail + MOVD R0, R11 + // Magic constant 0x40100401 allows us to identify which lane matches. + // Each byte in the group of 4 gets a distinct bit: 1, 4, 16, 64. + MOVD $0x40100401, R5 + VMOV R1, V0.B16 // V0 = splat(b1) + VMOV R7, V7.B16 // V7 = splat(b2) + // Work with aligned 32-byte chunks + BIC $0x1f, R0, R3 + VMOV R5, V5.S4 + ANDS $0x1f, R0, R9 + AND $0x1f, R2, R10 + BEQ loop + + // Input string is not 32-byte aligned. Process the first + // aligned 32-byte block and mask off bytes before our start. + VLD1.P (R3), [V1.B16, V2.B16] + SUB $0x20, R9, R4 + ADDS R4, R2, R2 + // Compare against both needles + VCMEQ V0.B16, V1.B16, V3.B16 // b1 vs first 16 bytes + VCMEQ V7.B16, V1.B16, V8.B16 // b2 vs first 16 bytes + VORR V8.B16, V3.B16, V3.B16 // combine + VCMEQ V0.B16, V2.B16, V4.B16 // b1 vs second 16 bytes + VCMEQ V7.B16, V2.B16, V9.B16 // b2 vs second 16 bytes + VORR V9.B16, V4.B16, V4.B16 // combine + // Build syndrome + VAND V5.B16, V3.B16, V3.B16 + VAND V5.B16, V4.B16, V4.B16 + VADDP V4.B16, V3.B16, V6.B16 + VADDP V6.B16, V6.B16, V6.B16 + VMOV V6.D[0], R6 + // Clear the irrelevant lower bits + LSL $1, R9, R4 + LSR R4, R6, R6 + LSL R4, R6, R6 + // The first block can also be the last + BLS masklast + // Have we found something already? + CBNZ R6, tail + +loop: + VLD1.P (R3), [V1.B16, V2.B16] + SUBS $0x20, R2, R2 + // Compare against both needles, OR results + VCMEQ V0.B16, V1.B16, V3.B16 + VCMEQ V7.B16, V1.B16, V8.B16 + VORR V8.B16, V3.B16, V3.B16 + VCMEQ V0.B16, V2.B16, V4.B16 + VCMEQ V7.B16, V2.B16, V9.B16 + VORR V9.B16, V4.B16, V4.B16 + // If we're out of data we finish regardless of the result + BLS end + // Fast check: OR both halves and check for any match + VORR V4.B16, V3.B16, V6.B16 + VADDP V6.D2, V6.D2, V6.D2 + VMOV V6.D[0], R6 + CBZ R6, loop + +end: + // Found something or out of data — build full syndrome + VAND V5.B16, V3.B16, V3.B16 + VAND V5.B16, V4.B16, V4.B16 + VADDP V4.B16, V3.B16, V6.B16 + VADDP V6.B16, V6.B16, V6.B16 + VMOV V6.D[0], R6 + // Only mask for the last block + BHS tail + +masklast: + // Clear irrelevant upper bits + ADD R9, R10, R4 + AND $0x1f, R4, R4 + SUB $0x20, R4, R4 + NEG R4<<1, R4 + LSL R4, R6, R6 + LSR R4, R6, R6 + +tail: + CBZ R6, fail + RBIT R6, R6 + SUB $0x20, R3, R3 + CLZ R6, R6 + ADD R6>>1, R3, R0 + SUB R11, R0, R0 + MOVD R0, (R8) + RET + +fail: + MOVD $-1, R0 + MOVD R0, (R8) + RET + +// func lastIndexByteTwo(s []byte, b1, b2 byte) int +// +// Returns the index of the last occurrence of b1 or b2 in s, or -1. +// Scans backward using ARM64 NEON. +TEXT ·lastIndexByteTwo(SB),NOSPLIT,$0-40 + MOVD s_base+0(FP), R0 + MOVD s_len+8(FP), R2 + MOVBU b1+24(FP), R1 + MOVBU b2+25(FP), R7 + MOVD $ret+32(FP), R8 + + CBZ R2, lfail + MOVD R0, R11 // save base + ADD R0, R2, R12 // R12 = end = base + len + MOVD $0x40100401, R5 + VMOV R1, V0.B16 // V0 = splat(b1) + VMOV R7, V7.B16 // V7 = splat(b2) + VMOV R5, V5.S4 + + // Align: find the aligned block containing the last byte + SUB $1, R12, R3 + BIC $0x1f, R3, R3 // R3 = start of aligned block containing last byte + + // --- Process tail block --- + VLD1 (R3), [V1.B16, V2.B16] + VCMEQ V0.B16, V1.B16, V3.B16 + VCMEQ V7.B16, V1.B16, V8.B16 + VORR V8.B16, V3.B16, V3.B16 + VCMEQ V0.B16, V2.B16, V4.B16 + VCMEQ V7.B16, V2.B16, V9.B16 + VORR V9.B16, V4.B16, V4.B16 + VAND V5.B16, V3.B16, V3.B16 + VAND V5.B16, V4.B16, V4.B16 + VADDP V4.B16, V3.B16, V6.B16 + VADDP V6.B16, V6.B16, V6.B16 + VMOV V6.D[0], R6 + + // Mask upper bits (bytes past end of slice) + // tail_bytes = end - R3 (1..32) + SUB R3, R12, R10 // R10 = tail_bytes + MOVD $64, R4 + SUB R10<<1, R4, R4 // R4 = 64 - 2*tail_bytes + LSL R4, R6, R6 + LSR R4, R6, R6 + + // Is this also the head block? + CMP R11, R3 // R3 - R11 + BLO lmaskfirst // R3 < base: head+tail in same block + BEQ ltailonly // R3 == base: single aligned block + + // R3 > base: more blocks before this one + CBNZ R6, llast + B lbacksetup + +ltailonly: + // Single block, already masked upper bits + CBNZ R6, llast + B lfail + +lmaskfirst: + // Mask lower bits (bytes before start of slice) + SUB R3, R11, R4 // R4 = base - R3 + LSL $1, R4, R4 + LSR R4, R6, R6 + LSL R4, R6, R6 + CBNZ R6, llast + B lfail + +lbacksetup: + SUB $0x20, R3 + +lbackloop: + VLD1 (R3), [V1.B16, V2.B16] + VCMEQ V0.B16, V1.B16, V3.B16 + VCMEQ V7.B16, V1.B16, V8.B16 + VORR V8.B16, V3.B16, V3.B16 + VCMEQ V0.B16, V2.B16, V4.B16 + VCMEQ V7.B16, V2.B16, V9.B16 + VORR V9.B16, V4.B16, V4.B16 + // Quick check: any match in this block? + VORR V4.B16, V3.B16, V6.B16 + VADDP V6.D2, V6.D2, V6.D2 + VMOV V6.D[0], R6 + + // Is this a head block? (R3 < base) + CMP R11, R3 + BLO lheadblock + + // Full block (R3 >= base) + CBNZ R6, lbackfound + // More blocks? + BEQ lfail // R3 == base, no more + SUB $0x20, R3 + B lbackloop + +lbackfound: + // Build full syndrome + VAND V5.B16, V3.B16, V3.B16 + VAND V5.B16, V4.B16, V4.B16 + VADDP V4.B16, V3.B16, V6.B16 + VADDP V6.B16, V6.B16, V6.B16 + VMOV V6.D[0], R6 + B llast + +lheadblock: + // R3 < base. Build full syndrome if quick check had a match. + CBZ R6, lfail + VAND V5.B16, V3.B16, V3.B16 + VAND V5.B16, V4.B16, V4.B16 + VADDP V4.B16, V3.B16, V6.B16 + VADDP V6.B16, V6.B16, V6.B16 + VMOV V6.D[0], R6 + // Mask lower bits + SUB R3, R11, R4 // R4 = base - R3 + LSL $1, R4, R4 + LSR R4, R6, R6 + LSL R4, R6, R6 + CBZ R6, lfail + +llast: + // Find last match: highest set bit in syndrome + // Syndrome has bit 2i set for matching byte i. + // CLZ gives leading zeros; byte_offset = (63 - CLZ) / 2. + CLZ R6, R6 + MOVD $63, R4 + SUB R6, R4, R6 // R6 = 63 - CLZ = bit position + LSR $1, R6 // R6 = byte offset within block + ADD R3, R6, R0 // R0 = absolute address + SUB R11, R0, R0 // R0 = slice index + MOVD R0, (R8) + RET + +lfail: + MOVD $-1, R0 + MOVD R0, (R8) + RET diff --git a/src/algo/indexbyte2_other.go b/src/algo/indexbyte2_other.go new file mode 100644 index 00000000..44041ff0 --- /dev/null +++ b/src/algo/indexbyte2_other.go @@ -0,0 +1,33 @@ +//go:build !arm64 && !amd64 + +package algo + +import "bytes" + +// indexByteTwo returns the index of the first occurrence of b1 or b2 in s, +// or -1 if neither is present. +func indexByteTwo(s []byte, b1, b2 byte) int { + i1 := bytes.IndexByte(s, b1) + if i1 == 0 { + return 0 + } + scope := s + if i1 > 0 { + scope = s[:i1] + } + if i2 := bytes.IndexByte(scope, b2); i2 >= 0 { + return i2 + } + return i1 +} + +// lastIndexByteTwo returns the index of the last occurrence of b1 or b2 in s, +// or -1 if neither is present. +func lastIndexByteTwo(s []byte, b1, b2 byte) int { + for i := len(s) - 1; i >= 0; i-- { + if s[i] == b1 || s[i] == b2 { + return i + } + } + return -1 +} diff --git a/src/algo/indexbyte2_test.go b/src/algo/indexbyte2_test.go new file mode 100644 index 00000000..9f4dedac --- /dev/null +++ b/src/algo/indexbyte2_test.go @@ -0,0 +1,259 @@ +package algo + +import ( + "bytes" + "testing" +) + +func TestIndexByteTwo(t *testing.T) { + tests := []struct { + name string + s string + b1 byte + b2 byte + want int + }{ + {"empty", "", 'a', 'b', -1}, + {"single_b1", "a", 'a', 'b', 0}, + {"single_b2", "b", 'a', 'b', 0}, + {"single_none", "c", 'a', 'b', -1}, + {"b1_first", "xaxb", 'a', 'b', 1}, + {"b2_first", "xbxa", 'a', 'b', 1}, + {"same_byte", "xxa", 'a', 'a', 2}, + {"at_end", "xxxxa", 'a', 'b', 4}, + {"not_found", "xxxxxxxx", 'a', 'b', -1}, + {"long_b1_at_3000", string(make([]byte, 3000)) + "a" + string(make([]byte, 1000)), 'a', 'b', 3000}, + {"long_b2_at_3000", string(make([]byte, 3000)) + "b" + string(make([]byte, 1000)), 'a', 'b', 3000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := indexByteTwo([]byte(tt.s), tt.b1, tt.b2) + if got != tt.want { + t.Errorf("indexByteTwo(%q, %c, %c) = %d, want %d", tt.s[:min(len(tt.s), 40)], tt.b1, tt.b2, got, tt.want) + } + }) + } + + // Exhaustive test: compare against loop reference for various lengths, + // including sizes around SIMD block boundaries (16, 32, 64). + for n := 0; n <= 256; n++ { + data := make([]byte, n) + for i := range data { + data[i] = byte('c' + (i % 20)) + } + // Test with match at every position + for pos := 0; pos < n; pos++ { + for _, b := range []byte{'A', 'B'} { + data[pos] = b + got := indexByteTwo(data, 'A', 'B') + want := loopIndexByteTwo(data, 'A', 'B') + if got != want { + t.Fatalf("indexByteTwo(len=%d, match=%c@%d) = %d, want %d", n, b, pos, got, want) + } + data[pos] = byte('c' + (pos % 20)) + } + } + // Test with no match + got := indexByteTwo(data, 'A', 'B') + if got != -1 { + t.Fatalf("indexByteTwo(len=%d, no match) = %d, want -1", n, got) + } + // Test with both bytes present + if n >= 2 { + data[n/3] = 'A' + data[n*2/3] = 'B' + got := indexByteTwo(data, 'A', 'B') + want := loopIndexByteTwo(data, 'A', 'B') + if got != want { + t.Fatalf("indexByteTwo(len=%d, both@%d,%d) = %d, want %d", n, n/3, n*2/3, got, want) + } + data[n/3] = byte('c' + ((n / 3) % 20)) + data[n*2/3] = byte('c' + ((n * 2 / 3) % 20)) + } + } +} + +func TestLastIndexByteTwo(t *testing.T) { + tests := []struct { + name string + s string + b1 byte + b2 byte + want int + }{ + {"empty", "", 'a', 'b', -1}, + {"single_b1", "a", 'a', 'b', 0}, + {"single_b2", "b", 'a', 'b', 0}, + {"single_none", "c", 'a', 'b', -1}, + {"b1_last", "xbxa", 'a', 'b', 3}, + {"b2_last", "xaxb", 'a', 'b', 3}, + {"same_byte", "axx", 'a', 'a', 0}, + {"at_start", "axxxx", 'a', 'b', 0}, + {"both_present", "axbx", 'a', 'b', 2}, + {"not_found", "xxxxxxxx", 'a', 'b', -1}, + {"long_b1_at_3000", string(make([]byte, 3000)) + "a" + string(make([]byte, 1000)), 'a', 'b', 3000}, + {"long_b2_at_end", string(make([]byte, 4000)) + "b", 'a', 'b', 4000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := lastIndexByteTwo([]byte(tt.s), tt.b1, tt.b2) + if got != tt.want { + t.Errorf("lastIndexByteTwo(%q, %c, %c) = %d, want %d", tt.s[:min(len(tt.s), 40)], tt.b1, tt.b2, got, tt.want) + } + }) + } + + // Exhaustive test against loop reference + for n := 0; n <= 256; n++ { + data := make([]byte, n) + for i := range data { + data[i] = byte('c' + (i % 20)) + } + for pos := 0; pos < n; pos++ { + for _, b := range []byte{'A', 'B'} { + data[pos] = b + got := lastIndexByteTwo(data, 'A', 'B') + want := refLastIndexByteTwo(data, 'A', 'B') + if got != want { + t.Fatalf("lastIndexByteTwo(len=%d, match=%c@%d) = %d, want %d", n, b, pos, got, want) + } + data[pos] = byte('c' + (pos % 20)) + } + } + // No match + got := lastIndexByteTwo(data, 'A', 'B') + if got != -1 { + t.Fatalf("lastIndexByteTwo(len=%d, no match) = %d, want -1", n, got) + } + // Both bytes present + if n >= 2 { + data[n/3] = 'A' + data[n*2/3] = 'B' + got := lastIndexByteTwo(data, 'A', 'B') + want := refLastIndexByteTwo(data, 'A', 'B') + if got != want { + t.Fatalf("lastIndexByteTwo(len=%d, both@%d,%d) = %d, want %d", n, n/3, n*2/3, got, want) + } + data[n/3] = byte('c' + ((n / 3) % 20)) + data[n*2/3] = byte('c' + ((n * 2 / 3) % 20)) + } + } +} + +func FuzzIndexByteTwo(f *testing.F) { + f.Add([]byte("hello world"), byte('o'), byte('l')) + f.Add([]byte(""), byte('a'), byte('b')) + f.Add([]byte("aaa"), byte('a'), byte('a')) + f.Fuzz(func(t *testing.T, data []byte, b1, b2 byte) { + got := indexByteTwo(data, b1, b2) + want := loopIndexByteTwo(data, b1, b2) + if got != want { + t.Errorf("indexByteTwo(len=%d, b1=%d, b2=%d) = %d, want %d", len(data), b1, b2, got, want) + } + }) +} + +func FuzzLastIndexByteTwo(f *testing.F) { + f.Add([]byte("hello world"), byte('o'), byte('l')) + f.Add([]byte(""), byte('a'), byte('b')) + f.Add([]byte("aaa"), byte('a'), byte('a')) + f.Fuzz(func(t *testing.T, data []byte, b1, b2 byte) { + got := lastIndexByteTwo(data, b1, b2) + want := refLastIndexByteTwo(data, b1, b2) + if got != want { + t.Errorf("lastIndexByteTwo(len=%d, b1=%d, b2=%d) = %d, want %d", len(data), b1, b2, got, want) + } + }) +} + +// Reference implementations for correctness checking +func refIndexByteTwo(s []byte, b1, b2 byte) int { + i1 := bytes.IndexByte(s, b1) + if i1 == 0 { + return 0 + } + scope := s + if i1 > 0 { + scope = s[:i1] + } + if i2 := bytes.IndexByte(scope, b2); i2 >= 0 { + return i2 + } + return i1 +} + +func loopIndexByteTwo(s []byte, b1, b2 byte) int { + for i, b := range s { + if b == b1 || b == b2 { + return i + } + } + return -1 +} + +func refLastIndexByteTwo(s []byte, b1, b2 byte) int { + for i := len(s) - 1; i >= 0; i-- { + if s[i] == b1 || s[i] == b2 { + return i + } + } + return -1 +} + +func benchIndexByteTwo(b *testing.B, size int, pos int) { + data := make([]byte, size) + for i := range data { + data[i] = byte('a' + (i % 20)) + } + data[pos] = 'Z' + + type impl struct { + name string + fn func([]byte, byte, byte) int + } + impls := []impl{ + {"asm", indexByteTwo}, + {"2xIndexByte", refIndexByteTwo}, + {"loop", loopIndexByteTwo}, + } + for _, im := range impls { + b.Run(im.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + im.fn(data, 'Z', 'z') + } + }) + } +} + +func benchLastIndexByteTwo(b *testing.B, size int, pos int) { + data := make([]byte, size) + for i := range data { + data[i] = byte('a' + (i % 20)) + } + data[pos] = 'Z' + + type impl struct { + name string + fn func([]byte, byte, byte) int + } + impls := []impl{ + {"asm", lastIndexByteTwo}, + {"loop", refLastIndexByteTwo}, + } + for _, im := range impls { + b.Run(im.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + im.fn(data, 'Z', 'z') + } + }) + } +} + +func BenchmarkIndexByteTwo_10(b *testing.B) { benchIndexByteTwo(b, 10, 8) } +func BenchmarkIndexByteTwo_100(b *testing.B) { benchIndexByteTwo(b, 100, 80) } +func BenchmarkIndexByteTwo_1000(b *testing.B) { benchIndexByteTwo(b, 1000, 800) } +func BenchmarkLastIndexByteTwo_10(b *testing.B) { benchLastIndexByteTwo(b, 10, 2) } +func BenchmarkLastIndexByteTwo_100(b *testing.B) { benchLastIndexByteTwo(b, 100, 20) } +func BenchmarkLastIndexByteTwo_1000(b *testing.B) { benchLastIndexByteTwo(b, 1000, 200) }