refactored the NNI system so RDG can keep track of the style transfer operations

This commit is contained in:
Manuel Wagner 2022-11-21 15:06:06 +01:00
parent f16119ca84
commit 9a83d413fc
21 changed files with 94 additions and 56 deletions

View File

@ -222,9 +222,8 @@ VoiceChatVolumeControlBus=/Game/Audio/Modulation/ControlBuses/CB_VoiceChat.CB_Vo
LoadingScreenControlBusMix=/Game/Audio/Modulation/ControlBusMixes/CBM_LoadingScreenMix.CBM_LoadingScreenMix
[/Script/StyleTransfer.StyleTransferSettings]
StyleTransferNetwork=/StyleTransfer/NN_TransferVGG.NN_TransferVGG
StylePredictionNetwork=/StyleTransfer/NN_PredictorVGG.NN_PredictorVGG
+StyleTextures=/StyleTransfer/T_StyleImage.T_StyleImage
+StyleTextures=/StyleTransfer/T_StyleImage3.T_StyleImage3
StyleTransferNetwork=/StyleTransfer/rstnet-960-32-3_transfer.rstnet-960-32-3_transfer
StylePredictionNetwork=/StyleTransfer/rstnet-960-32-3_predictor.rstnet-960-32-3_predictor
+StyleTextures=/StyleTransfer/Styles/T_StyleImage1.T_StyleImage1
InterpolationCurve=(EditorCurveData=(Keys=((InterpMode=RCIM_Cubic,TangentMode=RCTM_User),(InterpMode=RCIM_Cubic,TangentMode=RCTM_User,Time=2.241476,Value=1.000000),(InterpMode=RCIM_Cubic,TangentMode=RCTM_User,Time=3.908512,Value=1.000000),(InterpMode=RCIM_Cubic,TangentMode=RCTM_User,Time=5.831762),(Time=7.971708)),DefaultValue=340282346638528859811704183484516925440.000000,PreInfinityExtrap=RCCE_Constant,PostInfinityExtrap=RCCE_Constant),ExternalCurve=None)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
Plugins/StyleTransfer/Content/Styles/T_StyleImage.uasset (Stored with Git LFS) Normal file

Binary file not shown.

BIN
Plugins/StyleTransfer/Content/Styles/T_StyleImage1.uasset (Stored with Git LFS) Normal file

Binary file not shown.

BIN
Plugins/StyleTransfer/Content/Styles/T_StyleImage2.uasset (Stored with Git LFS) Normal file

Binary file not shown.

BIN
Plugins/StyleTransfer/Content/Styles/T_StyleImage3.uasset (Stored with Git LFS) Normal file

Binary file not shown.

