From 10a09953bf3c57d87075f7d02421da4009178c0c Mon Sep 17 00:00:00 2001 From: Chris Duncan Date: Tue, 14 Jan 2025 11:52:50 -0800 Subject: [PATCH] Replace some 63-bit scalar rotations with vector rotations. --- src/shaders/compute.wgsl | 144 ++++++++++++++++++++++++++------------- 1 file changed, 96 insertions(+), 48 deletions(-) diff --git a/src/shaders/compute.wgsl b/src/shaders/compute.wgsl index 5cff052..90d4062 100644 --- a/src/shaders/compute.wgsl +++ b/src/shaders/compute.wgsl @@ -213,10 +213,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -1026,10 +1030,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -1842,10 +1850,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -2649,10 +2661,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -3471,10 +3487,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -4296,10 +4316,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -5115,10 +5139,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -5931,10 +5959,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -6747,10 +6779,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -7569,10 +7605,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -8391,10 +8431,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; @@ -9204,10 +9248,14 @@ fn main(@builtin(global_invocation_id) id: vec3) { v25 = v_2425.y; // b = rotr64(b ^ c, 63) - xor0 = v8 ^ v16; - xor1 = v9 ^ v17; - v8 = (xor1 >> 31u) ^ (xor0 << 1u); - v9 = (xor0 >> 31u) ^ (xor1 << 1u); + v_1617.x = v16; + v_1617.y = v17; + v_89.x = v8; + v_89.y = v9; + xor = v_89 ^ v_1617; + v_89 = vec2((xor.x << 1u) | (xor.y >> 31u), (xor.y << 1u) | (xor.x >> 31u)); + v8 = v_89.x; + v9 = v_89.y; -- 2.34.1