Skip to content

feat: Add multiprocessing #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added multiprocessing threshold
  • Loading branch information
Pringled committed Dec 27, 2024
commit 345f590a4c739f11e0e6720688ad5d8b8d71fe71
17 changes: 7 additions & 10 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@
from typing import Any, Iterator, Union

import numpy as np
from joblib import delayed
from tokenizers import Encoding, Tokenizer
from tqdm import tqdm

from model2vec.utils import load_local_model
from model2vec.utils import ProgressParallel, load_local_model

PathLike = Union[Path, str]


logger = getLogger(__name__)


from joblib import delayed
from tqdm.auto import tqdm

from model2vec.utils import ProgressParallel
MULTIPROCESSING_THRESHOLD = 6000


class StaticModel:
Expand Down Expand Up @@ -205,13 +201,14 @@
sentence_batches = list(self._batch(sentences, batch_size))
total_batches = math.ceil(len(sentences) / batch_size)

if use_multiprocessing:
if use_multiprocessing and len(sentences) > MULTIPROCESSING_THRESHOLD:
# Use joblib for multiprocessing if requested, and if we have enough sentences
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(

Check warning on line 206 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L206

Added line #L206 was not covered by tests
delayed(self._encode_batch_as_sequence)(batch, max_length) for batch in sentence_batches
)
out_array: list[np.ndarray] = []
for r in results:
out_array.extend(r)

Check warning on line 211 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L209-L211

Added lines #L209 - L211 were not covered by tests
else:
out_array = []
for batch in tqdm(
Expand Down Expand Up @@ -270,12 +267,12 @@
sentence_batches = list(self._batch(sentences, batch_size))
total_batches = math.ceil(len(sentences) / batch_size)

if use_multiprocessing:
# Use joblib for multiprocessing if requested
if use_multiprocessing and len(sentences) > MULTIPROCESSING_THRESHOLD:
# Use joblib for multiprocessing if requested, and if we have enough sentences
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(

Check warning on line 272 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L272

Added line #L272 was not covered by tests
delayed(self._encode_batch)(batch, max_length) for batch in sentence_batches
)
out_array = np.concatenate(results, axis=0)

Check warning on line 275 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L275

Added line #L275 was not covered by tests
else:
# Don't use multiprocessing
out_arrays: list[np.ndarray] = []
Expand Down
Loading