From d60609cc9f7bf71fdece096f1db2a36d30cba732 Mon Sep 17 00:00:00 2001 From: Chris Duncan Date: Tue, 14 Jan 2025 08:24:37 -0800 Subject: [PATCH] Replace more scalar addition with vector addition. --- src/shaders/compute.wgsl | 336 +++++++++++++++++++++++++-------------- 1 file changed, 216 insertions(+), 120 deletions(-) diff --git a/src/shaders/compute.wgsl b/src/shaders/compute.wgsl index 53610b6..3162f0a 100644 --- a/src/shaders/compute.wgsl +++ b/src/shaders/compute.wgsl @@ -496,11 +496,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+0]] // // skip since adding 0u does nothing @@ -532,11 +536,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing @@ -1201,11 +1209,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // a = a + m[sigma[r][2*i+0]] o0 = v0 + m2; @@ -1234,11 +1246,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing @@ -1891,11 +1907,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+0]] // // skip since adding 0u does nothing @@ -1927,11 +1947,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing @@ -2584,11 +2608,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // a = a + m[sigma[r][2*i+0]] o0 = v0 + m4; @@ -2617,11 +2645,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing @@ -3274,11 +3306,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+0]] // // skip since adding 0u does nothing @@ -3310,11 +3346,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // a = a + m[sigma[r][2*i+1]] o0 = v0 + m2; @@ -3967,11 +4007,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // a = a + m[sigma[r][2*i+0]] o0 = v0 + m8; @@ -4000,11 +4044,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing @@ -4663,11 +4711,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // a = a + m[sigma[r][2*i+0]] o0 = v0 + m0; @@ -4696,11 +4748,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing @@ -5356,11 +5412,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+0]] // // skip since adding 0u does nothing @@ -5392,11 +5452,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // a = a + m[sigma[r][2*i+1]] o0 = v0 + m0; @@ -6049,11 +6113,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+0]] // // skip since adding 0u does nothing @@ -6085,11 +6153,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // a = a + m[sigma[r][2*i+1]] o0 = v0 + m4; @@ -6739,11 +6811,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+0]] // // skip since adding 0u does nothing @@ -6775,11 +6851,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing @@ -7426,11 +7506,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+0]] // // skip since adding 0u does nothing @@ -7462,11 +7546,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing @@ -8131,11 +8219,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { */ // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // a = a + m[sigma[r][2*i+0]] o0 = v0 + m2; @@ -8164,11 +8256,15 @@ fn main(@builtin(global_invocation_id) id: vec3) { v11 = (xor1 >> 24u) ^ (xor0 << 8u); // a = a + b - o0 = v0 + v10; - o1 = v1 + v11; - o1 = o1 + select(0u, 1u, o0 < v0); - v0 = o0; - v1 = o1; + v_01.x = v0; + v_01.y = v1; + v_1011.x = v10; + v_1011.y = v11; + v_01 = v_01 + v_1011 + select(vec2(0u), vec2(0u, 1u), v_01.x + v_1011.x < v_01.x); + v0 = v_01.x; + v1 = v_01.y; + v10 = v_1011.x; + v11 = v_1011.y; // // a = a + m[sigma[r][2*i+1]] // // skip since adding 0u does nothing -- 2.34.1