chacha/asm/chacha-x86_64.pl: add AVX512 path optimized for shorter inputs.
authorAndy Polyakov <appro@openssl.org>
Mon, 19 Dec 2016 15:26:35 +0000 (16:26 +0100)
committerAndy Polyakov <appro@openssl.org>
Sun, 25 Dec 2016 15:31:40 +0000 (16:31 +0100)
Reviewed-by: Richard Levitte <levitte@openssl.org>
crypto/chacha/asm/chacha-x86_64.pl

index fd3fdeb10c7572bd0dc6e1b99fad24488719cab0..ac169ee33cc943287a473585432abffce2c093a2 100755 (executable)
@@ -112,6 +112,10 @@ $code.=<<___;
 .Lsigma:
 .asciz "expand 32-byte k"
 .align 64
+.Lzeroz:
+.long  0,0,0,0, 1,0,0,0, 2,0,0,0, 3,0,0,0
+.Lfourz:
+.long  4,0,0,0, 4,0,0,0, 4,0,0,0, 4,0,0,0
 .Lincz:
 .long  0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
 .Lsixteen:
@@ -241,6 +245,12 @@ ChaCha20_ctr32:
        cmp     \$0,$len
        je      .Lno_data
        mov     OPENSSL_ia32cap_P+4(%rip),%r10
+___
+$code.=<<___   if ($avx>2);
+       bt      \$48,%r10               # check for AVX512F
+       jc      .LChaCha20_avx512
+___
+$code.=<<___;
        test    \$`1<<(41-32)`,%r10d
        jnz     .LChaCha20_ssse3
 
@@ -447,7 +457,7 @@ $code.=<<___;
        ja      .LChaCha20_4x           # but overall it won't be slower
 
 .Ldo_sse3_after_all:
-       push    %rbx
+       push    %rbx                    # just to share SEH handler, no pops
        push    %rbp
        push    %r12
        push    %r13
@@ -472,7 +482,7 @@ $code.=<<___;
        movdqa  $b,0x10(%rsp)
        movdqa  $c,0x20(%rsp)
        movdqa  $d,0x30(%rsp)
-       mov     \$10,%ebp
+       mov     \$10,$counter           # reuse $counter
        jmp     .Loop_ssse3
 
 .align 32
@@ -482,7 +492,7 @@ $code.=<<___;
        movdqa  0x10(%rsp),$b
        movdqa  0x20(%rsp),$c
        paddd   0x30(%rsp),$d
-       mov     \$10,%ebp
+       mov     \$10,$counter
        movdqa  $d,0x30(%rsp)
        jmp     .Loop_ssse3
 
@@ -500,7 +510,7 @@ ___
        &pshufd ($b,$b,0b10010011);
        &pshufd ($d,$d,0b00111001);
 
-       &dec    ("%ebp");
+       &dec    ($counter);
        &jnz    (".Loop_ssse3");
 
 $code.=<<___;
@@ -539,14 +549,14 @@ $code.=<<___;
        movdqa  $b,0x10(%rsp)
        movdqa  $c,0x20(%rsp)
        movdqa  $d,0x30(%rsp)
-       xor     %rbx,%rbx
+       xor     $counter,$counter
 
 .Loop_tail_ssse3:
-       movzb   ($inp,%rbx),%eax
-       movzb   (%rsp,%rbx),%ecx
-       lea     1(%rbx),%rbx
+       movzb   ($inp,$counter),%eax
+       movzb   (%rsp,$counter),%ecx
+       lea     1($counter),$counter
        xor     %ecx,%eax
-       mov     %al,-1($out,%rbx)
+       mov     %al,-1($out,$counter)
        dec     $len
        jnz     .Loop_tail_ssse3
 
@@ -557,13 +567,7 @@ $code.=<<___       if ($win64);
        movaps  64+48(%rsp),%xmm7
 ___
 $code.=<<___;
-       add     \$64+$xframe,%rsp
-       pop     %r15
-       pop     %r14
-       pop     %r13
-       pop     %r12
-       pop     %rbp
-       pop     %rbx
+       add     \$64+$xframe+48,%rsp
        ret
 .size  ChaCha20_ssse3,.-ChaCha20_ssse3
 ___
@@ -1732,12 +1736,6 @@ $code.=<<___;
 .align 32
 ChaCha20_8x:
 .LChaCha20_8x:
-___
-$code.=<<___           if ($avx>2);
-       test            \$`1<<16`,%r10d                 # check for AVX512F
-       jnz             .LChaCha20_16x
-___
-$code.=<<___;
        mov             %rsp,%r10
        sub             \$0x280+$xframe,%rsp
        and             \$-32,%rsp
@@ -2229,7 +2227,7 @@ $code.=<<___;
        jnz             .Loop_tail8x
 
 .Ldone8x:
