From: Chris Duncan Date: Tue, 14 Jan 2025 19:34:21 +0000 (-0800) Subject: Replace some 16-bit scalar rotations with vector rotations. X-Git-Url: https://zoso.dev/?a=commitdiff_plain;h=a521923b55b6a3028975d78c46fb5b934da01364;p=nano-pow.git Replace some 16-bit scalar rotations with vector rotations. --- diff --git a/src/shaders/compute.wgsl b/src/shaders/compute.wgsl index 91d41be..3a887f5 100644 --- a/src/shaders/compute.wgsl +++ b/src/shaders/compute.wgsl @@ -192,10 +192,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v1 = o1; // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -973,10 +977,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { // // skip since adding 0u does nothing // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -1757,10 +1765,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { // // skip since adding 0u does nothing // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -2532,10 +2544,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { // // skip since adding 0u does nothing // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -3322,10 +3338,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v1 = o1; // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -4115,10 +4135,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { // // skip since adding 0u does nothing // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -4902,10 +4926,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { // // skip since adding 0u does nothing // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -5686,10 +5714,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { // // skip since adding 0u does nothing // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -6470,10 +6502,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { // // skip since adding 0u does nothing // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -7260,10 +7296,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v1 = o1; // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -8050,10 +8090,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v1 = o1; // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16; @@ -8831,10 +8875,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { // // skip since adding 0u does nothing // d = rotr64(d ^ a, 16) - xor0 = v24 ^ v0; - xor1 = v25 ^ v1; - v24 = (xor0 >> 16u) ^ (xor1 << 16u); - v25 = (xor1 >> 16u) ^ (xor0 << 16u); + v_01.x = v0; + v_01.y = v1; + v_2425.x = v24; + v_2425.y = v25; + xor = v_2425 ^ v_01; + v_2425 = vec2((xor.x >> 16u) | (xor.y << 16u), (xor.y >> 16u) | (xor.x << 16u)); + v24 = v_2425.x; + v25 = v_2425.y; // c = c + d v_1617.x = v16;