Skip to content

Commit 227bc57

Browse files
committed
Revert "[Offload] Lazily initialize platforms in the Offloading API" (#163272)
Summary: This causes issues with CUDA's teardown order when the init is separated from the total init scope.
1 parent 225ee03 commit 227bc57

File tree

1 file changed

+33
-62
lines changed

1 file changed

+33
-62
lines changed

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 33 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,6 @@ struct ol_platform_impl_t {
4545
: BackendType(BackendType), Plugin(std::move(Plugin)) {}
4646
ol_platform_backend_t BackendType;
4747

48-
/// Get the plugin, lazily initializing it if necessary.
49-
llvm::Expected<GenericPluginTy *> getPlugin() {
50-
if (llvm::Error Err = init())
51-
return Err;
52-
return Plugin.get();
53-
}
54-
55-
/// Get the device list, lazily initializing it if necessary.
56-
llvm::Expected<llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> &>
57-
getDevices() {
58-
if (llvm::Error Err = init())
59-
return Err;
60-
return Devices;
61-
}
62-
6348
/// Complete all pending work for this platform and perform any needed
6449
/// cleanup.
6550
///
@@ -73,8 +58,6 @@ struct ol_platform_impl_t {
7358
/// Direct access to the plugin, may be uninitialized if accessed here.
7459
std::unique_ptr<GenericPluginTy> Plugin;
7560

76-
private:
77-
std::once_flag Initialized;
7861
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
7962
};
8063

@@ -154,39 +137,25 @@ llvm::Error ol_platform_impl_t::destroy() {
154137
}
155138

156139
llvm::Error ol_platform_impl_t::init() {
157-
std::unique_ptr<llvm::Error> Storage;
158-
159-
// This can be called concurrently, make sure we only do the actual
160-
// initialization once.
161-
std::call_once(Initialized, [&]() {
162-
// FIXME: Need better handling for the host platform.
163-
if (!Plugin)
164-
return;
165-
166-
llvm::Error Err = Plugin->init();
167-
if (Err) {
168-
Storage = std::make_unique<llvm::Error>(std::move(Err));
169-
return;
170-
}
140+
if (!Plugin)
141+
return llvm::Error::success();
171142

172-
for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) {
173-
if (llvm::Error Err = Plugin->initDevice(Id)) {
174-
Storage = std::make_unique<llvm::Error>(std::move(Err));
175-
return;
176-
}
143+
if (llvm::Error Err = Plugin->init())
144+
return Err;
177145

178-
auto Device = &Plugin->getDevice(Id);
179-
auto Info = Device->obtainInfoImpl();
180-
if (llvm::Error Err = Info.takeError()) {
181-
Storage = std::make_unique<llvm::Error>(std::move(Err));
182-
return;
183-
}
184-
Devices.emplace_back(std::make_unique<ol_device_impl_t>(
185-
Id, Device, *this, std::move(*Info)));
186-
}
187-
});
146+
for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) {
147+
if (llvm::Error Err = Plugin->initDevice(Id))
148+
return Err;
188149

189-
return Storage ? std::move(*Storage) : llvm::Error::success();
150+
auto Device = &Plugin->getDevice(Id);
151+
auto Info = Device->obtainInfoImpl();
152+
if (llvm::Error Err = Info.takeError())
153+
return Err;
154+
Devices.emplace_back(std::make_unique<ol_device_impl_t>(Id, Device, *this,
155+
std::move(*Info)));
156+
}
157+
158+
return llvm::Error::success();
190159
}
191160

192161
struct ol_queue_impl_t {
@@ -266,7 +235,7 @@ struct OffloadContext {
266235
std::mutex AllocInfoMapMutex{};
267236
// Partitioned list of memory base addresses. Each element in this list is a
268237
// key in AllocInfoMap
269-
llvm::SmallVector<void *> AllocBases{};
238+
SmallVector<void *> AllocBases{};
270239
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
271240
ol_device_handle_t HostDevice;
272241
size_t RefCount;
@@ -314,14 +283,19 @@ Error initPlugins(OffloadContext &Context) {
314283
} while (false);
315284
#include "Shared/Targets.def"
316285

317-
// Add the special host device
286+
// Eagerly initialize all of the plugins and devices. We need to make sure
287+
// that the platform is initialized at a consistent point to maintain the
288+
// expected teardown order in the vendor libraries.
289+
for (auto &Platform : Context.Platforms) {
290+
if (Error Err = Platform->init())
291+
return Err;
292+
}
293+
294+
// Add the special host device.
318295
auto &HostPlatform = Context.Platforms.emplace_back(
319296
std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
320-
auto DevicesOrErr = HostPlatform->getDevices();
321-
if (!DevicesOrErr)
322-
return DevicesOrErr.takeError();
323-
Context.HostDevice = DevicesOrErr
324-
->emplace_back(std::make_unique<ol_device_impl_t>(
297+
Context.HostDevice = HostPlatform->Devices
298+
.emplace_back(std::make_unique<ol_device_impl_t>(
325299
-1, nullptr, *HostPlatform, InfoTreeNode{}))
326300
.get();
327301

@@ -355,7 +329,7 @@ Error olShutDown_impl() {
355329
if (--OffloadContext::get().RefCount != 0)
356330
return Error::success();
357331

358-
llvm::Error Result = Error::success();
332+
Error Result = Error::success();
359333
auto *OldContext = OffloadContextVal.exchange(nullptr);
360334

361335
for (auto &Platform : OldContext->Platforms) {
@@ -364,7 +338,7 @@ Error olShutDown_impl() {
364338
continue;
365339

366340
if (auto Res = Platform->destroy())
367-
Result = llvm::joinErrors(std::move(Result), std::move(Res));
341+
Result = joinErrors(std::move(Result), std::move(Res));
368342
}
369343

370344
delete OldContext;
@@ -423,7 +397,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
423397

424398
auto makeError = [&](ErrorCode Code, StringRef Err) {
425399
std::string ErrBuffer;
426-
llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
400+
raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
427401
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
428402
};
429403

@@ -641,10 +615,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
641615

642616
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
643617
for (auto &Platform : OffloadContext::get().Platforms) {
644-
auto DevicesOrErr = Platform->getDevices();
645-
if (!DevicesOrErr)
646-
return DevicesOrErr.takeError();
647-
for (auto &Device : *DevicesOrErr) {
618+
for (auto &Device : Platform->Devices) {
648619
if (!Callback(Device.get(), UserData)) {
649620
return Error::success();
650621
}
@@ -1186,7 +1157,7 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
11861157
auto CheckKind = [&](ol_symbol_kind_t Required) {
11871158
if (Symbol->Kind != Required) {
11881159
std::string ErrBuffer;
1189-
llvm::raw_string_ostream(ErrBuffer)
1160+
raw_string_ostream(ErrBuffer)
11901161
<< PropName << ": Expected a symbol of Kind " << Required
11911162
<< " but given a symbol of Kind " << Symbol->Kind;
11921163
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());

0 commit comments

Comments
 (0)