Skip to content

Commit

Permalink
Set minimum KV chunk size automatically
Browse files Browse the repository at this point in the history
Fixes broken setting of minimum KV chunk size. Also removes --sub-quad-min-chunk-size and --sub-quad-min-chunk-vram-percent in favor of automatically adjusting minimum KV chunk size and chunk threshold to use all available memory.
  • Loading branch information
brkirch committed Dec 31, 2022
1 parent 1824569 commit dabcda4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 26 deletions.
5 changes: 0 additions & 5 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
from modules.sd_hijack import model_hijack
from modules.sd_hijack_optimizations import set_sub_quad_chunk_threshold
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
Expand Down Expand Up @@ -699,8 +698,6 @@ def init(self, all_prompts, all_seeds, all_subseeds):

self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
if shared.cmd_opts.sub_quad_min_chunk_vram_percent > 0 or shared.cmd_opts.sub_quad_min_chunk_size > 0:
set_sub_quad_chunk_threshold()

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
Expand Down Expand Up @@ -909,8 +906,6 @@ def init(self, all_prompts, all_seeds, all_subseeds):
self.init_latent = self.init_latent * self.mask

self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
if shared.cmd_opts.sub_quad_min_chunk_vram_percent > 0 or shared.cmd_opts.sub_quad_min_chunk_size > 0:
set_sub_quad_chunk_threshold()

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
Expand Down
33 changes: 14 additions & 19 deletions modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,29 +207,17 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# -- End of code from /~https://github.com/invoke-ai/InvokeAI --


chunk_threshold_bytes = None
def set_sub_quad_chunk_threshold():
global chunk_threshold_bytes
if shared.cmd_opts.sub_quad_min_chunk_size > 0:
chunk_threshold_bytes = shared.cmd_opts.sub_quad_min_chunk_size
if shared.device.type == 'mps' and chunk_threshold_bytes >= 2**34:
print("Warning: Minimum KV chunk size is set to " + str(chunk_threshold_bytes) + " but on MPS KV chunks can't exceed a size of 17179869183.", file=sys.stderr)
print("Setting minimum size of KV chunks to 17179869183 to avoid crashing.", file=sys.stderr)
chunk_threshold_bytes = 2**34-1
return
def get_sub_quad_chunk_threshold():
if shared.device.type == 'cuda':
stats = torch.cuda.memory_stats(shared.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
chunk_threshold_bytes = int(shared.cmd_opts.sub_quad_min_chunk_vram_percent * 0.01 * mem_free_total)
return int(0.95 * mem_free_total)
else:
chunk_threshold_bytes = int(shared.cmd_opts.sub_quad_min_chunk_vram_percent * 0.01 * psutil.virtual_memory().available)
if shared.device.type == 'mps' and chunk_threshold_bytes >= 2**34:
chunk_threshold_bytes = 2**34-1

return psutil.virtual_memory().available

# Based on Birch-san's modified implementation of sub-quadratic attention from /~https://github.com/Birch-san/diffusers/pull/1
def sub_quad_attention_forward(self, x, context=None, mask=None):
Expand All @@ -249,6 +237,8 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)

chunk_threshold_bytes = get_sub_quad_chunk_threshold()

x = sub_quad_attention(q, k, v, kv_chunk_size_min=chunk_threshold_bytes, chunk_threshold_bytes=chunk_threshold_bytes, use_checkpoint=self.training)

x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
Expand All @@ -266,20 +256,22 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens

if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens

return efficient_dot_product_attention(
q,
k,
v,
query_chunk_size=q_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min = kv_chunk_size_min,
use_checkpoint=use_checkpoint,
)

# MemoryEfficientAttnBlock forward from /~https://github.com/Stability-AI/stablediffusion modified to use sub-quadratic attention instead of xformers
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
Expand All @@ -299,6 +291,9 @@ def sub_quad_attnblock_forward(self, x):
.contiguous(),
(q, k, v),
)

chunk_threshold_bytes = get_sub_quad_chunk_threshold()

out = sub_quad_attention(q, k, v, kv_chunk_size_min=chunk_threshold_bytes, chunk_threshold_bytes=chunk_threshold_bytes, use_checkpoint=self.training)

out = (
Expand Down
2 changes: 0 additions & 2 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--sub-quad-min-chunk-size", type=int, help="minimum kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=0)
parser.add_argument("--sub-quad-min-chunk-vram-percent", type=int, help="minimum percentage of available VRAM for sub-quadratic cross-attention layer optimization kv chunks to use", default=0)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
Expand Down

0 comments on commit dabcda4

Please sign in to comment.