Shortcuts

Source code for mmdeploy.apis.pytorch2torchscript

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Any, Optional, Union

import mmengine

from mmdeploy.apis.core.pipeline_manager import PIPELINE_MANAGER, no_mp


[docs]@PIPELINE_MANAGER.register_pipeline() def torch2torchscript(img: Any, work_dir: str, save_file: str, deploy_cfg: Union[str, mmengine.Config], model_cfg: Union[str, mmengine.Config], model_checkpoint: Optional[str] = None, device: str = 'cuda:0'): """Convert PyTorch model to torchscript model. Args: img (str | np.ndarray | torch.Tensor): Input image used to assist converting model. work_dir (str): A working directory to save files. save_file (str): Filename to save torchscript model. deploy_cfg (str | mmengine.Config): Deployment config file or Config object. model_cfg (str | mmengine.Config): Model config file or Config object. model_checkpoint (str): A checkpoint path of PyTorch model, defaults to `None`. device (str): A string specifying device type, defaults to 'cuda:0'. """ import torch from mmdeploy.utils import get_backend, get_input_shape, load_config from .torch_jit import trace # load deploy_cfg if necessary deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) mmengine.mkdir_or_exist(osp.abspath(work_dir)) input_shape = get_input_shape(deploy_cfg) from mmdeploy.apis import build_task_processor task_processor = build_task_processor(model_cfg, deploy_cfg, device) torch_model = task_processor.build_pytorch_model(model_checkpoint) data, model_inputs = task_processor.create_input( img, input_shape, data_preprocessor=getattr(torch_model, 'data_preprocessor', None)) data_samples = data['data_samples'] input_metas = {'data_samples': data_samples, 'mode': 'predict'} if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1: model_inputs = model_inputs[0] context_info = dict(deploy_cfg=deploy_cfg) backend = get_backend(deploy_cfg).value output_prefix = osp.join(work_dir, osp.splitext(save_file)[0]) if model_inputs.device != device: model_inputs = model_inputs.to(device) with no_mp(): trace( torch_model, model_inputs, output_path_prefix=output_prefix, backend=backend, input_metas=input_metas, context_info=context_info, check_trace=False)
Read the Docs v: latest
Versions
latest
stable
1.x
v1.3.0
v1.2.0
v1.1.0
v1.0.0
0.x
v0.14.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.