-       vzeroupper
+       vzeroall
 ___
 $code.=<<___   if ($win64);
        lea             0x290+0x30(%rsp),%r11
@@ -2254,6 +2252,228 @@ ___
 ########################################################################
 # AVX512 code paths
 if ($avx>2) {
+# This one handles shorter inputs...
+
+my ($a,$b,$c,$d, $a_,$b_,$c_,$d_,$fourz) = map("%zmm$_",(0..3,16..20));
+my ($t0,$t1,$t2,$t3) = map("%xmm$_",(4..7));
+
+sub AVX512ROUND {      # critical path is 14 "SIMD ticks" per round
+       &vpaddd ($a,$a,$b);
+       &vpxord ($d,$d,$a);
+       &vprold ($d,$d,16);
+
+       &vpaddd ($c,$c,$d);
+       &vpxord ($b,$b,$c);
+       &vprold ($b,$b,12);
+
+       &vpaddd ($a,$a,$b);
+       &vpxord ($d,$d,$a);
+       &vprold ($d,$d,8);
+
+       &vpaddd ($c,$c,$d);
+       &vpxord ($b,$b,$c);
+       &vprold ($b,$b,7);
+}
+
+my $xframe = $win64 ? 32+32+8 : 24;
+
+$code.=<<___;
+.type  ChaCha20_avx512,\@function,5
+.align 32
+ChaCha20_avx512:
+.LChaCha20_avx512:
+       cmp     \$512,$len
+       ja      .LChaCha20_16x
+
+       push    %rbx                    # just to share SEH handler, no pops
+       push    %rbp
+       push    %r12
+       push    %r13
+       push    %r14
+       push    %r15
+
+       sub     \$64+$xframe,%rsp
+___
+$code.=<<___   if ($win64);
+       movaps  %xmm6,64+32(%rsp)
+       movaps  %xmm7,64+48(%rsp)
+___
+$code.=<<___;
+       vbroadcasti32x4 .Lsigma(%rip),$a
+       vbroadcasti32x4 ($key),$b
+       vbroadcasti32x4 16($key),$c
+       vbroadcasti32x4 ($counter),$d
+
+       vmovdqa32       $a,$a_
+       vmovdqa32       $b,$b_
+       vmovdqa32       $c,$c_
+       vpaddd          .Lzeroz(%rip),$d,$d
+       vmovdqa32       .Lfourz(%rip),$fourz
+       mov             \$10,$counter   # reuse $counter
+       vmovdqa32       $d,$d_
+       jmp             .Loop_avx512
+
+.align 16
+.Loop_outer_avx512:
+       vmovdqa32       $a_,$a
+       vmovdqa32       $b_,$b
+       vmovdqa32       $c_,$c
+       vpaddd          $fourz,$d_,$d
+       mov             \$10,$counter
+       vmovdqa32       $d,$d_
+       jmp             .Loop_avx512
+
+.align 32
+.Loop_avx512:
+___
+       &AVX512ROUND();
+       &vpshufd        ($c,$c,0b01001110);
+       &vpshufd        ($b,$b,0b00111001);
+       &vpshufd        ($d,$d,0b10010011);
+
+       &AVX512ROUND();
+       &vpshufd        ($c,$c,0b01001110);
+       &vpshufd        ($b,$b,0b10010011);
+       &vpshufd        ($d,$d,0b00111001);
+
+       &dec            ($counter);
+       &jnz            (".Loop_avx512");
+
+$code.=<<___;
+       vpaddd          $a_,$a,$a
+       vpaddd          $b_,$b,$b
+       vpaddd          $c_,$c,$c
+       vpaddd          $d_,$d,$d
+
+       sub             \$64,$len
+       jb              .Ltail64_avx512
+
+       vpxor           0x00($inp),%x#$a,$t0    # xor with input
+       vpxor           0x10($inp),%x#$b,$t1
+       vpxor           0x20($inp),%x#$c,$t2
+       vpxor           0x30($inp),%x#$d,$t3
+       lea             0x40($inp),$inp         # inp+=64
+
+       vmovdqu         $t0,0x00($out)          # write output
+       vmovdqu         $t1,0x10($out)
+       vmovdqu         $t2,0x20($out)
+       vmovdqu         $t3,0x30($out)
+       lea             0x40($out),$out         # out+=64
+
+       jz              .Ldone_avx512
+
+       vextracti32x4   \$1,$a,$t0
+       vextracti32x4   \$1,$b,$t1
+       vextracti32x4   \$1,$c,$t2
+       vextracti32x4   \$1,$d,$t3
+
+       sub             \$64,$len
+       jb              .Ltail_avx512
+
+       vpxor           0x00($inp),$t0,$t0      # xor with input
+       vpxor           0x10($inp),$t1,$t1
+       vpxor           0x20($inp),$t2,$t2
+       vpxor           0x30($inp),$t3,$t3
+       lea             0x40($inp),$inp         # inp+=64
+
+       vmovdqu         $t0,0x00($out)          # write output
+       vmovdqu         $t1,0x10($out)
+       vmovdqu         $t2,0x20($out)
+       vmovdqu         $t3,0x30($out)
+       lea             0x40($out),$out         # out+=64
+
+       jz              .Ldone_avx512
+
+       vextracti32x4   \$2,$a,$t0
+       vextracti32x4   \$2,$b,$t1
+       vextracti32x4   \$2,$c,$t2
+       vextracti32x4   \$2,$d,$t3
+
+       sub             \$64,$len
+       jb              .Ltail_avx512
+
+       vpxor           0x00($inp),$t0,$t0      # xor with input
+       vpxor           0x10($inp),$t1,$t1
+       vpxor           0x20($inp),$t2,$t2
+       vpxor           0x30($inp),$t3,$t3
+       lea             0x40($inp),$inp         # inp+=64
+
+       vmovdqu         $t0,0x00($out)          # write output
+       vmovdqu         $t1,0x10($out)
+       vmovdqu         $t2,0x20($out)
+       vmovdqu         $t3,0x30($out)
+       lea             0x40($out),$out         # out+=64
+
+       jz              .Ldone_avx512
+
+       vextracti32x4   \$3,$a,$t0
+       vextracti32x4   \$3,$b,$t1
+       vextracti32x4   \$3,$c,$t2
+       vextracti32x4   \$3,$d,$t3
+
+       sub             \$64,$len
+       jb              .Ltail_avx512
+
+       vpxor           0x00($inp),$t0,$t0      # xor with input
+       vpxor           0x10($inp),$t1,$t1
+       vpxor           0x20($inp),$t2,$t2
+       vpxor           0x30($inp),$t3,$t3
+       lea             0x40($inp),$inp         # inp+=64
+
+       vmovdqu         $t0,0x00($out)          # write output
+       vmovdqu         $t1,0x10($out)
+       vmovdqu         $t2,0x20($out)
+       vmovdqu         $t3,0x30($out)
+       lea             0x40($out),$out         # out+=64
+
+       jnz             .Loop_outer_avx512
+
+       jmp             .Ldone_avx512
+
+.align 16
+.Ltail64_avx512:
+       vmovdqa         %x#$a,0x00(%rsp)
+       vmovdqa         %x#$b,0x10(%rsp)
+       vmovdqa         %x#$c,0x20(%rsp)
+       vmovdqa         %x#$d,0x30(%rsp)
+       add             \$64,$len
+       jmp             .Loop_tail_avx512
+
+.align 16
+.Ltail_avx512:
+       vmovdqa         $t0,0x00(%rsp)
+       vmovdqa         $t1,0x10(%rsp)
+       vmovdqa         $t2,0x20(%rsp)
+       vmovdqa         $t3,0x30(%rsp)
+       add             \$64,$len
+
+.Loop_tail_avx512:
+       movzb           ($inp,$counter),%eax
+       movzb           (%rsp,$counter),%ecx
+       lea             1($counter),$counter
+       xor             %ecx,%eax
+       mov             %al,-1($out,$counter)
+       dec             $len
+       jnz             .Loop_tail_avx512
+
+       vmovdqa32       $a_,0x00(%rsp)
+
+.Ldone_avx512:
+       vzeroall
+___
+$code.=<<___   if ($win64);
+       movaps  64+32(%rsp),%xmm6
+       movaps  64+48(%rsp),%xmm7
+___
+$code.=<<___;
+       add     \$64+$xframe+48,%rsp
+       ret
+.size  ChaCha20_avx512,.-ChaCha20_avx512
+___
+}
+if ($avx>2) {
+# This one handles longer inputs...
+
 my ($xa0,$xa1,$xa2,$xa3, $xb0,$xb1,$xb2,$xb3,
     $xc0,$xc1,$xc2,$xc3, $xd0,$xd1,$xd2,$xd3)=map("%zmm$_",(0..15));
 my  @xx=($xa0,$xa1,$xa2,$xa3, $xb0,$xb1,$xb2,$xb3,
@@ -2728,8 +2948,11 @@ $code.=<<___;
        dec             $len
        jnz             .Loop_tail16x
 
+       vpxord          $xa0,$xa0,$xa0
+       vmovdqa32       $xa0,0(%rsp)
+
 .Ldone16x:
-       vzeroupper
+       vzeroall
 ___
 $code.=<<___   if ($win64);
        lea             0x290+0x30(%rsp),%r11
@@ -2752,9 +2975,9 @@ ___
 }
 
 foreach (split("\n",$code)) {
-       s/\`([^\`]*)\`/eval $1/geo;
+       s/\`([^\`]*)\`/eval $1/ge;
 
-       s/%x#%y/%x/go;
+       s/%x#%[yz]/%x/g;        # "down-shift"
 
        print $_,"\n";
 }