Added runtime style prediction
This commit is contained in:
parent
20144b8b59
commit
b580ba6316
BIN
Plugins/StyleTransfer/Content/T_StyleImage.uasset (Stored with Git LFS)
BIN
Plugins/StyleTransfer/Content/T_StyleImage.uasset (Stored with Git LFS)
Binary file not shown.
|
@ -113,6 +113,7 @@ void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProces
|
|||
|
||||
void FStyleTransferSceneViewExtension::PreRenderViewFamily_RenderThread(FRDGBuilder& GraphBuilder, FSceneViewFamily& InViewFamily)
|
||||
{
|
||||
return;
|
||||
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
|
||||
if(!IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
|
||||
return;
|
||||
|
@ -181,19 +182,19 @@ FRDGTexture* FStyleTransferSceneViewExtension::TensorToTexture(FRDGBuilder& Grap
|
|||
return OutputTexture;
|
||||
}
|
||||
|
||||
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor)
|
||||
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor)
|
||||
{
|
||||
const FIntVector InputTensorDimensions = {
|
||||
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
|
||||
CastNarrowingSafe<int32>(DestinationTensor.GetSize(2)),
|
||||
CastNarrowingSafe<int32>(DestinationTensor.GetSize(3)),
|
||||
};
|
||||
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture.Texture->Desc.Extent;
|
||||
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture->Desc.Extent;
|
||||
|
||||
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
|
||||
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
|
||||
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
|
||||
SceneColorToInputTensorParameters->InputTexture = SourceTexture.Texture;
|
||||
SceneColorToInputTensorParameters->InputTexture = SourceTexture;
|
||||
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
|
||||
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
|
||||
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
|
||||
|
@ -239,7 +240,7 @@ FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap
|
|||
|
||||
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
|
||||
|
||||
TextureToTensor(GraphBuilder, SceneColor, StyleTransferContentInputTensor);
|
||||
TextureToTensor(GraphBuilder, SceneColor.Texture, StyleTransferContentInputTensor);
|
||||
|
||||
StyleTransferNetwork->Run(GraphBuilder, *InferenceContext);
|
||||
|
||||
|
|
|
@ -3,11 +3,13 @@
|
|||
|
||||
#include "StyleTransferSubsystem.h"
|
||||
|
||||
#include "IRenderCaptureProvider.h"
|
||||
#include "NeuralNetwork.h"
|
||||
#include "RenderGraphUtils.h"
|
||||
#include "ScreenPass.h"
|
||||
#include "StyleTransferModule.h"
|
||||
#include "StyleTransferSceneViewExtension.h"
|
||||
#include "TextureCompiler.h"
|
||||
#include "Rendering/Texture2DResource.h"
|
||||
|
||||
TAutoConsoleVariable<bool> CVarStyleTransferEnabled(
|
||||
|
@ -47,11 +49,14 @@ void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportCl
|
|||
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
|
||||
|
||||
UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'"));
|
||||
UpdateStyle(StyleTexture);
|
||||
FTextureCompilingManager::Get().FinishCompilation({StyleTexture});
|
||||
//UpdateStyle(StyleTexture);
|
||||
UpdateStyle(FPaths::GetPath("C:\\projects\\realtime-style-transfer\\temp\\style_params_tensor.bin"));
|
||||
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
|
||||
|
||||
}
|
||||
StyleTransferSceneViewExtension->SetEnabled(true);
|
||||
ViewportClient->GetWorld()->GetWorldSettings()->SetPauserPlayerState(ViewportClient->GetWorld()->GetFirstPlayerController()->PlayerState);
|
||||
}
|
||||
|
||||
void UStyleTransferSubsystem::StopStylizingViewport()
|
||||
|
@ -74,22 +79,25 @@ void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture)
|
|||
{
|
||||
checkf(StyleTransferInferenceContext.IsValid() && (*StyleTransferInferenceContext) != INDEX_NONE, TEXT("Can not infer style without inference context"));
|
||||
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Can not update style without inference context"));
|
||||
|
||||
FlushRenderingCommands();
|
||||
ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StyleTexture](FRHICommandListImmediate& RHICommandList)
|
||||
{
|
||||
FRDGBuilder GraphBuilder(RHICommandList);
|
||||
|
||||
|
||||
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
|
||||
|
||||
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
|
||||
if(!StyleTextureResource->IsInitialized())
|
||||
IRenderCaptureProvider* RenderCaptureProvider = nullptr;
|
||||
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
|
||||
if(IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
|
||||
{
|
||||
StyleTextureResource->UpdateRHI();
|
||||
RenderCaptureProvider = &IModularFeatures::Get().GetModularFeature<IRenderCaptureProvider>(RenderCaptureProviderType);
|
||||
RenderCaptureProvider->BeginCapture(&RHICommandList);
|
||||
}
|
||||
|
||||
FRDGBuilder GraphBuilder(RHICommandList);
|
||||
{
|
||||
RDG_EVENT_SCOPE(GraphBuilder, "StylePrediction");
|
||||
|
||||
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
|
||||
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
|
||||
FRDGTextureRef RDGStyleTexture = GraphBuilder.RegisterExternalTexture(CreateRenderTarget(StyleTextureResource->TextureRHI, TEXT("StyleInputTexture")));
|
||||
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, FScreenPassTexture(RDGStyleTexture), InputStyleImageTensor);
|
||||
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, RDGStyleTexture, InputStyleImageTensor);
|
||||
|
||||
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
|
||||
|
||||
|
@ -104,13 +112,42 @@ void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture)
|
|||
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
|
||||
|
||||
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
|
||||
|
||||
}
|
||||
GraphBuilder.Execute();
|
||||
|
||||
if(RenderCaptureProvider)
|
||||
{
|
||||
RenderCaptureProvider->EndCapture(&RHICommandList);
|
||||
}
|
||||
});
|
||||
|
||||
FlushRenderingCommands();
|
||||
}
|
||||
|
||||
void UStyleTransferSubsystem::UpdateStyle(FString StyleTensorDataPath)
|
||||
{
|
||||
FArchive& FileReader = *IFileManager::Get().CreateFileReader(*StyleTensorDataPath);
|
||||
TArray<float> StyleParams;
|
||||
StyleParams.SetNumUninitialized(192);
|
||||
|
||||
FileReader << StyleParams;
|
||||
|
||||
ENQUEUE_RENDER_COMMAND(StyleParamsLoad)([this, StyleParams = MoveTemp(StyleParams)](FRHICommandListImmediate& RHICommandList)
|
||||
{
|
||||
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
|
||||
|
||||
FRDGBuilder GraphBuilder(RHICommandList);
|
||||
{
|
||||
RDG_EVENT_SCOPE(GraphBuilder, "StyleParamsLoad");
|
||||
|
||||
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
|
||||
GraphBuilder.QueueBufferUpload(InputStyleParamsBuffer, StyleParams.GetData(), StyleParams.Num() * StyleParams.GetTypeSize(), ERDGInitialDataFlags::NoCopy);
|
||||
}
|
||||
GraphBuilder.Execute();
|
||||
});
|
||||
FlushRenderingCommands();
|
||||
}
|
||||
|
||||
void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable)
|
||||
{
|
||||
check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable());
|
||||
|
|
|
@ -38,7 +38,7 @@ public:
|
|||
|
||||
static void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget);
|
||||
static FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor);
|
||||
static void TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor);
|
||||
static void TextureToTensor(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor);
|
||||
|
||||
private:
|
||||
/** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */
|
||||
|
|
|
@ -26,6 +26,7 @@ public:
|
|||
void StopStylizingViewport();
|
||||
|
||||
void UpdateStyle(UTexture2D* StyleTexture);
|
||||
void UpdateStyle(FString StyleTensorDataPath);
|
||||
|
||||
private:
|
||||
FStyleTransferSceneViewExtension::Ptr StyleTransferSceneViewExtension;
|
||||
|
|
Loading…
Reference in New Issue