Skip to content

Commit

Permalink
bevy_pbr: Avoid copying structs and using registers in shaders (#7069)
Browse files Browse the repository at this point in the history
# Objective

- The #7064 PR had poor performance on an M1 Max in MacOS due to significant overuse of registers resulting in 'register spilling' where data that would normally be stored in registers on the GPU is instead stored in VRAM. The latency to read from/write to VRAM instead of registers incurs a significant performance penalty.
- Use of registers is a limiting factor in shader performance. Assignment of a struct from memory to a local variable can incur copies. Passing a variable that has struct type as an argument to a function can also incur copies. As such, these two cases can incur increased register usage and decreased performance.

## Solution

- Remove/avoid a number of assignments of light struct type data to local variables.
- Remove/avoid a number of passing light struct type variables/data as value arguments to shader functions.
  • Loading branch information
superdump committed Jan 2, 2023
1 parent b833bda commit b44b606
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 49 deletions.
15 changes: 6 additions & 9 deletions crates/bevy_pbr/src/render/pbr_functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -196,38 +196,35 @@ fn pbr(
// point lights
for (var i: u32 = offset_and_counts[0]; i < offset_and_counts[0] + offset_and_counts[1]; i = i + 1u) {
let light_id = get_light_id(i);
let light = point_lights.data[light_id];
var shadow: f32 = 1.0;
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
&& (light.flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
&& (point_lights.data[light_id].flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
shadow = fetch_point_shadow(light_id, in.world_position, in.world_normal);
}
let light_contrib = point_light(in.world_position.xyz, light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
let light_contrib = point_light(in.world_position.xyz, light_id, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
light_accum = light_accum + light_contrib * shadow;
}

// spot lights
for (var i: u32 = offset_and_counts[0] + offset_and_counts[1]; i < offset_and_counts[0] + offset_and_counts[1] + offset_and_counts[2]; i = i + 1u) {
let light_id = get_light_id(i);
let light = point_lights.data[light_id];
var shadow: f32 = 1.0;
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
&& (light.flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
&& (point_lights.data[light_id].flags & POINT_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
shadow = fetch_spot_shadow(light_id, in.world_position, in.world_normal);
}
let light_contrib = spot_light(in.world_position.xyz, light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
let light_contrib = spot_light(in.world_position.xyz, light_id, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
light_accum = light_accum + light_contrib * shadow;
}

let n_directional_lights = lights.n_directional_lights;
for (var i: u32 = 0u; i < n_directional_lights; i = i + 1u) {
let light = lights.directional_lights[i];
var shadow: f32 = 1.0;
if ((mesh.flags & MESH_FLAGS_SHADOW_RECEIVER_BIT) != 0u
&& (light.flags & DIRECTIONAL_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
&& (lights.directional_lights[i].flags & DIRECTIONAL_LIGHT_FLAGS_SHADOWS_ENABLED_BIT) != 0u) {
shadow = fetch_directional_shadow(i, in.world_position, in.world_normal);
}
let light_contrib = directional_light(light, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
let light_contrib = directional_light(i, roughness, NdotV, in.N, in.V, R, F0, diffuse_color);
light_accum = light_accum + light_contrib * shadow;
}

Expand Down
37 changes: 21 additions & 16 deletions crates/bevy_pbr/src/render/pbr_lighting.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -150,22 +150,23 @@ fn perceptualRoughnessToRoughness(perceptualRoughness: f32) -> f32 {
}

fn point_light(
world_position: vec3<f32>, light: PointLight, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
world_position: vec3<f32>, light_id: u32, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>
) -> vec3<f32> {
let light_to_frag = light.position_radius.xyz - world_position.xyz;
let light = &point_lights.data[light_id];
let light_to_frag = (*light).position_radius.xyz - world_position.xyz;
let distance_square = dot(light_to_frag, light_to_frag);
let rangeAttenuation =
getDistanceAttenuation(distance_square, light.color_inverse_square_range.w);
getDistanceAttenuation(distance_square, (*light).color_inverse_square_range.w);

// Specular.
// Representative Point Area Lights.
// see http://blog.selfshadow.com/publications/s2013-shading-course/karis/s2013_pbs_epic_notes_v2.pdf p14-16
let a = roughness;
let centerToRay = dot(light_to_frag, R) * R - light_to_frag;
let closestPoint = light_to_frag + centerToRay * saturate(light.position_radius.w * inverseSqrt(dot(centerToRay, centerToRay)));
let closestPoint = light_to_frag + centerToRay * saturate((*light).position_radius.w * inverseSqrt(dot(centerToRay, centerToRay)));
let LspecLengthInverse = inverseSqrt(dot(closestPoint, closestPoint));
let normalizationFactor = a / saturate(a + (light.position_radius.w * 0.5 * LspecLengthInverse));
let normalizationFactor = a / saturate(a + ((*light).position_radius.w * 0.5 * LspecLengthInverse));
let specularIntensity = normalizationFactor * normalizationFactor;

var L: vec3<f32> = closestPoint * LspecLengthInverse; // normalize() equivalent?
Expand Down Expand Up @@ -197,40 +198,44 @@ fn point_light(
// I = Φ / 4 π
// The derivation of this can be seen here: https://google.github.io/filament/Filament.html#mjx-eqn-pointLightLuminousPower

// NOTE: light.color.rgb is premultiplied with light.intensity / 4 π (which would be the luminous intensity) on the CPU
// NOTE: (*light).color.rgb is premultiplied with (*light).intensity / 4 π (which would be the luminous intensity) on the CPU

// TODO compensate for energy loss https://google.github.io/filament/Filament.html#materialsystem/improvingthebrdfs/energylossinspecularreflectance

return ((diffuse + specular_light) * light.color_inverse_square_range.rgb) * (rangeAttenuation * NoL);
return ((diffuse + specular_light) * (*light).color_inverse_square_range.rgb) * (rangeAttenuation * NoL);
}

fn spot_light(
world_position: vec3<f32>, light: PointLight, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
world_position: vec3<f32>, light_id: u32, roughness: f32, NdotV: f32, N: vec3<f32>, V: vec3<f32>,
R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>
) -> vec3<f32> {
// reuse the point light calculations
let point_light = point_light(world_position, light, roughness, NdotV, N, V, R, F0, diffuseColor);
let point_light = point_light(world_position, light_id, roughness, NdotV, N, V, R, F0, diffuseColor);

let light = &point_lights.data[light_id];

// reconstruct spot dir from x/z and y-direction flag
var spot_dir = vec3<f32>(light.light_custom_data.x, 0.0, light.light_custom_data.y);
var spot_dir = vec3<f32>((*light).light_custom_data.x, 0.0, (*light).light_custom_data.y);
spot_dir.y = sqrt(max(0.0, 1.0 - spot_dir.x * spot_dir.x - spot_dir.z * spot_dir.z));
if ((light.flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
if (((*light).flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
spot_dir.y = -spot_dir.y;
}
let light_to_frag = light.position_radius.xyz - world_position.xyz;
let light_to_frag = (*light).position_radius.xyz - world_position.xyz;

// calculate attenuation based on filament formula https://google.github.io/filament/Filament.html#listing_glslpunctuallight
// spot_scale and spot_offset have been precomputed
// note we normalize here to get "l" from the filament listing. spot_dir is already normalized
let cd = dot(-spot_dir, normalize(light_to_frag));
let attenuation = saturate(cd * light.light_custom_data.z + light.light_custom_data.w);
let attenuation = saturate(cd * (*light).light_custom_data.z + (*light).light_custom_data.w);
let spot_attenuation = attenuation * attenuation;

return point_light * spot_attenuation;
}

fn directional_light(light: DirectionalLight, roughness: f32, NdotV: f32, normal: vec3<f32>, view: vec3<f32>, R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>) -> vec3<f32> {
let incident_light = light.direction_to_light.xyz;
fn directional_light(light_id: u32, roughness: f32, NdotV: f32, normal: vec3<f32>, view: vec3<f32>, R: vec3<f32>, F0: vec3<f32>, diffuseColor: vec3<f32>) -> vec3<f32> {
let light = &lights.directional_lights[light_id];

let incident_light = (*light).direction_to_light.xyz;

let half_vector = normalize(incident_light + view);
let NoL = saturate(dot(normal, incident_light));
Expand All @@ -241,5 +246,5 @@ fn directional_light(light: DirectionalLight, roughness: f32, NdotV: f32, normal
let specularIntensity = 1.0;
let specular_light = specular(F0, roughness, half_vector, NdotV, NoL, NoH, LoH, specularIntensity);

return (specular_light + diffuse) * light.color.rgb * NoL;
return (specular_light + diffuse) * (*light).color.rgb * NoL;
}
48 changes: 24 additions & 24 deletions crates/bevy_pbr/src/render/shadows.wgsl
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
#define_import_path bevy_pbr::shadows

fn fetch_point_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
let light = point_lights.data[light_id];
let light = &point_lights.data[light_id];

// because the shadow maps align with the axes and the frustum planes are at 45 degrees
// we can get the worldspace depth by taking the largest absolute axis
let surface_to_light = light.position_radius.xyz - frag_position.xyz;
let surface_to_light = (*light).position_radius.xyz - frag_position.xyz;
let surface_to_light_abs = abs(surface_to_light);
let distance_to_light = max(surface_to_light_abs.x, max(surface_to_light_abs.y, surface_to_light_abs.z));

// The normal bias here is already scaled by the texel size at 1 world unit from the light.
// The texel size increases proportionally with distance from the light so multiplying by
// distance to light scales the normal bias to the texel size at the fragment distance.
let normal_offset = light.shadow_normal_bias * distance_to_light * surface_normal.xyz;
let depth_offset = light.shadow_depth_bias * normalize(surface_to_light.xyz);
let normal_offset = (*light).shadow_normal_bias * distance_to_light * surface_normal.xyz;
let depth_offset = (*light).shadow_depth_bias * normalize(surface_to_light.xyz);
let offset_position = frag_position.xyz + normal_offset + depth_offset;

// similar largest-absolute-axis trick as above, but now with the offset fragment position
let frag_ls = light.position_radius.xyz - offset_position.xyz;
let frag_ls = (*light).position_radius.xyz - offset_position.xyz;
let abs_position_ls = abs(frag_ls);
let major_axis_magnitude = max(abs_position_ls.x, max(abs_position_ls.y, abs_position_ls.z));

// NOTE: These simplifications come from multiplying:
// projection * vec4(0, 0, -major_axis_magnitude, 1.0)
// and keeping only the terms that have any impact on the depth.
// Projection-agnostic approach:
let zw = -major_axis_magnitude * light.light_custom_data.xy + light.light_custom_data.zw;
let zw = -major_axis_magnitude * (*light).light_custom_data.xy + (*light).light_custom_data.zw;
let depth = zw.x / zw.y;

// do the lookup, using HW PCF and comparison
Expand All @@ -42,27 +42,27 @@ fn fetch_point_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: v
}

fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
let light = point_lights.data[light_id];
let light = &point_lights.data[light_id];

let surface_to_light = light.position_radius.xyz - frag_position.xyz;
let surface_to_light = (*light).position_radius.xyz - frag_position.xyz;

// construct the light view matrix
var spot_dir = vec3<f32>(light.light_custom_data.x, 0.0, light.light_custom_data.y);
var spot_dir = vec3<f32>((*light).light_custom_data.x, 0.0, (*light).light_custom_data.y);
// reconstruct spot dir from x/z and y-direction flag
spot_dir.y = sqrt(1.0 - spot_dir.x * spot_dir.x - spot_dir.z * spot_dir.z);
if ((light.flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
if (((*light).flags & POINT_LIGHT_FLAGS_SPOT_LIGHT_Y_NEGATIVE) != 0u) {
spot_dir.y = -spot_dir.y;
}

// view matrix z_axis is the reverse of transform.forward()
let fwd = -spot_dir;
let distance_to_light = dot(fwd, surface_to_light);
let offset_position =
-surface_to_light
+ (light.shadow_depth_bias * normalize(surface_to_light))
+ (surface_normal.xyz * light.shadow_normal_bias) * distance_to_light;
let offset_position =
-surface_to_light
+ ((*light).shadow_depth_bias * normalize(surface_to_light))
+ (surface_normal.xyz * (*light).shadow_normal_bias) * distance_to_light;

// the construction of the up and right vectors needs to precisely mirror the code
// the construction of the up and right vectors needs to precisely mirror the code
// in render/light.rs:spot_light_view_matrix
var sign = -1.0;
if (fwd.z >= 0.0) {
Expand All @@ -74,14 +74,14 @@ fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: ve
let right_dir = vec3<f32>(-b, -sign - fwd.y * fwd.y * a, fwd.y);
let light_inv_rot = mat3x3<f32>(right_dir, up_dir, fwd);

// because the matrix is a pure rotation matrix, the inverse is just the transpose, and to calculate
// the product of the transpose with a vector we can just post-multiply instead of pre-multplying.
// because the matrix is a pure rotation matrix, the inverse is just the transpose, and to calculate
// the product of the transpose with a vector we can just post-multiply instead of pre-multplying.
// this allows us to keep the matrix construction code identical between CPU and GPU.
let projected_position = offset_position * light_inv_rot;

// divide xy by perspective matrix "f" and by -projected.z (projected.z is -projection matrix's w)
// to get ndc coordinates
let f_div_minus_z = 1.0 / (light.spot_light_tan_angle * -projected_position.z);
let f_div_minus_z = 1.0 / ((*light).spot_light_tan_angle * -projected_position.z);
let shadow_xy_ndc = projected_position.xy * f_div_minus_z;
// convert to uv coordinates
let shadow_uv = shadow_xy_ndc * vec2<f32>(0.5, -0.5) + vec2<f32>(0.5, 0.5);
Expand All @@ -90,23 +90,23 @@ fn fetch_spot_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: ve
let depth = 0.1 / -projected_position.z;

#ifdef NO_ARRAY_TEXTURES_SUPPORT
return textureSampleCompare(directional_shadow_textures, directional_shadow_textures_sampler,
return textureSampleCompare(directional_shadow_textures, directional_shadow_textures_sampler,
shadow_uv, depth);
#else
return textureSampleCompareLevel(directional_shadow_textures, directional_shadow_textures_sampler,
return textureSampleCompareLevel(directional_shadow_textures, directional_shadow_textures_sampler,
shadow_uv, i32(light_id) + lights.spot_light_shadowmap_offset, depth);
#endif
}

fn fetch_directional_shadow(light_id: u32, frag_position: vec4<f32>, surface_normal: vec3<f32>) -> f32 {
let light = lights.directional_lights[light_id];
let light = &lights.directional_lights[light_id];

// The normal bias is scaled to the texel size.
let normal_offset = light.shadow_normal_bias * surface_normal.xyz;
let depth_offset = light.shadow_depth_bias * light.direction_to_light.xyz;
let normal_offset = (*light).shadow_normal_bias * surface_normal.xyz;
let depth_offset = (*light).shadow_depth_bias * (*light).direction_to_light.xyz;
let offset_position = vec4<f32>(frag_position.xyz + normal_offset + depth_offset, frag_position.w);

let offset_position_clip = light.view_projection * offset_position;
let offset_position_clip = (*light).view_projection * offset_position;
if (offset_position_clip.w <= 0.0) {
return 1.0;
}
Expand Down

0 comments on commit b44b606

Please sign in to comment.