[PATCH] chacha20-avx512: add handling for any input block count and tweak 16 block code a bit

Jussi Kivilinna jussi.kivilinna at iki.fi
Mon Dec 5 21:17:55 CET 2022


* cipher/chacha20-amd64-avx512.S: Add tail handling for 8/4/2/1
blocks; Rename `_gcry_chacha20_amd64_avx512_blocks16` to
`_gcry_chacha20_amd64_avx512_blocks`; Tweak 16 parallel block processing
for small speed improvement.
* cipher/chacha20.c (_gcry_chacha20_amd64_avx512_blocks16): Rename to ...
(_gcry_chacha20_amd64_avx512_blocks): ... this.
(chacha20_blocks) [USE_AVX512]: Add AVX512 code-path.
(do_chacha20_encrypt_stream_tail) [USE_AVX512]: Change to handle any
number of full input blocks instead of multiples of 16.
--

Patch improves performance of ChaCha20-AVX512 implementation on small
input buffer sizes (less than 64*16B = 1024B).

===

Benchmark on AMD Ryzen 9 7900X:

Before:
 CHACHA20       |  nanosecs/byte   mebibytes/sec   cycles/byte  auto Mhz
     STREAM enc |     0.130 ns/B      7330 MiB/s     0.716 c/B      5500
     STREAM dec |     0.128 ns/B      7426 MiB/s     0.713 c/B      5555
   POLY1305 enc |     0.175 ns/B      5444 MiB/s     0.964 c/B      5500
   POLY1305 dec |     0.175 ns/B      5455 MiB/s     0.962 c/B      5500

After:
 CHACHA20       |  nanosecs/byte   mebibytes/sec   cycles/byte  auto Mhz
     STREAM enc |     0.124 ns/B      7675 MiB/s     0.699 c/B      5625
     STREAM dec |     0.126 ns/B      7544 MiB/s     0.695 c/B      5500
   POLY1305 enc |     0.170 ns/B      5626 MiB/s     0.954 c/B      5625
   POLY1305 dec |     0.169 ns/B      5639 MiB/s     0.945 c/B      5587

===

Benchmark on Intel Core i3-1115G4:

Before:
 CHACHA20       |  nanosecs/byte   mebibytes/sec   cycles/byte  auto Mhz
     STREAM enc |     0.161 ns/B      5934 MiB/s     0.658 c/B      4097±3
     STREAM dec |     0.160 ns/B      5951 MiB/s     0.656 c/B      4097±4
   POLY1305 enc |     0.220 ns/B      4333 MiB/s     0.902 c/B      4096±3
   POLY1305 dec |     0.220 ns/B      4325 MiB/s     0.903 c/B      4096±3

After:
 CHACHA20       |  nanosecs/byte   mebibytes/sec   cycles/byte  auto Mhz
     STREAM enc |     0.154 ns/B      6186 MiB/s     0.631 c/B      4096±3
     STREAM dec |     0.153 ns/B      6215 MiB/s     0.629 c/B      4096±3
   POLY1305 enc |     0.216 ns/B      4407 MiB/s     0.886 c/B      4096±3
   POLY1305 dec |     0.216 ns/B      4419 MiB/s     0.884 c/B      4096±3

Signed-off-by: Jussi Kivilinna <jussi.kivilinna at iki.fi>
---
 cipher/chacha20-amd64-avx512.S | 463 ++++++++++++++++++++++++++++++---
 cipher/chacha20.c              |  24 +-
 2 files changed, 447 insertions(+), 40 deletions(-)

diff --git a/cipher/chacha20-amd64-avx512.S b/cipher/chacha20-amd64-avx512.S
index 8b4d7499..b48b1bf7 100644
--- a/cipher/chacha20-amd64-avx512.S
+++ b/cipher/chacha20-amd64-avx512.S
@@ -61,14 +61,56 @@
 #define X13 %zmm13
 #define X14 %zmm14
 #define X15 %zmm15
