12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
- """ Utils to interact with the Triton Inference Server
- """
- import typing
- from urllib.parse import urlparse
- import torch
- class TritonRemoteModel:
- """ A wrapper over a model served by the Triton Inference Server. It can
- be configured to communicate over GRPC or HTTP. It accepts Torch Tensors
- as input and returns them as outputs.
- """
- def __init__(self, url: str):
- """
- Keyword arguments:
- url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
- """
- parsed_url = urlparse(url)
- if parsed_url.scheme == 'grpc':
- from tritonclient.grpc import InferenceServerClient, InferInput
- self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
- model_repository = self.client.get_model_repository_index()
- self.model_name = model_repository.models[0].name
- self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
- def create_input_placeholders() -> typing.List[InferInput]:
- return [
- InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
- else:
- from tritonclient.http import InferenceServerClient, InferInput
- self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
- model_repository = self.client.get_model_repository_index()
- self.model_name = model_repository[0]['name']
- self.metadata = self.client.get_model_metadata(self.model_name)
- def create_input_placeholders() -> typing.List[InferInput]:
- return [
- InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
- self._create_input_placeholders_fn = create_input_placeholders
- @property
- def runtime(self):
- """Returns the model runtime"""
- return self.metadata.get('backend', self.metadata.get('platform'))
- def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
- """ Invokes the model. Parameters can be provided via args or kwargs.
- args, if provided, are assumed to match the order of inputs of the model.
- kwargs are matched with the model input names.
- """
- inputs = self._create_inputs(*args, **kwargs)
- response = self.client.infer(model_name=self.model_name, inputs=inputs)
- result = []
- for output in self.metadata['outputs']:
- tensor = torch.as_tensor(response.as_numpy(output['name']))
- result.append(tensor)
- return result[0] if len(result) == 1 else result
- def _create_inputs(self, *args, **kwargs):
- args_len, kwargs_len = len(args), len(kwargs)
- if not args_len and not kwargs_len:
- raise RuntimeError('No inputs provided.')
- if args_len and kwargs_len:
- raise RuntimeError('Cannot specify args and kwargs at the same time')
- placeholders = self._create_input_placeholders_fn()
- if args_len:
- if args_len != len(placeholders):
- raise RuntimeError(f'Expected {len(placeholders)} inputs, got {args_len}.')
- for input, value in zip(placeholders, args):
- input.set_data_from_numpy(value.cpu().numpy())
- else:
- for input in placeholders:
- value = kwargs[input.name]
- input.set_data_from_numpy(value.cpu().numpy())
- return placeholders
|