Skip to content

Commit

Permalink
Lazy Library Loading support in GPUs with different architectures (#2…
Browse files Browse the repository at this point in the history
…040)

* Fix to use correct device property for solution selection

* Support Lazy Loading for heterogeneous architectures
  • Loading branch information
rkamd authored Sep 27, 2023
1 parent 212cbf0 commit bc4d8f5
Showing 1 changed file with 108 additions and 14 deletions.
122 changes: 108 additions & 14 deletions library/src/tensile_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,77 @@ namespace
}
}

Tensile::LazyLoadingInit getLazyLoadingArch(int deviceID)
{
hipDeviceProp_t deviceProperties;
hipGetDeviceProperties(&deviceProperties, deviceID);
// strip out xnack/ecc from name
std::string deviceFullString(deviceProperties.gcnArchName);
std::string deviceString = deviceFullString.substr(0, deviceFullString.find(":"));

if(deviceString.find("gfx803") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx803;
}
else if(deviceString.find("gfx900") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx900;
}
else if(deviceString.find("gfx906") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx906;
}
else if(deviceString.find("gfx908") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx908;
}
else if(deviceString.find("gfx90a") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx90a;
}
else if(deviceString.find("gfx940") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx940;
}
else if(deviceString.find("gfx941") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx941;
}
else if(deviceString.find("gfx942") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx942;
}
else if(deviceString.find("gfx1010") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx1010;
}
else if(deviceString.find("gfx1011") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx1011;
}
else if(deviceString.find("gfx1012") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx1012;
}
else if(deviceString.find("gfx1030") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx1030;
}
else if(deviceString.find("gfx1100") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx1100;
}
else if(deviceString.find("gfx1101") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx1101;
}
else if(deviceString.find("gfx1102") != std::string::npos)
{
return Tensile::LazyLoadingInit::gfx1102;
}
return Tensile::LazyLoadingInit::None;
}

/*************************************************************************
* Class for converting alpha and beta between rocBLAS and Tensile types *
* By default, alpha and beta are the same type as Tc compute_type *
Expand Down Expand Up @@ -502,7 +573,7 @@ namespace
{
// The library object
std::shared_ptr<Tensile::MasterSolutionLibrary<Tensile::ContractionProblem>> m_library;
std::shared_ptr<hipDeviceProp_t> m_deviceProp;
std::unordered_map<std::string, std::shared_ptr<hipDeviceProp_t>> m_devicePropMap;

// The adapter object. mutable is used to allow adapters to be modified
// even when they are stored in a const vector which is immutable in size
Expand Down Expand Up @@ -554,9 +625,9 @@ namespace
return m_library;
}

auto& get_device_property() const
auto& get_device_property(const std::string& deviceName) const
{
return m_deviceProp;
return m_devicePropMap.at(deviceName);
}

auto& get_adapters() const
Expand Down Expand Up @@ -585,9 +656,12 @@ namespace
std::string path;
std::string tensileLibraryPath;
bool tensile_lazy_load_enabled = false;
//Function local static-variables are used to gaurantee thread-safe initialization,
//avoids static initialization order fiasco
static std::future<
std::shared_ptr<Tensile::SolutionLibrary<Tensile::ContractionProblem>>>
ftr_lib;
ftr_lib;
static std::unordered_set<Tensile::LazyLoadingInit> tensileDeviceSet;

#ifndef WIN32
path.reserve(PATH_MAX);
Expand Down Expand Up @@ -705,6 +779,31 @@ namespace
else
tensile_lazy_load_enabled = true;

//Supports multi architecture configuration in lazy library loading mode
static int initialize_once = [&] {
hipDeviceProp_t prop;
int count;
HIP_CHECK_EXC(hipGetDeviceCount(&count));

for(int devId = 0; devId < count; devId++)
{
auto deviceArch = getLazyLoadingArch(devId);
if(tensileDeviceSet.find(deviceArch) == tensileDeviceSet.end())
{
//populate the arch list for lazy loading
tensileDeviceSet.insert(deviceArch);
//populate device property map, used in finding solutions based on arch
HIP_CHECK_EXC(hipGetDeviceProperties(&prop, devId));
// strip out xnack/ecc from name
std::string deviceFullString(prop.gcnArchName);
std::string deviceString
= deviceFullString.substr(0, deviceFullString.find(":"));
m_devicePropMap[deviceString] = std::make_shared<hipDeviceProp_t>(prop);
}
}
return 0;
}();

if(!tensile_lazy_load_enabled || rocblas_initialize_called())
{

Expand Down Expand Up @@ -792,8 +891,8 @@ namespace
= std::async(std::launch::async,
Tensile::LoadLibraryFilePreload<Tensile::ContractionProblem>,
tensileLibraryPath,
std::vector<Tensile::LazyLoadingInit>{});

std::vector<Tensile::LazyLoadingInit>{tensileDeviceSet.begin(),
tensileDeviceSet.end()});
return 0;
}();

Expand Down Expand Up @@ -821,18 +920,13 @@ namespace
rocblas_abort();
}

hipDeviceProp_t prop;
HIP_CHECK_EXC(hipGetDeviceProperties(&prop, deviceId));

m_deviceProp = std::make_shared<hipDeviceProp_t>(prop);

// Preload problem/solution mappings
const char* overrideEnv = getenv("ROCBLAS_TENSILE_GEMM_OVERRIDE_PATH");
if(overrideEnv)
{
std::string overridePath = overrideEnv;
std::shared_ptr<Tensile::Hardware> hardware
= Tensile::hip::GetDevice(*m_deviceProp);
std::shared_ptr<Tensile::Hardware> hardware = Tensile::hip::GetDevice(
*(get_device_property(rocblas_internal_get_arch_name())));
bool success = m_library->setOverridesFromFile(*hardware, overridePath);

if(!success)
Expand Down Expand Up @@ -887,7 +981,7 @@ namespace
if(library)
*library = host.get_library();
if(deviceProp)
*deviceProp = host.get_device_property();
*deviceProp = host.get_device_property(rocblas_internal_get_arch_name());

return *adapter;
}
Expand Down

0 comments on commit bc4d8f5

Please sign in to comment.