-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathControlNetRepository.cpp
More file actions
208 lines (174 loc) · 6.32 KB
/
ControlNetRepository.cpp
File metadata and controls
208 lines (174 loc) · 6.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#include "pch.h"
#include "ControlNetRepository.h"
using namespace Axodox::MachineLearning::Web;
using namespace std;
using namespace winrt::Windows::Storage;
namespace winrt::Unpaint
{
const char* const ControlNetRepository::_controlnetRepository = "axodoxian/controlnet_onnx";
const std::unordered_map<std::string, std::vector<std::string>> ControlNetRepository::_annotators = {
{ "canny", { "canny.onnx" } },
{ "depth", { "depth.onnx" } },
{ "hed", { "hed.onnx" } },
{ "openpose", { "openpose.onnx" } }
};
std::vector<ControlNetModeViewModel> ControlNetRepository::Modes()
{
vector<ControlNetModeViewModel> results;
results.push_back(ControlNetModeViewModel{
.Id = L"canny",
.Name = L"Canny edges",
.ShortName = L"Canny",
.Description = L"Generates images based on a monochrome image with white edges on a black background."
});
results.push_back(ControlNetModeViewModel{
.Id = L"depth",
.Name = L"Depth image",
.ShortName = L"Depth",
.Description = L"Generates images based on a grayscale image with black representing deep areas and white representing shallow areas."
});
results.push_back(ControlNetModeViewModel{
.Id = L"hed",
.Name = L"HED edges",
.ShortName = L"HED",
.Description = L"Generates images based on a monochrome image with white soft edges on a black background (Holistically-Nested Edge Detection)."
});
results.push_back(ControlNetModeViewModel{
.Id = L"mlsd",
.Name = L"MLSD edges",
.ShortName = L"MLSD",
.Description = L"Generates images based on a monochrome image composed only of white straight lines on a black background (Mobile Line Segment Detection)."
});
results.push_back(ControlNetModeViewModel{
.Id = L"normal",
.Name = L"Normal map",
.ShortName = L"Normal",
.Description = L"Generates images based on a normal map."
});
results.push_back(ControlNetModeViewModel{
.Id = L"openpose",
.Name = L"OpenPose",
.ShortName = L"Pose",
.Description = L"Generates images based on an OpenPose bone image."
});
results.push_back(ControlNetModeViewModel{
.Id = L"scribble",
.Name = L"Scribble",
.ShortName = L"Scribble",
.Description = L"Generates images based on a hand-drawn monochrome image with white outlines on a black background."
});
results.push_back(ControlNetModeViewModel{
.Id = L"seg",
.Name = L"Segmentation",
.ShortName = L"Seg",
.Description = L"Generates images based on an ADE20K segmentation protocol image."
});
results.push_back(ControlNetModeViewModel{
.Id = L"inpaint",
.Name = L"Inpainting",
.ShortName = L"Inpaint",
.Description = L"Generates images based on an existing image."
});
for (auto& result : results)
{
result.ExampleInput = format(L"ms-appx:///Assets/controlnet/{}_input.png", result.Id);
result.ExampleOutput = format(L"ms-appx:///Assets/controlnet/{}_output.png", result.Id);
}
return results;
}
ControlNetRepository::ControlNetRepository() :
_controlnetRoot((ApplicationData::Current().LocalCacheFolder().Path() + L"\\controlnet").c_str()),
_annotatorRoot((ApplicationData::Current().LocalCacheFolder().Path() + L"\\annotators").c_str())
{
std::error_code ec;
filesystem::create_directories(_controlnetRoot, ec);
filesystem::create_directories(_annotatorRoot, ec);
Refresh();
}
const std::filesystem::path& ControlNetRepository::Root() const
{
return _controlnetRoot;
}
const std::vector<std::string>& ControlNetRepository::InstalledModes() const
{
return _installedModes;
}
const std::vector<std::string>& ControlNetRepository::InstalledAnnotators() const
{
return _installedAnnotators;
}
bool ControlNetRepository::TryEnsureModes(const std::vector<std::string>& modes, Axodox::Threading::async_operation& operation)
{
//Define install & remove tasks
set<string> filesToInstall{};
set<string> modesToRemove{ _installedModes.begin(), _installedModes.end() };
for (auto& mode : modes)
{
if (!modesToRemove.erase(mode))
{
filesToInstall.emplace(format("controlnet/{}.onnx", mode));
auto annotatorIt = _annotators.find(mode);
if (annotatorIt != _annotators.end())
{
for (auto& annotator : annotatorIt->second)
{
filesToInstall.emplace("annotators/" + annotator);
}
}
}
}
//Install new modes
HuggingFaceClient huggingFaceClient{};
auto result = huggingFaceClient.TryDownloadModel(_controlnetRepository, filesToInstall, {}, _controlnetRoot.parent_path(), operation);
//Remove old modes
for (auto& mode : modesToRemove)
{
error_code ec;
filesystem::remove(_controlnetRoot / format("{}.onnx", mode), ec);
auto annotatorIt = _annotators.find(mode);
if (annotatorIt != _annotators.end())
{
for (auto& annotator : annotatorIt->second)
{
filesystem::remove(_annotatorRoot / annotator, ec);
}
}
}
Refresh();
return result;
}
void ControlNetRepository::Refresh()
{
//Detect controlnet modes
vector<string> installedModes;
for (auto& file : filesystem::directory_iterator{ _controlnetRoot })
{
if (file.path().extension() != ".onnx") continue;
installedModes.push_back(file.path().stem().string());
}
//Detect annotators
error_code ec;
vector<string> installedAnnotators;
for (auto& mode : installedModes)
{
bool isInstalled = false;
auto annotatorIt = _annotators.find(mode);
if (annotatorIt != _annotators.end())
{
isInstalled = true;
for (auto& annotator : annotatorIt->second)
{
if (!filesystem::exists(_annotatorRoot / annotator, ec))
{
isInstalled = false;
break;
}
}
}
if (isInstalled) installedAnnotators.push_back(mode);
}
//Update state
_installedModes = installedModes;
_installedAnnotators = installedAnnotators;
}
}