+#define X0y %ymm0
+#define X1y %ymm1
+#define X2y %ymm2
+#define X3y %ymm3
+#define X4y %ymm4
+#define X5y %ymm5
+#define X6y %ymm6
+#define X7y %ymm7
+#define X8y %ymm8
+#define X9y %ymm9
+#define X10y %ymm10
+#define X11y %ymm11
+#define X12y %ymm12
+#define X13y %ymm13
+#define X14y %ymm14
+#define X15y %ymm15
+#define X0x %xmm0
+#define X1x %xmm1
+#define X2x %xmm2
+#define X3x %xmm3
+#define X4x %xmm4
+#define X5x %xmm5
+#define X6x %xmm6
+#define X7x %xmm7
+#define X8x %xmm8
+#define X9x %xmm9
+#define X10x %xmm10
+#define X11x %xmm11
+#define X12x %xmm12
+#define X13x %xmm13
+#define X14x %xmm14
+#define X15x %xmm15
 
 #define TMP0 %zmm16
 #define TMP1 %zmm17
+#define TMP0y %ymm16
+#define TMP1y %ymm17
+#define TMP0x %xmm16
+#define TMP1x %xmm17
 
 #define COUNTER_ADD %zmm18
+#define COUNTER_ADDy %ymm18
+#define COUNTER_ADDx %xmm18
 
 #define X12_SAVE %zmm19
+#define X12_SAVEy %ymm19
+#define X12_SAVEx %xmm19
 #define X13_SAVE %zmm20
+#define X13_SAVEy %ymm20
+#define X13_SAVEx %xmm20
 
 #define S0 %zmm21
 #define S1 %zmm22
@@ -81,6 +123,28 @@
 #define S8 %zmm29
 #define S14 %zmm30
 #define S15 %zmm31
+#define S0y %ymm21
+#define S1y %ymm22
+#define S2y %ymm23
+#define S3y %ymm24
+#define S4y %ymm25
+#define S5y %ymm26
+#define S6y %ymm27
+#define S7y %ymm28
+#define S8y %ymm29
+#define S14y %ymm30
+#define S15y %ymm31
+#define S0x %xmm21
+#define S1x %xmm22
+#define S2x %xmm23
+#define S3x %xmm24
+#define S4x %xmm25
+#define S5x %xmm26
+#define S6x %xmm27
+#define S7x %xmm28
+#define S8x %xmm29
+#define S14x %xmm30
+#define S15x %xmm31
 
 /**********************************************************************
   helper macros
@@ -114,6 +178,12 @@
 	vshufi32x4 $0xdd, x2, t2, x3; \
 	vshufi32x4 $0x88, x2, t2, x2;
 
+/* 2x2 128-bit matrix transpose */
+#define transpose_16byte_2x2(x0,x1,t1) \
+	vmovdqa32  x0, t1; \
+	vshufi32x4 $0x0, x1, x0, x0; \
+	vshufi32x4 $0x3, x1, t1, x1;
+
 #define xor_src_dst_4x4(dst, src, offset, add, x0, x4, x8, x12) \
 	vpxord (offset + 0 * (add))(src), x0, x0; \
 	vpxord (offset + 1 * (add))(src), x4, x4; \
@@ -141,7 +211,7 @@
 	clear_vec4(%xmm19, %xmm23, %xmm27, %xmm31);
 
 /**********************************************************************
-  16-way chacha20
+  16-way (zmm), 8-way (ymm), 4-way (xmm) chacha20
  **********************************************************************/
 
 #define ROTATE2(v1,v2,c)	\
@@ -154,7 +224,7 @@
 #define PLUS(ds,s) \
 	vpaddd s, ds, ds;
 
-#define QUARTERROUND2(a1,b1,c1,d1,a2,b2,c2,d2)			\
+#define QUARTERROUND2V(a1,b1,c1,d1,a2,b2,c2,d2)			\
 	PLUS(a1,b1); PLUS(a2,b2); XOR(d1,a1); XOR(d2,a2);	\
 	    ROTATE2(d1, d2, 16);				\
 	PLUS(c1,d1); PLUS(c2,d2); XOR(b1,c1); XOR(b2,c2);	\
@@ -164,33 +234,99 @@
 	PLUS(c1,d1); PLUS(c2,d2); XOR(b1,c1); XOR(b2,c2);	\
 	    ROTATE2(b1, b2, 7);
 