BIN
Plugins/StyleTransfer/Content/Styles/T_StyleImage4.uasset (Stored with Git LFS) Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -173,10 +173,8 @@ FRDGTexture* FStyleTransferSceneViewExtension::TensorToTexture(FRDGBuilder& Grap
FRDGTexture* OutputTexture = GraphBuilder.CreateTexture(
DestinationDesc, TEXT("OutputTexture"));
FRDGBufferRef SourceTensorBuffer = GraphBuilder.RegisterExternalBuffer(SourceTensor.GetPooledBuffer());
auto OutputTensorToSceneColorParameters = GraphBuilder.AllocParameters<FOutputTensorToSceneColorCS::FParameters>();
OutputTensorToSceneColorParameters->InputTensor = GraphBuilder.CreateSRV(SourceTensorBuffer, EPixelFormat::PF_R32_FLOAT);
OutputTensorToSceneColorParameters->InputTensor = SourceTensor.GetBufferSRVRef();
OutputTensorToSceneColorParameters->OutputTexture = GraphBuilder.CreateUAV(OutputTexture);
OutputTensorToSceneColorParameters->TensorVolume = SourceTensor.Num();
OutputTensorToSceneColorParameters->TextureSize = DestinationDesc.Extent;
@ -219,10 +217,9 @@ FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& B
FRDGTexture* OutputTexture = GraphBuilder.CreateTexture(
DestinationDesc, TEXT("OutputTexture"));
FRDGBufferRef SourceTensorBuffer = GraphBuilder.RegisterExternalBuffer(SourceTensor.GetPooledBuffer());
auto OutputTensorToSceneColorParameters = GraphBuilder.AllocParameters<FOutputTensorToSceneColorCS::FParameters>();
OutputTensorToSceneColorParameters->InputTensor = GraphBuilder.CreateSRV(SourceTensorBuffer, EPixelFormat::PF_R32_FLOAT);
OutputTensorToSceneColorParameters->InputTensor = SourceTensor.GetBufferSRVRef();
OutputTensorToSceneColorParameters->OutputTexture = GraphBuilder.CreateUAV(OutputTexture);
OutputTensorToSceneColorParameters->TensorVolume = SourceTensor.Num();
OutputTensorToSceneColorParameters->TextureSize = DestinationDesc.Extent;
@ -246,7 +243,7 @@ FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& B
return OutputTexture;
}
void TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor)
FRDGPassRef TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, FNeuralTensor& DestinationTensor)
{
const FIntVector InputTensorDimensions = {
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
@ -255,12 +252,11 @@ void TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture,
};
const FIntPoint RgbRenderTargetDimensions = SourceTexture->Desc.Extent;
const FRDGBufferRef DestinationTensorBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
FSceneColorToInputTensorCS::FParameters* RgbToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
RgbToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
RgbToInputTensorParameters->InputTexture = SourceTexture;
RgbToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
RgbToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(DestinationTensorBuffer);
RgbToInputTensorParameters->OutputUAV = DestinationTensor.GetBufferUAVRef();
RgbToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
RgbToInputTensorParameters->HalfPixelUV = FVector2f(0.5f / RgbRenderTargetDimensions.X, 0.5 / RgbRenderTargetDimensions.Y);
FIntVector ComputeGroupCount = FComputeShaderUtils::GetGroupCount(
@ -269,7 +265,7 @@ void TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture,
);
TShaderMapRef<FSceneColorToInputTensorCS> RgbToInputTensorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
return GraphBuilder.AddPass(
RDG_EVENT_NAME("TextureToTensorRGB(%s)", FSceneColorToInputTensorCS::StaticType.GetName()),
RgbToInputTensorParameters,
ERDGPassFlags::Compute,
@ -281,7 +277,7 @@ void TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture,
);
}
void TextureToTensorGrayscale(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor)
FRDGPassRef TextureToTensorGrayscale(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, FNeuralTensor& DestinationTensor)
{
const FIntVector InputTensorDimensions = {
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
@ -290,12 +286,11 @@ void TextureToTensorGrayscale(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTe
};
const FIntPoint GrayscaleRenderTargetDimensions = SourceTexture->Desc.Extent;
const FRDGBufferRef DestinationTensorBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
FShadowMaskToInputTensorCS::FParameters* GrayscaleToInputTensorParameters = GraphBuilder.AllocParameters<FShadowMaskToInputTensorCS::FParameters>();
GrayscaleToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
GrayscaleToInputTensorParameters->InputTexture = SourceTexture;
GrayscaleToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
GrayscaleToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(DestinationTensorBuffer);
GrayscaleToInputTensorParameters->OutputUAV = DestinationTensor.GetBufferUAVRef();
GrayscaleToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
GrayscaleToInputTensorParameters->HalfPixelUV = FVector2f(0.5f / GrayscaleRenderTargetDimensions.X, 0.5 / GrayscaleRenderTargetDimensions.Y);
FIntVector ComputeGroupCount = FComputeShaderUtils::GetGroupCount(
@ -304,7 +299,7 @@ void TextureToTensorGrayscale(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTe
);
TShaderMapRef<FShadowMaskToInputTensorCS> GrayscaleToInputTensorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
return GraphBuilder.AddPass(
RDG_EVENT_NAME("TextureToTensorGrayscale(%s)", FShadowMaskToInputTensorCS::StaticType.GetName()),
GrayscaleToInputTensorParameters,
ERDGPassFlags::Compute,
@ -316,24 +311,19 @@ void TextureToTensorGrayscale(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTe
);
}
void FStyleTransferSceneViewExtension::TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor)
void FStyleTransferSceneViewExtension::TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, FNeuralTensor& DestinationTensor)
{
::TextureToTensorRGB(GraphBuilder, SourceTexture, DestinationTensor);
}
void FStyleTransferSceneViewExtension::InterpolateTensors(FRDGBuilder& GraphBuilder, const FNeuralTensor& DestinationTensor, const FNeuralTensor& InputTensorA, const FNeuralTensor& InputTensorB, float Alpha)
void FStyleTransferSceneViewExtension::InterpolateTensors(FRDGBuilder& GraphBuilder, FNeuralTensor& DestinationTensor, const FNeuralTensor& InputTensorA, const FNeuralTensor& InputTensorB, float Alpha)
{
RDG_EVENT_SCOPE(GraphBuilder, "InterpolateTensors");
const FRDGBufferRef DestinationBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
const FRDGBufferRef InputBufferA = GraphBuilder.RegisterExternalBuffer(InputTensorA.GetPooledBuffer());
const FRDGBufferRef InputBufferB = GraphBuilder.RegisterExternalBuffer(InputTensorB.GetPooledBuffer());
auto InterpolateTensorsParameters = GraphBuilder.AllocParameters<FInterpolateTensorsCS::FParameters>();
InterpolateTensorsParameters->InputSrvA = GraphBuilder.CreateSRV(InputBufferA, EPixelFormat::PF_R32_FLOAT);
InterpolateTensorsParameters->InputSrvB = GraphBuilder.CreateSRV(InputBufferB, EPixelFormat::PF_R32_FLOAT);
InterpolateTensorsParameters->OutputUAV = GraphBuilder.CreateUAV(DestinationBuffer);
InterpolateTensorsParameters->InputSrvA = InputTensorA.GetBufferSRVRef();
InterpolateTensorsParameters->InputSrvB = InputTensorB.GetBufferSRVRef();
InterpolateTensorsParameters->OutputUAV = DestinationTensor.GetBufferUAVRef();
InterpolateTensorsParameters->Alpha = Alpha;
InterpolateTensorsParameters->TensorVolume = DestinationTensor.Num();
FIntVector InterpolateTensorsThreadGroupCount = FComputeShaderUtils::GetGroupCount(
@ -379,11 +369,16 @@ FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap
checkSlow(View.bIsViewInfo);
const FViewInfo& ViewInfo = static_cast<const FViewInfo&>(View);
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, ContentInputTensorIndex);
FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContextMutable(*InferenceContext, ContentInputTensorIndex);
FNeuralTensor& StyleTransferStyleParamsInputTensor = StyleTransferNetwork->GetInputTensorForContextMutable(*InferenceContext, StyleParamsInputTensorIndex);
StyleTransferContentInputTensor.GPUToRDGBuilder_RenderThread(&GraphBuilder);
StyleTransferStyleParamsInputTensor.GPUToRDGBuilder_RenderThread(&GraphBuilder);
if (StyleWeightsInputTensorIndex != INDEX_NONE)
{
const FNeuralTensor& StyleTransferStyleWeightsInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, StyleWeightsInputTensorIndex);
FNeuralTensor& StyleTransferStyleWeightsInputTensor = StyleTransferNetwork->GetInputTensorForContextMutable(*InferenceContext, StyleWeightsInputTensorIndex);
StyleTransferStyleWeightsInputTensor.GPUToRDGBuilder_RenderThread(&GraphBuilder);
check(GScreenShadowMaskTexture);
::TextureToTensorGrayscale(GraphBuilder, GScreenShadowMaskTexture, StyleTransferStyleWeightsInputTensor);
@ -393,7 +388,8 @@ FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap
StyleTransferNetwork->Run(GraphBuilder, *InferenceContext);
const FNeuralTensor& StyleTransferContentOutputTensor = StyleTransferNetwork->GetOutputTensorForContext(*InferenceContext, 0);
FNeuralTensor& StyleTransferContentOutputTensor = StyleTransferNetwork->GetOutputTensorForContextMutable(*InferenceContext, 0);
FRDGTexture* StyleTransferRenderTargetTexture = TensorToTexture(GraphBuilder, SceneColor.Texture->Desc, StyleTransferContentOutputTensor);
TSharedPtr<FScreenPassRenderTarget> StyleTransferOutputTarget = MakeShared<FScreenPassRenderTarget>(StyleTransferRenderTargetTexture, SceneColor.ViewRect,

View File

@ -92,6 +92,7 @@ void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportCl
FTextureCompilingManager::Get().FinishCompilation({StyleTexture});
#endif
UpdateStyle(StyleTexture, i, StylePredictionInferenceContext);
//UpdateStyle(StyleTexture, i, StylePredictionInferenceContext);
}
//UpdateStyle(FPaths::GetPath("C:\\projects\\realtime-style-transfer\\temp\\style_params_tensor.bin"));
UE_LOG(LogStyleTransfer, Log, TEXT("Creating FStyleTransferSceneViewExtension"));
@ -136,23 +137,26 @@ void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture, uint32 Style
ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StyleTexture, StylePredictionInferenceContext, StyleIndex](FRHICommandListImmediate& RHICommandList)
{
IRenderCaptureProvider* RenderCaptureProvider = ConditionalBeginRenderCapture(RHICommandList);
FRDGBuilder GraphBuilder(RHICommandList);
{
RDG_EVENT_SCOPE(GraphBuilder, "StylePrediction");
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContextMutable(StylePredictionInferenceContext, 0);
FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContextMutable(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
InputStyleImageTensor.GPUToRDGBuilder_RenderThread(&GraphBuilder);
InputStyleParams.GPUToRDGBuilder_RenderThread(&GraphBuilder);
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
FRDGTextureRef RDGStyleTexture = GraphBuilder.RegisterExternalTexture(CreateRenderTarget(StyleTextureResource->TextureRHI, TEXT("StyleInputTexture")));
FStyleTransferSceneViewExtension::TextureToTensorRGB(GraphBuilder, RDGStyleTexture, InputStyleImageTensor);
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContextMutable(StylePredictionInferenceContext, 0);
FRDGBufferRef OutputStyleParamsBuffer = OutputStyleParams.GetBufferSRVRef()->GetParent();
FRDGBufferRef InputStyleParamsBuffer = InputStyleParams.GetBufferUAVRef()->GetParent();
const uint64 NumBytes = OutputStyleParams.NumInBytes();
const uint64 DstOffset = StyleIndex * NumBytes;
@ -162,12 +166,12 @@ void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture, uint32 Style
Parameters->DstBuffer = InputStyleParamsBuffer;
GraphBuilder.AddPass(
RDG_EVENT_NAME("CopyBuffer(%s Size=%ubytes)", Parameters->SrcBuffer, Parameters->SrcBuffer->Desc.GetSize()),
RDG_EVENT_NAME("CopyBuffer(%s Size=%ubytes)", Parameters->SrcBuffer->Name, Parameters->SrcBuffer->Desc.GetSize()),
Parameters,
ERDGPassFlags::Copy,
[Parameters, NumBytes, DstOffset](FRHICommandList& RHICmdList)
{
RHICmdList.CopyBufferRegion(Parameters->DstBuffer->GetRHI(), DstOffset, Parameters->SrcBuffer->GetRHI(), 0, NumBytes);
RHICmdList.CopyBufferRegion(Parameters->DstBuffer->GetRHI(), /*DstOffset*/ 0, Parameters->SrcBuffer->GetRHI(), 0, NumBytes);
});
}
GraphBuilder.Execute();
@ -185,7 +189,7 @@ void UStyleTransferSubsystem::UpdateStyle(FString StyleTensorDataPath)
{
FArchive& FileReader = *IFileManager::Get().CreateFileReader(*StyleTensorDataPath);
TArray<float> StyleParams;
StyleParams.SetNumUninitialized(2758);
StyleParams.SetNumUninitialized(2758); // hardcoded for brevity reasons
FileReader << StyleParams;
@ -268,9 +272,12 @@ void UStyleTransferSubsystem::InterpolateStyles(int32 StylePredictionInferenceCo
{
RDG_EVENT_SCOPE(GraphBuilder, "StylePrediction");
const FNeuralTensor& InputStyleImageTensorA = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContextA, 0);
const FNeuralTensor& InputStyleImageTensorB = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContextB, 0);
const FNeuralTensor& OutputStyleParamsTensor = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FNeuralTensor& InputStyleImageTensorA = StylePredictionNetwork->GetOutputTensorForContextMutable(StylePredictionInferenceContextA, 0);
FNeuralTensor& InputStyleImageTensorB = StylePredictionNetwork->GetOutputTensorForContextMutable(StylePredictionInferenceContextB, 0);
FNeuralTensor& OutputStyleParamsTensor = StyleTransferNetwork->GetInputTensorForContextMutable(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
InputStyleImageTensorA.GPUToRDGBuilder_RenderThread(&GraphBuilder);
InputStyleImageTensorB.GPUToRDGBuilder_RenderThread(&GraphBuilder);
OutputStyleParamsTensor.GPUToRDGBuilder_RenderThread(&GraphBuilder);
FStyleTransferSceneViewExtension::InterpolateTensors(GraphBuilder, OutputStyleParamsTensor, InputStyleImageTensorA, InputStyleImageTensorB, Alpha);
}
GraphBuilder.Execute();

View File

@ -42,9 +42,9 @@ public:
static void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget);
static FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor);
static void TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor);
static void TextureToTensorGrayscale(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor);
static void InterpolateTensors(FRDGBuilder& GraphBuilder, const FNeuralTensor& DestinationTensor, const FNeuralTensor& InputTensorA, const FNeuralTensor& InputTensorB, float Alpha);
static void TextureToTensorRGB(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, FNeuralTensor& DestinationTensor);
static void TextureToTensorGrayscale(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, FNeuralTensor& DestinationTensor);
static void InterpolateTensors(FRDGBuilder& GraphBuilder, FNeuralTensor& DestinationTensor, const FNeuralTensor& InputTensorA, const FNeuralTensor& InputTensorB, float Alpha);
private:
/** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */