Skip to main content

Overwriting shaders

note

The code described in this section can be run from the overwriting_shaders.rs example with cargo run --features derive --example overwriting_shaders. Complete source code is also provided at the end of this page.

It is possible to replace the WGSL sources of any shader by another source file (assuming it exports the same symbols and/or compute kernels). This can be useful in particular if you need to inject some special behavior into a particular WGSL function, fix a bug, or improve performances, in a shader belonging to a dependency you don’t have control over. It can also be used to overwrite any embedded shader with a source file available separately from your executable (for example if some of your game shaders are bundled in a separate assets directory).

Given a shader deriving the Shader trait with derive(Shader), overwriting its sources only require a single line of code:

MyShader::set_wgsl_path("path/to/your/shader.wgsl");

Since this is static method, this will set the wgsl path globally. In order to take this change into account, any compute pipeline that was constructed from this shader (or that depends directly or indirectly on this shader) have to be re-created; subsequent calls to Shader::from_device, as well as any other function provided by the Shader trait, will automatically take the new path into account. Typically, you would set this path before constructing any compute pipeline.

info

The shader at the path set with set_wgsl_path will not be embedded in the built executable or library. So be sure that the path is always valid relative to wherever your executable will be run. Such shaders will typically be deployed alongside the executable in a known path.

All the overwritten shader paths are registered globally in wgcore::shader::ShaderRegistry. This is essentially a hash-map between a shader type and its overwritten path. Another way to obtain the known path for a shader is by calling Shader::wgsl_path.

Complete example

#[cfg(not(feature = "derive"))]
std::compile_error!(
r#"
###############################################################
## The `derive` feature must be enabled to run this example. ##
###############################################################
"#
);

use nalgebra::DVector;
use std::fmt::Debug;
use wgcore::gpu::GpuInstance;
use wgcore::kernel::{KernelInvocationBuilder, KernelInvocationQueue};
use wgcore::tensor::GpuVector;
use wgcore::Shader;
use wgpu::{BufferUsages, ComputePipeline};

// Declare our shader module that contains our composable functions.
// Note that we don’t build any compute pipeline from this wgsl file.
#[derive(Shader)]
#[shader(
src = "compose_dependency.wgsl" // Shader source code, will be embedded in the exe with `include_str!`
)]
struct Composable;

#[derive(Shader)]
#[shader(
derive(Composable), // This shader depends on the `Composable` shader.
src = "compose_kernel.wgsl", // Shader source code, will be embedded in the exe with `include_str!`.
composable = false // This shader doesn’t export any symbols reusable from other wgsl shaders.
)]
struct WgKernel {
// This ComputePipeline field indicates that the Shader macro needs to generate the boilerplate
// for loading the compute pipeline in `WgKernel::from_device`.
main: ComputePipeline,
}

#[derive(Copy, Clone, PartialEq, Default, bytemuck::Pod, bytemuck::Zeroable)]
#[repr(C)]
pub struct MyStruct {
value: f32,
}

// Optional: makes the debug output more concise.
impl Debug for MyStruct {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.value)
}
}

#[async_std::main]
async fn main() -> anyhow::Result<()> {
// Initialize the gpu device and its queue.
//
// Note that `GpuInstance` is just a simple helper struct for initializing the gpu resources.
// You are free to initialize them independently if more control is needed, or reuse the ones
// that were already created/owned by e.g., a game engine.
let gpu = GpuInstance::new().await?;

// Load and compile our kernel. The `from_device` function was generated by the `Shader` derive.
// Note that its dependency to `Composable` is automatically resolved by the `Shader` derive
// too.
let kernel_before_overwrite = WgKernel::from_device(gpu.device())?;
// Run the original shader.
let result_before_overwrite = run_kernel(&gpu, &kernel_before_overwrite).await;

// Overwrite the sources of the dependency module.
// Since we are running this with `cargo run --example`, the path is relative to the
// `target/debug` folder.
Composable::set_wgsl_path("../../crates/wgcore/examples/overwritten_dependency.wgsl");
// Recompile our kernel.
let kernel_after_overwrite = WgKernel::from_device(gpu.device())?;
// Run the modified kernel.
let result_after_overwrite = run_kernel(&gpu, &kernel_after_overwrite).await;

println!("Result before overwrite: {:?}", result_before_overwrite);
println!("Result after overwrite: {:?}", result_after_overwrite);

Ok(())
}

async fn run_kernel(gpu: &GpuInstance, kernel: &WgKernel) -> Vec<MyStruct> {
// Create the buffers.
const LEN: u32 = 10;
let a_data = DVector::from_fn(LEN as usize, |i, _| MyStruct { value: i as f32 });
let b_data = DVector::from_fn(LEN as usize, |i, _| MyStruct {
value: i as f32 * 10.0,
});
let a_buf = GpuVector::init(
gpu.device(),
&a_data,
BufferUsages::STORAGE | BufferUsages::COPY_SRC,
);
let b_buf = GpuVector::init(gpu.device(), &b_data, BufferUsages::STORAGE);
let staging = GpuVector::uninit(
gpu.device(),
LEN,
BufferUsages::COPY_DST | BufferUsages::MAP_READ,
);

// Queue the operation.
let mut queue = KernelInvocationQueue::new(gpu.device());
KernelInvocationBuilder::new(&mut queue, &kernel.main)
.bind0([a_buf.buffer(), b_buf.buffer()])
.queue(LEN.div_ceil(64));

// Encode & submit the operation to the gpu.
let mut encoder = gpu.device().create_command_encoder(&Default::default());
queue.encode(&mut encoder, None);
// Copy the result to the staging buffer.
staging.copy_from(&mut encoder, &a_buf);
gpu.queue().submit(Some(encoder.finish()));

// Read the result back from the gpu.
staging
.read(gpu.device())
.await
.expect("Failed to read result from the GPU.")
}