+/**********************************************************************
+  1-way/2-way (xmm) chacha20
+ **********************************************************************/
+
+#define ROTATE(v1,c)			\
+	vprold $(c), v1, v1;		\
+
+#define WORD_SHUF(v1,shuf)		\
+	vpshufd $shuf, v1, v1;
+
+#define QUARTERROUND1H(x0,x1,x2,x3,shuf_x1,shuf_x2,shuf_x3) \
+	PLUS(x0, x1); XOR(x3, x0); ROTATE(x3, 16); \
+	PLUS(x2, x3); XOR(x1, x2); ROTATE(x1, 12); \
+	PLUS(x0, x1); XOR(x3, x0); ROTATE(x3, 8); \
+	PLUS(x2, x3); \
+	  WORD_SHUF(x3, shuf_x3); \
+		      XOR(x1, x2); \
+	  WORD_SHUF(x2, shuf_x2); \
+				   ROTATE(x1, 7); \
+	  WORD_SHUF(x1, shuf_x1);
+
+#define QUARTERROUND2H(x0,x1,x2,x3,y0,y1,y2,y3,shuf_x1,shuf_x2,shuf_x3) \
+	PLUS(x0, x1); PLUS(y0, y1); XOR(x3, x0); XOR(y3, y0); \
+	  ROTATE(x3, 16); ROTATE(y3, 16); \
+	PLUS(x2, x3); PLUS(y2, y3); XOR(x1, x2); XOR(y1, y2); \
+	  ROTATE(x1, 12); ROTATE(y1, 12); \
+	PLUS(x0, x1); PLUS(y0, y1); XOR(x3, x0); XOR(y3, y0); \
+	  ROTATE(x3, 8); ROTATE(y3, 8); \
+	PLUS(x2, x3); PLUS(y2, y3); \
+	  WORD_SHUF(x3, shuf_x3); WORD_SHUF(y3, shuf_x3); \
+		      XOR(x1, x2); XOR(y1, y2); \
+	  WORD_SHUF(x2, shuf_x2); WORD_SHUF(y2, shuf_x2); \
+				   ROTATE(x1, 7); ROTATE(y1, 7); \
+	  WORD_SHUF(x1, shuf_x1); WORD_SHUF(y1, shuf_x1);
+
 .align 64
 ELF(.type _gcry_chacha20_amd64_avx512_data, at object;)
 _gcry_chacha20_amd64_avx512_data:
-.Linc_counter:
-	.byte 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
+.Lcounter_0_1_2_3:
+.Lcounter_0_1:
+	.long 0,0,0,0
 .Lone:
 	.long 1,0,0,0
+.Lcounter_2_3:
+.Ltwo:
+	.long 2,0,0,0
+.Lthree:
+	.long 3,0,0,0
+.Linc_counter:
+	.byte 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
 ELF(.size _gcry_chacha20_amd64_avx512_data,.-_gcry_chacha20_amd64_avx512_data)
 
 .align 16
-.globl _gcry_chacha20_amd64_avx512_blocks16
-ELF(.type _gcry_chacha20_amd64_avx512_blocks16, at function;)
-_gcry_chacha20_amd64_avx512_blocks16:
+.globl _gcry_chacha20_amd64_avx512_blocks
+ELF(.type _gcry_chacha20_amd64_avx512_blocks, at function;)
+_gcry_chacha20_amd64_avx512_blocks:
 	/* input:
 	 *	%rdi: input
 	 *	%rsi: dst
 	 *	%rdx: src
-	 *	%rcx: nblks (multiple of 16)
+	 *	%rcx: nblks
 	 */
 	CFI_STARTPROC();
 
 	vpxord %xmm16, %xmm16, %xmm16;
-	vpopcntb %zmm16, %zmm16; /* spec stop for old AVX512 CPUs */
+	vpopcntb %ymm16, %ymm16; /* spec stop for old AVX512 CPUs */
+
+	cmpq $4, NBLKS;
+	jb .Lskip_vertical_handling;
 
+	/* Load constants */
 	vpmovzxbd .Linc_counter rRIP, COUNTER_ADD;
 
-	/* Preload state */
+	cmpq $16, NBLKS;
+	jae .Lload_zmm_state;
+
+	/* Preload state to YMM registers */
+	vpbroadcastd (0 * 4)(INPUT), S0y;
+	vpbroadcastd (1 * 4)(INPUT), S1y;
+	vpbroadcastd (2 * 4)(INPUT), S2y;
+	vpbroadcastd (3 * 4)(INPUT), S3y;
+	vpbroadcastd (4 * 4)(INPUT), S4y;
+	vpbroadcastd (5 * 4)(INPUT), S5y;
+	vpbroadcastd (6 * 4)(INPUT), S6y;
+	vpbroadcastd (7 * 4)(INPUT), S7y;
+	vpbroadcastd (8 * 4)(INPUT), S8y;
+	vpbroadcastd (14 * 4)(INPUT), S14y;
+	vpbroadcastd (15 * 4)(INPUT), S15y;
+	jmp .Lskip16v;
+
+.align 16
+.Lload_zmm_state:
+	/* Preload state to ZMM registers */
 	vpbroadcastd (0 * 4)(INPUT), S0;
 	vpbroadcastd (1 * 4)(INPUT), S1;
 	vpbroadcastd (2 * 4)(INPUT), S2;
