Added StyleTransfer settings in project settings

imported network with actual weights
This commit is contained in:
Manuel Wagner 2022-09-05 18:12:20 +02:00
parent b580ba6316
commit 0c2a994eac
9 changed files with 58 additions and 19 deletions

View File

@ -218,3 +218,8 @@ VoiceChatVolumeControlBus=/Game/Audio/Modulation/ControlBuses/CB_VoiceChat.CB_Vo
LoadingScreenControlBusMix=/Game/Audio/Modulation/ControlBusMixes/CBM_LoadingScreenMix.CBM_LoadingScreenMix LoadingScreenControlBusMix=/Game/Audio/Modulation/ControlBusMixes/CBM_LoadingScreenMix.CBM_LoadingScreenMix
[/Script/StyleTransfer.StyleTransferSettings]
StyleTransferNetwork=/StyleTransfer/NN_TransferWithWeights.NN_TransferWithWeights
StylePredictionNetwork=/StyleTransfer/NN_StylePredictorWithWeights.NN_StylePredictorWithWeights
StyleTexture=/StyleTransfer/T_StyleImage.T_StyleImage

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
Plugins/StyleTransfer/Content/NN_TransferWithWeights.uasset (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -0,0 +1,9 @@
// Fill out your copyright notice in the Description page of Project Settings.
#include "StyleTransferSettings.h"
UStyleTransferSettings::UStyleTransferSettings()
{
this->CategoryName = NAME_Game;
}

View File

@ -0,0 +1,28 @@
// Fill out your copyright notice in the Description page of Project Settings.
#pragma once
#include "CoreMinimal.h"
#include "NeuralNetwork.h"
#include "UObject/Object.h"
#include "StyleTransferSettings.generated.h"
/**
*
*/
UCLASS(Config=Game, defaultconfig, meta=(DisplayName="Style Transfer"))
class STYLETRANSFER_API UStyleTransferSettings : public UDeveloperSettings
{
GENERATED_BODY()
public:
UStyleTransferSettings();
UPROPERTY(EditAnywhere, Config)
TSoftObjectPtr<UNeuralNetwork> StyleTransferNetwork = nullptr;
UPROPERTY(EditAnywhere, Config)
TSoftObjectPtr<UNeuralNetwork> StylePredictionNetwork = nullptr;
UPROPERTY(EditAnywhere, Config)
TSoftObjectPtr<UTexture2D> StyleTexture = nullptr;
};

View File

@ -9,6 +9,7 @@
#include "ScreenPass.h" #include "ScreenPass.h"
#include "StyleTransferModule.h" #include "StyleTransferModule.h"
#include "StyleTransferSceneViewExtension.h" #include "StyleTransferSceneViewExtension.h"
#include "StyleTransferSettings.h"
#include "TextureCompiler.h" #include "TextureCompiler.h"
#include "Rendering/Texture2DResource.h" #include "Rendering/Texture2DResource.h"
@ -43,20 +44,21 @@ void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportCl
if (!StyleTransferSceneViewExtension) if (!StyleTransferSceneViewExtension)
{ {
const UStyleTransferSettings* StyleTransferSettings = GetDefault<UStyleTransferSettings>();
StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext(); StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext();
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StylePredictionNetwork")); checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StylePredictionNetwork"));
StyleTransferInferenceContext = MakeShared<int32>(StyleTransferNetwork->CreateInferenceContext()); StyleTransferInferenceContext = MakeShared<int32>(StyleTransferNetwork->CreateInferenceContext());
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork")); checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'")); UTexture2D* StyleTexture = StyleTransferSettings->StyleTexture.LoadSynchronous();
//UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'"));
FTextureCompilingManager::Get().FinishCompilation({StyleTexture}); FTextureCompilingManager::Get().FinishCompilation({StyleTexture});
//UpdateStyle(StyleTexture); UpdateStyle(StyleTexture);
UpdateStyle(FPaths::GetPath("C:\\projects\\realtime-style-transfer\\temp\\style_params_tensor.bin")); //UpdateStyle(FPaths::GetPath("C:\\projects\\realtime-style-transfer\\temp\\style_params_tensor.bin"));
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef()); StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
} }
StyleTransferSceneViewExtension->SetEnabled(true); StyleTransferSceneViewExtension->SetEnabled(true);
ViewportClient->GetWorld()->GetWorldSettings()->SetPauserPlayerState(ViewportClient->GetWorld()->GetFirstPlayerController()->PlayerState);
} }
void UStyleTransferSubsystem::StopStylizingViewport() void UStyleTransferSubsystem::StopStylizingViewport()
@ -166,8 +168,9 @@ void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* Con
void UStyleTransferSubsystem::LoadNetworks() void UStyleTransferSubsystem::LoadNetworks()
{ {
StyleTransferNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StyleTransfer.NN_StyleTransfer")); const UStyleTransferSettings* StyleTransferSettings = GetDefault<UStyleTransferSettings>();
StylePredictionNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StylePredictor.NN_StylePredictor")); StyleTransferNetwork = StyleTransferSettings->StyleTransferNetwork.LoadSynchronous();
StylePredictionNetwork = StyleTransferSettings->StylePredictionNetwork.LoadSynchronous();
if (StyleTransferNetwork->IsLoaded()) if (StyleTransferNetwork->IsLoaded())
{ {

View File

@ -45,13 +45,7 @@ public class StyleTransfer : ModuleRules
"StyleTransferShaders", "StyleTransferShaders",
"PixWinPlugin", "PixWinPlugin",
"InputDevice", "InputDevice",
} "DeveloperSettings",
);
DynamicallyLoadedModuleNames.AddRange(
new string[]
{
} }
); );
} }