Skip to content

feat: Add multiprocessing #141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 27, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added multiprocessing threshold
  • Loading branch information
Pringled committed Dec 27, 2024
commit fe3eccb3213236be305d4a32a154e996c4e06075
2 changes: 1 addition & 1 deletion model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

logger = getLogger(__name__)

MULTIPROCESSING_THRESHOLD = 6000
MULTIPROCESSING_THRESHOLD = 10_000


class StaticModel:
Expand Down Expand Up @@ -203,12 +203,12 @@

if use_multiprocessing and len(sentences) > MULTIPROCESSING_THRESHOLD:
# Use joblib for multiprocessing if requested, and if we have enough sentences
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(

Check warning on line 206 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L206

Added line #L206 was not covered by tests
delayed(self._encode_batch_as_sequence)(batch, max_length) for batch in sentence_batches
)
out_array: list[np.ndarray] = []
for r in results:
out_array.extend(r)

Check warning on line 211 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L209-L211

Added lines #L209 - L211 were not covered by tests
else:
out_array = []
for batch in tqdm(
Expand Down Expand Up @@ -269,10 +269,10 @@

if use_multiprocessing and len(sentences) > MULTIPROCESSING_THRESHOLD:
# Use joblib for multiprocessing if requested, and if we have enough sentences
results = ProgressParallel(n_jobs=-1, use_tqdm=show_progress_bar, total=total_batches)(

Check warning on line 272 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L272

Added line #L272 was not covered by tests
delayed(self._encode_batch)(batch, max_length) for batch in sentence_batches
)
out_array = np.concatenate(results, axis=0)

Check warning on line 275 in model2vec/model.py

View check run for this annotation

Codecov / codecov/patch

model2vec/model.py#L275

Added line #L275 was not covered by tests
else:
# Don't use multiprocessing
out_arrays: list[np.ndarray] = []
Expand Down
Loading