@@ -204,13 +340,14 @@ _gcry_chacha20_amd64_avx512_blocks16:
 	vpbroadcastd (15 * 4)(INPUT), S15;
 
 .align 16
-.Loop16:
+.Loop16v:
+	/* Process 16 ChaCha20 blocks */
 	movl $20, ROUND;
+	subq $16, NBLKS;
 
 	/* Construct counter vectors X12 and X13 */
-	vpbroadcastd (12 * 4)(INPUT), X12;
+	vpaddd (12 * 4)(INPUT){1to16}, COUNTER_ADD, X12;
 	vpbroadcastd (13 * 4)(INPUT), X13;
-	vpaddd COUNTER_ADD, X12, X12;
 	vpcmpud $6, X12, COUNTER_ADD, %k2;
 	vpaddd .Lone rRIP {1to16}, X13, X13{%k2};
 	vmovdqa32 X12, X12_SAVE;
@@ -223,7 +360,7 @@ _gcry_chacha20_amd64_avx512_blocks16:
 	vmovdqa32 S1, X1;
 	vmovdqa32 S5, X5;
 	vpbroadcastd (9 * 4)(INPUT), X9;
-	QUARTERROUND2(X0, X4,  X8, X12,   X1, X5,  X9, X13)
+	QUARTERROUND2V(X0, X4,  X8, X12,   X1, X5,  X9, X13)
 	vmovdqa32 S2, X2;
 	vmovdqa32 S6, X6;
 	vpbroadcastd (10 * 4)(INPUT), X10;
@@ -235,19 +372,18 @@ _gcry_chacha20_amd64_avx512_blocks16:
 
 	/* Update counter */
 	addq $16, (12 * 4)(INPUT);
-	jmp .Lround2_entry;
+	jmp .Lround2_entry_16v;
 
 .align 16
-.Lround2:
-	QUARTERROUND2(X2, X7,  X8, X13,   X3, X4,  X9, X14)
-	QUARTERROUND2(X0, X4,  X8, X12,   X1, X5,  X9, X13)
-.Lround2_entry:
+.Lround2_16v:
+	QUARTERROUND2V(X2, X7,  X8, X13,   X3, X4,  X9, X14)
+	QUARTERROUND2V(X0, X4,  X8, X12,   X1, X5,  X9, X13)
+.Lround2_entry_16v:
+	QUARTERROUND2V(X2, X6, X10, X14,   X3, X7, X11, X15)
+	QUARTERROUND2V(X0, X5, X10, X15,   X1, X6, X11, X12)
 	subl $2, ROUND;
-	QUARTERROUND2(X2, X6, X10, X14,   X3, X7, X11, X15)
-	QUARTERROUND2(X0, X5, X10, X15,   X1, X6, X11, X12)
-	jnz .Lround2;
+	jnz .Lround2_16v;
 
-.Lround2_end:
 	PLUS(X0, S0);
 	PLUS(X1, S1);
 	PLUS(X5, S5);
@@ -256,7 +392,7 @@ _gcry_chacha20_amd64_avx512_blocks16:
 	PLUS(X11, (11 * 4)(INPUT){1to16});
 	PLUS(X15, S15);
 	PLUS(X12, X12_SAVE);
-	QUARTERROUND2(X2, X7,  X8, X13,   X3, X4,  X9, X14)
+	QUARTERROUND2V(X2, X7,  X8, X13,   X3, X4,  X9, X14)
 
 	PLUS(X2, S2);
 	PLUS(X3, S3);
@@ -280,21 +416,286 @@ _gcry_chacha20_amd64_avx512_blocks16:
 	transpose_16byte_4x4(X3, X7, X11, X15, TMP0, TMP1);
 	xor_src_dst_4x4(DST, SRC, (64 * 3), (64 * 4), X3, X7, X11, X15);
 
