Update padding implementation reduced memory footprint
Browse files
modeling_qwen3_vl_nemotron_embed.py
CHANGED
|
@@ -94,6 +94,42 @@ def _create_bidirectional_mask(
|
|
| 94 |
return None
|
| 95 |
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
class Qwen3VLNemotronEmbedTextModel(Qwen3VLTextModel):
|
| 98 |
"""Bidirectional text model for Qwen3VLNemotronEmbed."""
|
| 99 |
|
|
@@ -247,7 +283,7 @@ class EmbeddingMixin:
|
|
| 247 |
Tensor of embeddings with shape (num_samples, max_seq_len, hidden_dim).
|
| 248 |
"""
|
| 249 |
device = next(self.parameters()).device
|
| 250 |
-
|
| 251 |
message = "query" if is_query else "document"
|
| 252 |
|
| 253 |
for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."):
|
|
@@ -269,10 +305,9 @@ class EmbeddingMixin:
|
|
| 269 |
if not torch.isfinite(embeddings).all():
|
| 270 |
raise ValueError("Embeddings contain NaN or Inf values")
|
| 271 |
|
| 272 |
-
|
| 273 |
|
| 274 |
-
|
| 275 |
-
return all_embeddings_tensor
|
| 276 |
|
| 277 |
def forward_queries(self, queries: List[str], batch_size: int = 8) -> torch.Tensor:
|
| 278 |
"""Forward text queries and extract embeddings.
|
|
@@ -357,22 +392,6 @@ class EmbeddingMixin:
|
|
| 357 |
class ColBERTScoringMixin:
|
| 358 |
"""Mixin providing ColBERT MaxSim scoring methods."""
|
| 359 |
|
| 360 |
-
def padding_various_shape_tensor(self, tensors: List[torch.Tensor]) -> torch.Tensor:
|
| 361 |
-
"""Pad tensors of various shapes for ColBERT-like scoring.
|
| 362 |
-
|
| 363 |
-
Args:
|
| 364 |
-
tensors: List of tensors with shape (batch, seq_len, hidden_dim)
|
| 365 |
-
|
| 366 |
-
Returns:
|
| 367 |
-
Concatenated tensor with all sequences padded to max length.
|
| 368 |
-
"""
|
| 369 |
-
max_seq_len = max(t.shape[1] for t in tensors)
|
| 370 |
-
padded_tensors = [
|
| 371 |
-
F.pad(t, (0, 0, 0, max_seq_len - t.shape[1]), mode="constant", value=0)
|
| 372 |
-
for t in tensors
|
| 373 |
-
]
|
| 374 |
-
return torch.cat(padded_tensors, dim=0)
|
| 375 |
-
|
| 376 |
def colbert_score(
|
| 377 |
self,
|
| 378 |
qs: Union[torch.Tensor, List[torch.Tensor]],
|
|
@@ -448,11 +467,11 @@ class ColBERTScoringMixin:
|
|
| 448 |
if isinstance(query_embeddings, list):
|
| 449 |
if len(query_embeddings[0].shape) == 2:
|
| 450 |
query_embeddings = [q.unsqueeze(0) for q in query_embeddings]
|
| 451 |
-
query_embeddings =
|
| 452 |
if isinstance(passage_embeddings, list):
|
| 453 |
if len(passage_embeddings[0].shape) == 2:
|
| 454 |
passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings]
|
| 455 |
-
passage_embeddings =
|
| 456 |
|
| 457 |
return self.colbert_score(
|
| 458 |
query_embeddings, passage_embeddings, batch_size or 128
|
|
|
|
| 94 |
return None
|
| 95 |
|
| 96 |
|
| 97 |
+
def _pad_and_stack_embeddings(tensors: List[torch.Tensor]) -> torch.Tensor:
|
| 98 |
+
"""Pad embedding tensors to uniform sequence length and concatenate.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
tensors: List of tensors with shape (batch, seq_len, hidden_dim).
|
| 102 |
+
Each tensor may have a different seq_len.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Concatenated tensor with shape (total_batch, max_seq_len, hidden_dim),
|
| 106 |
+
where sequences shorter than max_seq_len are zero-padded.
|
| 107 |
+
"""
|
| 108 |
+
if not tensors:
|
| 109 |
+
raise ValueError("Cannot pad empty tensor list")
|
| 110 |
+
|
| 111 |
+
max_seq_len = max(t.shape[1] for t in tensors)
|
| 112 |
+
total_docs = sum(t.shape[0] for t in tensors)
|
| 113 |
+
hidden_dim = tensors[0].shape[2]
|
| 114 |
+
dtype = tensors[0].dtype
|
| 115 |
+
|
| 116 |
+
# Pre-allocate result tensor
|
| 117 |
+
result = torch.zeros(total_docs, max_seq_len, hidden_dim, dtype=dtype)
|
| 118 |
+
|
| 119 |
+
# Copy in-place and release references to free memory
|
| 120 |
+
offset = 0
|
| 121 |
+
for i in range(len(tensors)):
|
| 122 |
+
t = tensors[i]
|
| 123 |
+
tensors[i] = None # Release reference immediately
|
| 124 |
+
batch_size = t.shape[0]
|
| 125 |
+
seq_len = t.shape[1]
|
| 126 |
+
result[offset : offset + batch_size, :seq_len, :] = t
|
| 127 |
+
offset += batch_size
|
| 128 |
+
del t
|
| 129 |
+
|
| 130 |
+
return result
|
| 131 |
+
|
| 132 |
+
|
| 133 |
class Qwen3VLNemotronEmbedTextModel(Qwen3VLTextModel):
|
| 134 |
"""Bidirectional text model for Qwen3VLNemotronEmbed."""
|
| 135 |
|
|
|
|
| 283 |
Tensor of embeddings with shape (num_samples, max_seq_len, hidden_dim).
|
| 284 |
"""
|
| 285 |
device = next(self.parameters()).device
|
| 286 |
+
embedding_batches = []
|
| 287 |
message = "query" if is_query else "document"
|
| 288 |
|
| 289 |
for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."):
|
|
|
|
| 305 |
if not torch.isfinite(embeddings).all():
|
| 306 |
raise ValueError("Embeddings contain NaN or Inf values")
|
| 307 |
|
| 308 |
+
embedding_batches.append(embeddings.detach().cpu())
|
| 309 |
|
| 310 |
+
return _pad_and_stack_embeddings(embedding_batches)
|
|
|
|
| 311 |
|
| 312 |
def forward_queries(self, queries: List[str], batch_size: int = 8) -> torch.Tensor:
|
| 313 |
"""Forward text queries and extract embeddings.
|
|
|
|
| 392 |
class ColBERTScoringMixin:
|
| 393 |
"""Mixin providing ColBERT MaxSim scoring methods."""
|
| 394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
def colbert_score(
|
| 396 |
self,
|
| 397 |
qs: Union[torch.Tensor, List[torch.Tensor]],
|
|
|
|
| 467 |
if isinstance(query_embeddings, list):
|
| 468 |
if len(query_embeddings[0].shape) == 2:
|
| 469 |
query_embeddings = [q.unsqueeze(0) for q in query_embeddings]
|
| 470 |
+
query_embeddings = _pad_and_stack_embeddings(query_embeddings)
|
| 471 |
if isinstance(passage_embeddings, list):
|
| 472 |
if len(passage_embeddings[0].shape) == 2:
|
| 473 |
passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings]
|
| 474 |
+
passage_embeddings = _pad_and_stack_embeddings(passage_embeddings)
|
| 475 |
|
| 476 |
return self.colbert_score(
|
| 477 |
query_embeddings, passage_embeddings, batch_size or 128
|