Added runtime style prediction

This commit is contained in:
Manuel Wagner 2022-09-05 13:38:05 +02:00
parent 20144b8b59
commit b580ba6316
5 changed files with 68 additions and 29 deletions

Binary file not shown.

View File

@ -113,6 +113,7 @@ void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProces
void FStyleTransferSceneViewExtension::PreRenderViewFamily_RenderThread(FRDGBuilder& GraphBuilder, FSceneViewFamily& InViewFamily)
{
return;
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
if(!IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
return;
@ -181,19 +182,19 @@ FRDGTexture* FStyleTransferSceneViewExtension::TensorToTexture(FRDGBuilder& Grap
return OutputTexture;
}
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor)
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor)
{
const FIntVector InputTensorDimensions = {
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
CastNarrowingSafe<int32>(DestinationTensor.GetSize(2)),
CastNarrowingSafe<int32>(DestinationTensor.GetSize(3)),
};
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture.Texture->Desc.Extent;
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture->Desc.Extent;
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
SceneColorToInputTensorParameters->InputTexture = SourceTexture.Texture;
SceneColorToInputTensorParameters->InputTexture = SourceTexture;
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
@ -239,7 +240,7 @@ FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
TextureToTensor(GraphBuilder, SceneColor, StyleTransferContentInputTensor);
TextureToTensor(GraphBuilder, SceneColor.Texture, StyleTransferContentInputTensor);
StyleTransferNetwork->Run(GraphBuilder, *InferenceContext);

View File

@ -3,11 +3,13 @@
#include "StyleTransferSubsystem.h"
#include "IRenderCaptureProvider.h"
#include "NeuralNetwork.h"
#include "RenderGraphUtils.h"
#include "ScreenPass.h"
#include "StyleTransferModule.h"
#include "StyleTransferSceneViewExtension.h"
#include "TextureCompiler.h"
#include "Rendering/Texture2DResource.h"
TAutoConsoleVariable<bool> CVarStyleTransferEnabled(
@ -47,11 +49,14 @@ void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportCl
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'"));
UpdateStyle(StyleTexture);
FTextureCompilingManager::Get().FinishCompilation({StyleTexture});
//UpdateStyle(StyleTexture);
UpdateStyle(FPaths::GetPath("C:\\projects\\realtime-style-transfer\\temp\\style_params_tensor.bin"));
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
}
StyleTransferSceneViewExtension->SetEnabled(true);
ViewportClient->GetWorld()->GetWorldSettings()->SetPauserPlayerState(ViewportClient->GetWorld()->GetFirstPlayerController()->PlayerState);
}
void UStyleTransferSubsystem::StopStylizingViewport()
@ -74,22 +79,25 @@ void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture)
{
checkf(StyleTransferInferenceContext.IsValid() && (*StyleTransferInferenceContext) != INDEX_NONE, TEXT("Can not infer style without inference context"));
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Can not update style without inference context"));
FlushRenderingCommands();
ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StyleTexture](FRHICommandListImmediate& RHICommandList)
{
FRDGBuilder GraphBuilder(RHICommandList);
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
if(!StyleTextureResource->IsInitialized())
IRenderCaptureProvider* RenderCaptureProvider = nullptr;
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
if(IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
{
StyleTextureResource->UpdateRHI();
RenderCaptureProvider = &IModularFeatures::Get().GetModularFeature<IRenderCaptureProvider>(RenderCaptureProviderType);
RenderCaptureProvider->BeginCapture(&RHICommandList);
}
FRDGBuilder GraphBuilder(RHICommandList);
{
RDG_EVENT_SCOPE(GraphBuilder, "StylePrediction");
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
FRDGTextureRef RDGStyleTexture = GraphBuilder.RegisterExternalTexture(CreateRenderTarget(StyleTextureResource->TextureRHI, TEXT("StyleInputTexture")));
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, FScreenPassTexture(RDGStyleTexture), InputStyleImageTensor);
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, RDGStyleTexture, InputStyleImageTensor);
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
@ -104,13 +112,42 @@ void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture)
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
}
GraphBuilder.Execute();
if(RenderCaptureProvider)
{
RenderCaptureProvider->EndCapture(&RHICommandList);
}
});
FlushRenderingCommands();
}
void UStyleTransferSubsystem::UpdateStyle(FString StyleTensorDataPath)
{
FArchive& FileReader = *IFileManager::Get().CreateFileReader(*StyleTensorDataPath);
TArray<float> StyleParams;
StyleParams.SetNumUninitialized(192);
FileReader << StyleParams;
ENQUEUE_RENDER_COMMAND(StyleParamsLoad)([this, StyleParams = MoveTemp(StyleParams)](FRHICommandListImmediate& RHICommandList)
{
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FRDGBuilder GraphBuilder(RHICommandList);
{
RDG_EVENT_SCOPE(GraphBuilder, "StyleParamsLoad");
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
GraphBuilder.QueueBufferUpload(InputStyleParamsBuffer, StyleParams.GetData(), StyleParams.Num() * StyleParams.GetTypeSize(), ERDGInitialDataFlags::NoCopy);
}
GraphBuilder.Execute();
});
FlushRenderingCommands();
}
void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable)
{
check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable());

View File

@ -38,7 +38,7 @@ public:
static void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget);
static FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor);
static void TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor);
static void TextureToTensor(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor);
private:
/** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */

View File

@ -26,6 +26,7 @@ public:
void StopStylizingViewport();
void UpdateStyle(UTexture2D* StyleTexture);
void UpdateStyle(FString StyleTensorDataPath);
private:
FStyleTransferSceneViewExtension::Ptr StyleTransferSceneViewExtension;