-	subq $16, NBLKS;
 	leaq (16 * 64)(SRC), SRC;
 	leaq (16 * 64)(DST), DST;
-	jnz .Loop16;
+	cmpq $16, NBLKS;
+	jae .Loop16v;
+
+.align 8
+.Lskip16v:
+	cmpq $8, NBLKS;
+	jb .Lskip8v;
+
+	/* Process 8 ChaCha20 blocks */
+	movl $20, ROUND;
+	subq $8, NBLKS;
+
+	/* Construct counter vectors X12 and X13 */
+	vpaddd (12 * 4)(INPUT){1to8}, COUNTER_ADDy, X12y;
+	vpbroadcastd (13 * 4)(INPUT), X13y;
+	vpcmpud $6, X12y, COUNTER_ADDy, %k2;
+	vpaddd .Lone rRIP {1to8}, X13y, X13y{%k2};
+	vmovdqa32 X12y, X12_SAVEy;
+	vmovdqa32 X13y, X13_SAVEy;
+
+	/* Load vectors */
+	vmovdqa32 S0y, X0y;
+	vmovdqa32 S4y, X4y;
+	vmovdqa32 S8y, X8y;
+	vmovdqa32 S1y, X1y;
+	vmovdqa32 S5y, X5y;
+	vpbroadcastd (9 * 4)(INPUT), X9y;
+	QUARTERROUND2V(X0y, X4y,  X8y, X12y,   X1y, X5y,  X9y, X13y)
+	vmovdqa32 S2y, X2y;
+	vmovdqa32 S6y, X6y;
+	vpbroadcastd (10 * 4)(INPUT), X10y;
+	vmovdqa32 S14y, X14y;
+	vmovdqa32 S3y, X3y;
+	vmovdqa32 S7y, X7y;
+	vpbroadcastd (11 * 4)(INPUT), X11y;
+	vmovdqa32 S15y, X15y;
+
+	/* Update counter */
+	addq $8, (12 * 4)(INPUT);
+	jmp .Lround2_entry_8v;
+
+.align 16
+.Lround2_8v:
+	QUARTERROUND2V(X2y, X7y,  X8y, X13y,   X3y, X4y,  X9y, X14y)
+	QUARTERROUND2V(X0y, X4y,  X8y, X12y,   X1y, X5y,  X9y, X13y)
+.Lround2_entry_8v:
+	QUARTERROUND2V(X2y, X6y, X10y, X14y,   X3y, X7y, X11y, X15y)
+	QUARTERROUND2V(X0y, X5y, X10y, X15y,   X1y, X6y, X11y, X12y)
+	subl $2, ROUND;
+	jnz .Lround2_8v;
+
+	PLUS(X0y, S0y);
+	PLUS(X1y, S1y);
+	PLUS(X5y, S5y);
+	PLUS(X6y, S6y);
+	PLUS(X10y, (10 * 4)(INPUT){1to8});
+	PLUS(X11y, (11 * 4)(INPUT){1to8});
+	PLUS(X15y, S15y);
+	PLUS(X12y, X12_SAVEy);
+	QUARTERROUND2V(X2y, X7y,  X8y, X13y,   X3y, X4y,  X9y, X14y)
+
+	PLUS(X2y, S2y);
+	PLUS(X3y, S3y);
+	PLUS(X4y, S4y);
+	PLUS(X7y, S7y);
+	transpose_4x4(X0y, X1y, X2y, X3y, TMP0y, TMP1y);
+	transpose_4x4(X4y, X5y, X6y, X7y, TMP0y, TMP1y);
+	PLUS(X8y, S8y);
+	PLUS(X9y, (9 * 4)(INPUT){1to8});
+	PLUS(X13y, X13_SAVEy);
+	PLUS(X14y, S14y);
+	transpose_16byte_2x2(X0y, X4y, TMP0y);
+	transpose_16byte_2x2(X1y, X5y, TMP0y);
+	transpose_16byte_2x2(X2y, X6y, TMP0y);
+	transpose_16byte_2x2(X3y, X7y, TMP0y);
+	transpose_4x4(X8y, X9y, X10y, X11y, TMP0y, TMP1y);
+	transpose_4x4(X12y, X13y, X14y, X15y, TMP0y, TMP1y);
+	xor_src_dst_4x4(DST, SRC, (16 * 0),  (64 * 1), X0y, X1y, X2y, X3y);
+	xor_src_dst_4x4(DST, SRC, (16 * 16), (64 * 1), X4y, X5y, X6y, X7y);
+	transpose_16byte_2x2(X8y, X12y, TMP0y);
+	transpose_16byte_2x2(X9y, X13y, TMP0y);
+	transpose_16byte_2x2(X10y, X14y, TMP0y);
+	transpose_16byte_2x2(X11y, X15y, TMP0y);
+	xor_src_dst_4x4(DST, SRC, (16 * 2),  (64 * 1), X8y, X9y, X10y, X11y);
+	xor_src_dst_4x4(DST, SRC, (16 * 18), (64 * 1), X12y, X13y, X14y, X15y);
+
+	leaq (8 * 64)(SRC), SRC;
+	leaq (8 * 64)(DST), DST;
+
+.align 8
+.Lskip8v:
+	cmpq $4, NBLKS;
+	jb .Lskip4v;
+
+	/* Process 4 ChaCha20 blocks */
+	movl $20, ROUND;
+	subq $4, NBLKS;
+
+	/* Construct counter vectors X12 and X13 */
+	vpaddd (12 * 4)(INPUT){1to4}, COUNTER_ADDx, X12x;
+	vpbroadcastd (13 * 4)(INPUT), X13x;
+	vpcmpud $6, X12x, COUNTER_ADDx, %k2;
+	vpaddd .Lone rRIP {1to4}, X13x, X13x{%k2};
+	vmovdqa32 X12x, X12_SAVEx;
+	vmovdqa32 X13x, X13_SAVEx;
+
+	/* Load vectors */
+	vmovdqa32 S0x, X0x;
+	vmovdqa32 S4x, X4x;
+	vmovdqa32 S8x, X8x;
+	vmovdqa32 S1x, X1x;
+	vmovdqa32 S5x, X5x;
+	vpbroadcastd (9 * 4)(INPUT), X9x;
+	QUARTERROUND2V(X0x, X4x,  X8x, X12x,   X1x, X5x,  X9x, X13x)
+	vmovdqa32 S2x, X2x;
+	vmovdqa32 S6x, X6x;
+	vpbroadcastd (10 * 4)(INPUT), X10x;
+	vmovdqa32 S14x, X14x;
+	vmovdqa32 S3x, X3x;
+	vmovdqa32 S7x, X7x;
+	vpbroadcastd (11 * 4)(INPUT), X11x;
+	vmovdqa32 S15x, X15x;
 
