From 81f382913b6d5c67320a796e3aa09eb979753172 Mon Sep 17 00:00:00 2001 From: Chris Duncan Date: Tue, 14 Jan 2025 11:31:22 -0800 Subject: [PATCH] Replace more 24-bit scalar rotations with vector rotations. --- src/shaders/compute.wgsl | 132 ++++++++++++++++++++++++++------------- 1 file changed, 88 insertions(+), 44 deletions(-) diff --git a/src/shaders/compute.wgsl b/src/shaders/compute.wgsl index 2267891..91d41be 100644 --- a/src/shaders/compute.wgsl +++ b/src/shaders/compute.wgsl @@ -951,10 +951,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -1731,10 +1735,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -2502,10 +2510,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -3282,10 +3294,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -4077,10 +4093,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -4860,10 +4880,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -5640,10 +5664,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -6420,10 +6448,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -7200,10 +7232,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -7986,10 +8022,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; @@ -8769,10 +8809,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 24) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor0 >> 24u) ^ (xor1 << 8u); - v9 = (xor1 >> 24u) ^ (xor0 << 8u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x >> 24u) | (xor.y << 8u), (xor.y >> 24u) | (xor.x << 8u)); + v8 = v_89.x; + v9 = v_89.y; // a = a + b v_89.x = v8; -- 2.34.1