-	/* clear the used vector registers */
+	/* Update counter */
+	addq $4, (12 * 4)(INPUT);
+	jmp .Lround2_entry_4v;
+
+.align 16
+.Lround2_4v:
+	QUARTERROUND2V(X2x, X7x,  X8x, X13x,   X3x, X4x,  X9x, X14x)
+	QUARTERROUND2V(X0x, X4x,  X8x, X12x,   X1x, X5x,  X9x, X13x)
+.Lround2_entry_4v:
+	QUARTERROUND2V(X2x, X6x, X10x, X14x,   X3x, X7x, X11x, X15x)
+	QUARTERROUND2V(X0x, X5x, X10x, X15x,   X1x, X6x, X11x, X12x)
+	subl $2, ROUND;
+	jnz .Lround2_4v;
+
+	PLUS(X0x, S0x);
+	PLUS(X1x, S1x);
+	PLUS(X5x, S5x);
+	PLUS(X6x, S6x);
+	PLUS(X10x, (10 * 4)(INPUT){1to4});
+	PLUS(X11x, (11 * 4)(INPUT){1to4});
+	PLUS(X15x, S15x);
+	PLUS(X12x, X12_SAVEx);
+	QUARTERROUND2V(X2x, X7x,  X8x, X13x,   X3x, X4x,  X9x, X14x)
+
+	PLUS(X2x, S2x);
+	PLUS(X3x, S3x);
+	PLUS(X4x, S4x);
+	PLUS(X7x, S7x);
+	transpose_4x4(X0x, X1x, X2x, X3x, TMP0x, TMP1x);
+	transpose_4x4(X4x, X5x, X6x, X7x, TMP0x, TMP1x);
+	xor_src_dst_4x4(DST, SRC, (16 * 0), (64 * 1), X0x, X1x, X2x, X3x);
+	PLUS(X8x, S8x);
+	PLUS(X9x, (9 * 4)(INPUT){1to4});
+	xor_src_dst_4x4(DST, SRC, (16 * 1), (64 * 1), X4x, X5x, X6x, X7x);
+	PLUS(X13x, X13_SAVEx);
+	PLUS(X14x, S14x);
+	transpose_4x4(X8x, X9x, X10x, X11x, TMP0x, TMP1x);
+	transpose_4x4(X12x, X13x, X14x, X15x, TMP0x, TMP1x);
+	xor_src_dst_4x4(DST, SRC, (16 * 2), (64 * 1), X8x, X9x, X10x, X11x);
+	xor_src_dst_4x4(DST, SRC, (16 * 3), (64 * 1), X12x, X13x, X14x, X15x);
+
+	leaq (4 * 64)(SRC), SRC;
+	leaq (4 * 64)(DST), DST;
+
+.align 8
+.Lskip4v:
+	/* clear AVX512 registers */
+	kxorq %k2, %k2, %k2;
+	vzeroupper;
 	clear_zmm16_zmm31();
-	kxord %k2, %k2, %k2;
+
+.align 8
+.Lskip_vertical_handling:
+	cmpq $0, NBLKS;
+	je .Ldone;
+
+	/* Load state */
+	vmovdqu (0 * 4)(INPUT), X10x;
+	vmovdqu (4 * 4)(INPUT), X11x;
+	vmovdqu (8 * 4)(INPUT), X12x;
+	vmovdqu (12 * 4)(INPUT), X13x;
+
+	/* Load constant */
+	vmovdqa .Lone rRIP, X4x;
+
+	cmpq $1, NBLKS;
+	je .Lhandle1;
+
+	/* Process two ChaCha20 blocks (XMM) */
+	movl $20, ROUND;
+	subq $2, NBLKS;
+
+	vmovdqa X10x, X0x;
+	vmovdqa X11x, X1x;
+	vmovdqa X12x, X2x;
+	vmovdqa X13x, X3x;
+
+	vmovdqa X10x, X8x;
+	vmovdqa X11x, X9x;
+	vmovdqa X12x, X14x;
+	vpaddq X4x, X13x, X15x;
+	vmovdqa X15x, X7x;
+
+.align 16
+.Lround2_2:
+	QUARTERROUND2H(X0x, X1x, X2x,  X3x,  X8x, X9x, X14x, X15x,
+		       0x39, 0x4e, 0x93);
+	QUARTERROUND2H(X0x, X1x, X2x,  X3x,  X8x, X9x, X14x, X15x,
+		       0x93, 0x4e, 0x39);
+	subl $2, ROUND;
+	jnz .Lround2_2;
+
+	PLUS(X0x, X10x);
+	PLUS(X1x, X11x);
+	PLUS(X2x, X12x);
+	PLUS(X3x, X13x);
+
+	vpaddq .Ltwo rRIP, X13x, X13x; /* Update counter */
+
+	PLUS(X8x, X10x);
+	PLUS(X9x, X11x);
+	PLUS(X14x, X12x);
+	PLUS(X15x, X7x);
+
+	xor_src_dst_4x4(DST, SRC, 0 * 4, 4 * 4, X0x, X1x, X2x, X3x);
+	xor_src_dst_4x4(DST, SRC, 16 * 4, 4 * 4, X8x, X9x, X14x, X15x);
+	lea (2 * 64)(DST), DST;
+	lea (2 * 64)(SRC), SRC;
+
+	cmpq $0, NBLKS;
+	je .Lskip1;
+
+.align 8
+.Lhandle1:
+	/* Process one ChaCha20 block (XMM) */
+	movl $20, ROUND;
+	subq $1, NBLKS;
+
+	vmovdqa X10x, X0x;
+	vmovdqa X11x, X1x;
+	vmovdqa X12x, X2x;
+	vmovdqa X13x, X3x;
+
+.align 16
+.Lround2_1:
+	QUARTERROUND1H(X0x, X1x, X2x, X3x, 0x39, 0x4e, 0x93);
+	QUARTERROUND1H(X0x, X1x, X2x, X3x, 0x93, 0x4e, 0x39);
+	subl $2, ROUND;
+	jnz .Lround2_1;
+
+	PLUS(X0x, X10x);
+	PLUS(X1x, X11x);
+	PLUS(X2x, X12x);
+	PLUS(X3x, X13x);
+
+	vpaddq X4x, X13x, X13x; /* Update counter */
+
+	xor_src_dst_4x4(DST, SRC, 0 * 4, 4 * 4, X0x, X1x, X2x, X3x);
+	/*lea (1 * 64)(DST), DST;*/
+	/*lea (1 * 64)(SRC), SRC;*/
+
+.align 8
+.Lskip1:
+	/* Store counter */
+	vmovdqu X13x, (12 * 4)(INPUT);
+
+.align 8
+.Ldone:
 	vzeroall; /* clears ZMM0-ZMM15 */
 
-	/* eax zeroed by round loop. */
+	xorl %eax, %eax;
 	ret_spec_stop;
 	CFI_ENDPROC();
-ELF(.size _gcry_chacha20_amd64_avx512_blocks16,
-	  .-_gcry_chacha20_amd64_avx512_blocks16;)
+ELF(.size _gcry_chacha20_amd64_avx512_blocks,
+	  .-_gcry_chacha20_amd64_avx512_blocks;)
 
 #endif /*defined(HAVE_COMPATIBLE_GCC_AMD64_PLATFORM_AS)*/
 #endif /*__x86_64*/
diff --git a/cipher/chacha20.c b/cipher/chacha20.c
index f0cb8721..a7e0dd63 100644
--- a/cipher/chacha20.c
+++ b/cipher/chacha20.c
@@ -173,9 +173,9 @@ unsigned int _gcry_chacha20_poly1305_amd64_avx2_blocks8(
 
 #ifdef USE_AVX512
 
-unsigned int _gcry_chacha20_amd64_avx512_blocks16(u32 *state, byte *dst,
-						  const byte *src,
-						  size_t nblks) ASM_FUNC_ABI;
+unsigned int _gcry_chacha20_amd64_avx512_blocks(u32 *state, byte *dst,
+                                                const byte *src,
+                                                size_t nblks) ASM_FUNC_ABI;
 
 #endif /* USE_AVX2 */
 
@@ -352,6 +352,13 @@ static unsigned int
 chacha20_blocks (CHACHA20_context_t *ctx, byte *dst, const byte *src,
 		 size_t nblks)
 {
+#ifdef USE_AVX512
+  if (ctx->use_avx512)
+    {
+      return _gcry_chacha20_amd64_avx512_blocks(ctx->input, dst, src, nblks);
+    }
+#endif
+
 #ifdef USE_SSSE3
   if (ctx->use_ssse3)
     {
@@ -546,14 +553,13 @@ do_chacha20_encrypt_stream_tail (CHACHA20_context_t *ctx, byte *outbuf,
   unsigned int nburn, burn = 0;
 
 #ifdef USE_AVX512
-  if (ctx->use_avx512 && length >= CHACHA20_BLOCK_SIZE * 16)
+  if (ctx->use_avx512 && length >= CHACHA20_BLOCK_SIZE)
     {
       size_t nblocks = length / CHACHA20_BLOCK_SIZE;
-      nblocks -= nblocks % 16;
-      nburn = _gcry_chacha20_amd64_avx512_blocks16(ctx->input, outbuf, inbuf,
-						   nblocks);
+      nburn = _gcry_chacha20_amd64_avx512_blocks(ctx->input, outbuf, inbuf,
+                                                 nblocks);
       burn = nburn > burn ? nburn : burn;
-      length -= nblocks * CHACHA20_BLOCK_SIZE;
+      length %= CHACHA20_BLOCK_SIZE;
       outbuf += nblocks * CHACHA20_BLOCK_SIZE;
       inbuf  += nblocks * CHACHA20_BLOCK_SIZE;
     }
@@ -662,7 +668,7 @@ do_chacha20_encrypt_stream_tail (CHACHA20_context_t *ctx, byte *outbuf,
       size_t nblocks = length / CHACHA20_BLOCK_SIZE;
       nburn = chacha20_blocks(ctx, outbuf, inbuf, nblocks);
       burn = nburn > burn ? nburn : burn;
-      length -= nblocks * CHACHA20_BLOCK_SIZE;
+      length %= CHACHA20_BLOCK_SIZE;
       outbuf += nblocks * CHACHA20_BLOCK_SIZE;
       inbuf  += nblocks * CHACHA20_BLOCK_SIZE;
     }
-- 
2.37.2




More information about the Gcrypt-devel mailing list