general.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import logging.config
  10. import math
  11. import os
  12. import platform
  13. import random
  14. import re
  15. import signal
  16. import subprocess
  17. import sys
  18. import time
  19. import urllib
  20. from copy import deepcopy
  21. from datetime import datetime
  22. from itertools import repeat
  23. from multiprocessing.pool import ThreadPool
  24. from pathlib import Path
  25. from subprocess import check_output
  26. from tarfile import is_tarfile
  27. from typing import Optional
  28. from zipfile import ZipFile, is_zipfile
  29. from torchvision import transforms
  30. from PIL import Image
  31. import cv2
  32. import numpy as np
  33. import pandas as pd
  34. import pkg_resources as pkg
  35. import torch
  36. import torchvision
  37. import yaml
  38. import requests
  39. from ultralytics.utils.checks import check_requirements
  40. from ultralytics.data.augment import classify_transforms
  41. from utils import TryExcept, emojis
  42. from utils.downloads import curl_download, gsutil_getsize
  43. from utils.metrics import box_iou, fitness
  44. import io
  45. import base64
  46. FILE = Path(__file__).resolve()
  47. ROOT = FILE.parents[1] # YOLOv5 root directory
  48. RANK = int(os.getenv('RANK', -1))
  49. import transforms as trans
  50. # Settings
  51. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  52. DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
  53. AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
  54. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  55. TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
  56. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  57. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  58. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  59. pd.options.display.max_columns = 10
  60. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  61. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  62. os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
  63. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
  64. mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
  65. test = transforms.Compose([transforms.Resize((224,224)),
  66. #transforms.CenterCrop(224),
  67. transforms.ToTensor(),
  68. transforms.Normalize(mean=mean, std=std)
  69. ])
  70. v8transforms = classify_transforms(64)
  71. transforma = trans.Compose([
  72. trans.Scale((224, 224)),
  73. transforms.ToTensor(),
  74. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # 处理的是Tensor
  75. ])
  76. def is_ascii(s=''):
  77. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  78. s = str(s) # convert list, tuple, None, etc. to str
  79. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  80. def is_chinese(s='人工智能'):
  81. # Is string composed of any Chinese characters?
  82. return bool(re.search('[\u4e00-\u9fff]', str(s)))
  83. def is_colab():
  84. # Is environment a Google Colab instance?
  85. return 'google.colab' in sys.modules
  86. def is_jupyter():
  87. """
  88. Check if the current script is running inside a Jupyter Notebook.
  89. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
  90. Returns:
  91. bool: True if running inside a Jupyter Notebook, False otherwise.
  92. """
  93. with contextlib.suppress(Exception):
  94. from IPython import get_ipython
  95. return get_ipython() is not None
  96. return False
  97. def is_kaggle():
  98. # Is environment a Kaggle Notebook?
  99. return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  100. def is_docker() -> bool:
  101. """Check if the process runs inside a docker container."""
  102. if Path('/.dockerenv').exists():
  103. return True
  104. try: # check if docker is in control groups
  105. with open('/proc/self/cgroup') as file:
  106. return any('docker' in line for line in file)
  107. except OSError:
  108. return False
  109. def is_writeable(dir, test=False):
  110. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  111. if not test:
  112. return os.access(dir, os.W_OK) # possible issues on Windows
  113. file = Path(dir) / 'tmp.txt'
  114. try:
  115. with open(file, 'w'): # open file with write permissions
  116. pass
  117. file.unlink() # remove file
  118. return True
  119. except OSError:
  120. return False
  121. LOGGING_NAME = 'yolov5'
  122. def set_logging(name=LOGGING_NAME, verbose=True):
  123. # sets up logging for the given name
  124. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  125. level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
  126. logging.config.dictConfig({
  127. 'version': 1,
  128. 'disable_existing_loggers': False,
  129. 'formatters': {
  130. name: {
  131. 'format': '%(message)s'}},
  132. 'handlers': {
  133. name: {
  134. 'class': 'logging.StreamHandler',
  135. 'formatter': name,
  136. 'level': level, }},
  137. 'loggers': {
  138. name: {
  139. 'level': level,
  140. 'handlers': [name],
  141. 'propagate': False, }}})
  142. set_logging(LOGGING_NAME) # run before defining LOGGER
  143. LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
  144. if platform.system() == 'Windows':
  145. for fn in LOGGER.info, LOGGER.warning:
  146. setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
  147. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  148. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  149. env = os.getenv(env_var)
  150. if env:
  151. path = Path(env) # use environment variable
  152. else:
  153. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  154. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  155. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  156. path.mkdir(exist_ok=True) # make if required
  157. return path
  158. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  159. class Profile(contextlib.ContextDecorator):
  160. # YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
  161. def __init__(self, t=0.0):
  162. self.t = t
  163. self.cuda = torch.cuda.is_available()
  164. def __enter__(self):
  165. self.start = self.time()
  166. return self
  167. def __exit__(self, type, value, traceback):
  168. self.dt = self.time() - self.start # delta-time
  169. self.t += self.dt # accumulate dt
  170. def time(self):
  171. if self.cuda:
  172. torch.cuda.synchronize()
  173. return time.time()
  174. class Timeout(contextlib.ContextDecorator):
  175. # YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  176. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  177. self.seconds = int(seconds)
  178. self.timeout_message = timeout_msg
  179. self.suppress = bool(suppress_timeout_errors)
  180. def _timeout_handler(self, signum, frame):
  181. raise TimeoutError(self.timeout_message)
  182. def __enter__(self):
  183. if platform.system() != 'Windows': # not supported on Windows
  184. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  185. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  186. def __exit__(self, exc_type, exc_val, exc_tb):
  187. if platform.system() != 'Windows':
  188. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  189. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  190. return True
  191. class WorkingDirectory(contextlib.ContextDecorator):
  192. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  193. def __init__(self, new_dir):
  194. self.dir = new_dir # new dir
  195. self.cwd = Path.cwd().resolve() # current dir
  196. def __enter__(self):
  197. os.chdir(self.dir)
  198. def __exit__(self, exc_type, exc_val, exc_tb):
  199. os.chdir(self.cwd)
  200. def methods(instance):
  201. # Get class/instance methods
  202. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith('__')]
  203. def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
  204. # Print function arguments (optional args dict)
  205. x = inspect.currentframe().f_back # previous frame
  206. file, _, func, _, _ = inspect.getframeinfo(x)
  207. if args is None: # get args automatically
  208. args, _, _, frm = inspect.getargvalues(x)
  209. args = {k: v for k, v in frm.items() if k in args}
  210. try:
  211. file = Path(file).resolve().relative_to(ROOT).with_suffix('')
  212. except ValueError:
  213. file = Path(file).stem
  214. s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
  215. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  216. def init_seeds(seed=0, deterministic=False):
  217. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  218. random.seed(seed)
  219. np.random.seed(seed)
  220. torch.manual_seed(seed)
  221. torch.cuda.manual_seed(seed)
  222. torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
  223. # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
  224. if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
  225. torch.use_deterministic_algorithms(True)
  226. torch.backends.cudnn.deterministic = True
  227. os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
  228. os.environ['PYTHONHASHSEED'] = str(seed)
  229. def intersect_dicts(da, db, exclude=()):
  230. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  231. return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
  232. def get_default_args(func):
  233. # Get func() default arguments
  234. signature = inspect.signature(func)
  235. return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
  236. def get_latest_run(search_dir='.'):
  237. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  238. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  239. return max(last_list, key=os.path.getctime) if last_list else ''
  240. def file_age(path=__file__):
  241. # Return days since last file update
  242. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  243. return dt.days # + dt.seconds / 86400 # fractional days
  244. def file_date(path=__file__):
  245. # Return human-readable file modification date, i.e. '2021-3-26'
  246. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  247. return f'{t.year}-{t.month}-{t.day}'
  248. def file_size(path):
  249. # Return file/dir size (MB)
  250. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  251. path = Path(path)
  252. if path.is_file():
  253. return path.stat().st_size / mb
  254. elif path.is_dir():
  255. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  256. else:
  257. return 0.0
  258. def check_online():
  259. # Check internet connectivity
  260. import socket
  261. def run_once():
  262. # Check once
  263. try:
  264. socket.create_connection(('1.1.1.1', 443), 5) # check host accessibility
  265. return True
  266. except OSError:
  267. return False
  268. return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
  269. def git_describe(path=ROOT): # path must be a directory
  270. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  271. try:
  272. assert (Path(path) / '.git').is_dir()
  273. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  274. except Exception:
  275. return ''
  276. @TryExcept()
  277. @WorkingDirectory(ROOT)
  278. def check_git_status(repo='ultralytics/yolov5', branch='master'):
  279. # YOLOv5 status check, recommend 'git pull' if code is out of date
  280. url = f'https://github.com/{repo}'
  281. msg = f', for updates see {url}'
  282. s = colorstr('github: ') # string
  283. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  284. assert check_online(), s + 'skipping check (offline)' + msg
  285. splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
  286. matches = [repo in s for s in splits]
  287. if any(matches):
  288. remote = splits[matches.index(True) - 1]
  289. else:
  290. remote = 'ultralytics'
  291. check_output(f'git remote add {remote} {url}', shell=True)
  292. check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
  293. local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  294. n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
  295. if n > 0:
  296. pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
  297. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use '{pull}' or 'git clone {url}' to update."
  298. else:
  299. s += f'up to date with {url} ✅'
  300. LOGGER.info(s)
  301. @WorkingDirectory(ROOT)
  302. def check_git_info(path='.'):
  303. # YOLOv5 git info check, return {remote, branch, commit}
  304. check_requirements('gitpython')
  305. import git
  306. try:
  307. repo = git.Repo(path)
  308. remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/ultralytics/yolov5'
  309. commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
  310. try:
  311. branch = repo.active_branch.name # i.e. 'main'
  312. except TypeError: # not on any branch
  313. branch = None # i.e. 'detached HEAD' state
  314. return {'remote': remote, 'branch': branch, 'commit': commit}
  315. except git.exc.InvalidGitRepositoryError: # path is not a git dir
  316. return {'remote': None, 'branch': None, 'commit': None}
  317. def check_python(minimum='3.7.0'):
  318. # Check current python version vs. required python version
  319. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  320. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  321. # Check version vs. required version
  322. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  323. result = (current == minimum) if pinned else (current >= minimum) # bool
  324. s = f'WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed' # string
  325. if hard:
  326. assert result, emojis(s) # assert min requirements met
  327. if verbose and not result:
  328. LOGGER.warning(s)
  329. return result
  330. def check_img_size(imgsz, s=32, floor=0):
  331. # Verify image size is a multiple of stride s in each dimension
  332. if isinstance(imgsz, int): # integer i.e. img_size=640
  333. new_size = max(make_divisible(imgsz, int(s)), floor)
  334. else: # list i.e. img_size=[640, 480]
  335. imgsz = list(imgsz) # convert to list if tuple
  336. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  337. if new_size != imgsz:
  338. LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  339. return new_size
  340. def check_imshow(warn=False):
  341. # Check if environment supports image displays
  342. try:
  343. assert not is_jupyter()
  344. assert not is_docker()
  345. cv2.imshow('test', np.zeros((1, 1, 3)))
  346. cv2.waitKey(1)
  347. cv2.destroyAllWindows()
  348. cv2.waitKey(1)
  349. return True
  350. except Exception as e:
  351. if warn:
  352. LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
  353. return False
  354. def check_suffix(file='yolov5s.pt', suffix=('.pt', ), msg=''):
  355. # Check file(s) for acceptable suffix
  356. if file and suffix:
  357. if isinstance(suffix, str):
  358. suffix = [suffix]
  359. for f in file if isinstance(file, (list, tuple)) else [file]:
  360. s = Path(f).suffix.lower() # file suffix
  361. if len(s):
  362. assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
  363. def check_yaml(file, suffix=('.yaml', '.yml')):
  364. # Search/download YAML file (if necessary) and return path, checking suffix
  365. return check_file(file, suffix)
  366. def check_file(file, suffix=''):
  367. # Search/download file (if necessary) and return path
  368. check_suffix(file, suffix) # optional
  369. file = str(file) # convert to str()
  370. if os.path.isfile(file) or not file: # exists
  371. return file
  372. elif file.startswith(('http:/', 'https:/')): # download
  373. url = file # warning: Pathlib turns :// -> :/
  374. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  375. if os.path.isfile(file):
  376. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  377. else:
  378. LOGGER.info(f'Downloading {url} to {file}...')
  379. torch.hub.download_url_to_file(url, file)
  380. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  381. return file
  382. elif file.startswith('clearml://'): # ClearML Dataset ID
  383. assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
  384. return file
  385. else: # search
  386. files = []
  387. for d in 'data', 'models', 'utils': # search directories
  388. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  389. assert len(files), f'File not found: {file}' # assert file was found
  390. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  391. return files[0] # return file
  392. def check_font(font=FONT, progress=False):
  393. # Download font to CONFIG_DIR if necessary
  394. font = Path(font)
  395. file = CONFIG_DIR / font.name
  396. if not font.exists() and not file.exists():
  397. url = f'https://ultralytics.com/assets/{font.name}'
  398. LOGGER.info(f'Downloading {url} to {file}...')
  399. torch.hub.download_url_to_file(url, str(file), progress=progress)
  400. def check_dataset(data, autodownload=True):
  401. # Download, check and/or unzip dataset if not found locally
  402. # Download (optional)
  403. extract_dir = ''
  404. if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
  405. download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
  406. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  407. extract_dir, autodownload = data.parent, False
  408. # Read yaml (optional)
  409. if isinstance(data, (str, Path)):
  410. data = yaml_load(data) # dictionary
  411. # Checks
  412. for k in 'train', 'val', 'names':
  413. assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
  414. if isinstance(data['names'], (list, tuple)): # old array format
  415. data['names'] = dict(enumerate(data['names'])) # convert to dict
  416. assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
  417. data['nc'] = len(data['names'])
  418. # Resolve paths
  419. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  420. if not path.is_absolute():
  421. path = (ROOT / path).resolve()
  422. data['path'] = path # download scripts
  423. for k in 'train', 'val', 'test':
  424. if data.get(k): # prepend path
  425. if isinstance(data[k], str):
  426. x = (path / data[k]).resolve()
  427. if not x.exists() and data[k].startswith('../'):
  428. x = (path / data[k][3:]).resolve()
  429. data[k] = str(x)
  430. else:
  431. data[k] = [str((path / x).resolve()) for x in data[k]]
  432. # Parse yaml
  433. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  434. if val:
  435. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  436. if not all(x.exists() for x in val):
  437. LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
  438. if not s or not autodownload:
  439. raise Exception('Dataset not found ❌')
  440. t = time.time()
  441. if s.startswith('http') and s.endswith('.zip'): # URL
  442. f = Path(s).name # filename
  443. LOGGER.info(f'Downloading {s} to {f}...')
  444. torch.hub.download_url_to_file(s, f)
  445. Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
  446. unzip_file(f, path=DATASETS_DIR) # unzip
  447. Path(f).unlink() # remove zip
  448. r = None # success
  449. elif s.startswith('bash '): # bash script
  450. LOGGER.info(f'Running {s} ...')
  451. r = subprocess.run(s, shell=True)
  452. else: # python script
  453. r = exec(s, {'yaml': data}) # return None
  454. dt = f'({round(time.time() - t, 1)}s)'
  455. s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
  456. LOGGER.info(f'Dataset download {s}')
  457. check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
  458. return data # dictionary
  459. def check_amp(model):
  460. # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
  461. from models.common import AutoShape, DetectMultiBackend
  462. def amp_allclose(model, im):
  463. # All close FP32 vs AMP results
  464. m = AutoShape(model, verbose=False) # model
  465. a = m(im).xywhn[0] # FP32 inference
  466. m.amp = True
  467. b = m(im).xywhn[0] # AMP inference
  468. return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
  469. prefix = colorstr('AMP: ')
  470. device = next(model.parameters()).device # get model device
  471. if device.type in ('cpu', 'mps'):
  472. return False # AMP only used on CUDA devices
  473. f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
  474. im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
  475. try:
  476. assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
  477. LOGGER.info(f'{prefix}checks passed ✅')
  478. return True
  479. except Exception:
  480. help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
  481. LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
  482. return False
  483. def yaml_load(file='data.yaml'):
  484. # Single-line safe yaml loading
  485. with open(file, errors='ignore') as f:
  486. return yaml.safe_load(f)
  487. def yaml_save(file='data.yaml', data={}):
  488. # Single-line safe yaml saving
  489. with open(file, 'w') as f:
  490. yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
  491. def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
  492. # Unzip a *.zip file to path/, excluding files containing strings in exclude list
  493. if path is None:
  494. path = Path(file).parent # default path
  495. with ZipFile(file) as zipObj:
  496. for f in zipObj.namelist(): # list all archived filenames in the zip
  497. if all(x not in f for x in exclude):
  498. zipObj.extract(f, path=path)
  499. def url2file(url):
  500. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  501. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  502. return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  503. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  504. # Multithreaded file download and unzip function, used in data.yaml for autodownload
  505. def download_one(url, dir):
  506. # Download 1 file
  507. success = True
  508. if os.path.isfile(url):
  509. f = Path(url) # filename
  510. else: # does not exist
  511. f = dir / Path(url).name
  512. LOGGER.info(f'Downloading {url} to {f}...')
  513. for i in range(retry + 1):
  514. if curl:
  515. success = curl_download(url, f, silent=(threads > 1))
  516. else:
  517. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  518. success = f.is_file()
  519. if success:
  520. break
  521. elif i < retry:
  522. LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
  523. else:
  524. LOGGER.warning(f'❌ Failed to download {url}...')
  525. if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
  526. LOGGER.info(f'Unzipping {f}...')
  527. if is_zipfile(f):
  528. unzip_file(f, dir) # unzip
  529. elif is_tarfile(f):
  530. subprocess.run(['tar', 'xf', f, '--directory', f.parent], check=True) # unzip
  531. elif f.suffix == '.gz':
  532. subprocess.run(['tar', 'xfz', f, '--directory', f.parent], check=True) # unzip
  533. if delete:
  534. f.unlink() # remove zip
  535. dir = Path(dir)
  536. dir.mkdir(parents=True, exist_ok=True) # make directory
  537. if threads > 1:
  538. pool = ThreadPool(threads)
  539. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
  540. pool.close()
  541. pool.join()
  542. else:
  543. for u in [url] if isinstance(url, (str, Path)) else url:
  544. download_one(u, dir)
  545. def make_divisible(x, divisor):
  546. # Returns nearest x divisible by divisor
  547. if isinstance(divisor, torch.Tensor):
  548. divisor = int(divisor.max()) # to int
  549. return math.ceil(x / divisor) * divisor
  550. def clean_str(s):
  551. # Cleans a string by replacing special characters with underscore _
  552. return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
  553. def one_cycle(y1=0.0, y2=1.0, steps=100):
  554. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  555. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  556. def colorstr(*input):
  557. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  558. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  559. colors = {
  560. 'black': '\033[30m', # basic colors
  561. 'red': '\033[31m',
  562. 'green': '\033[32m',
  563. 'yellow': '\033[33m',
  564. 'blue': '\033[34m',
  565. 'magenta': '\033[35m',
  566. 'cyan': '\033[36m',
  567. 'white': '\033[37m',
  568. 'bright_black': '\033[90m', # bright colors
  569. 'bright_red': '\033[91m',
  570. 'bright_green': '\033[92m',
  571. 'bright_yellow': '\033[93m',
  572. 'bright_blue': '\033[94m',
  573. 'bright_magenta': '\033[95m',
  574. 'bright_cyan': '\033[96m',
  575. 'bright_white': '\033[97m',
  576. 'end': '\033[0m', # misc
  577. 'bold': '\033[1m',
  578. 'underline': '\033[4m'}
  579. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  580. def labels_to_class_weights(labels, nc=80):
  581. # Get class weights (inverse frequency) from training labels
  582. if labels[0] is None: # no labels loaded
  583. return torch.Tensor()
  584. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  585. classes = labels[:, 0].astype(int) # labels = [class xywh]
  586. weights = np.bincount(classes, minlength=nc) # occurrences per class
  587. # Prepend gridpoint count (for uCE training)
  588. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  589. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  590. weights[weights == 0] = 1 # replace empty bins with 1
  591. weights = 1 / weights # number of targets per class
  592. weights /= weights.sum() # normalize
  593. return torch.from_numpy(weights).float()
  594. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  595. # Produces image weights based on class_weights and image contents
  596. # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
  597. class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
  598. return (class_weights.reshape(1, nc) * class_counts).sum(1)
  599. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  600. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  601. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  602. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  603. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  604. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  605. return [
  606. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  607. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  608. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  609. def xyxy2xywh(x):
  610. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  611. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  612. y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
  613. y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
  614. y[..., 2] = x[..., 2] - x[..., 0] # width
  615. y[..., 3] = x[..., 3] - x[..., 1] # height
  616. return y
  617. def xywh2xyxy(x):
  618. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  619. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  620. y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
  621. y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
  622. y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
  623. y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
  624. return y
  625. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  626. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  627. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  628. y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
  629. y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
  630. y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
  631. y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
  632. return y
  633. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  634. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  635. if clip:
  636. clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
  637. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  638. y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
  639. y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
  640. y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
  641. y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
  642. return y
  643. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  644. # Convert normalized segments into pixel segments, shape (n,2)
  645. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  646. y[..., 0] = w * x[..., 0] + padw # top left x
  647. y[..., 1] = h * x[..., 1] + padh # top left y
  648. return y
  649. def segment2box(segment, width=640, height=640):
  650. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  651. x, y = segment.T # segment xy
  652. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  653. x, y, = x[inside], y[inside]
  654. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  655. def segments2boxes(segments):
  656. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  657. boxes = []
  658. for s in segments:
  659. x, y = s.T # segment xy
  660. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  661. return xyxy2xywh(np.array(boxes)) # cls, xywh
  662. def resample_segments(segments, n=1000):
  663. # Up-sample an (n,2) segment
  664. for i, s in enumerate(segments):
  665. s = np.concatenate((s, s[0:1, :]), axis=0)
  666. x = np.linspace(0, len(s) - 1, n)
  667. xp = np.arange(len(s))
  668. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  669. return segments
  670. def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
  671. # Rescale boxes (xyxy) from img1_shape to img0_shape
  672. if ratio_pad is None: # calculate from img0_shape
  673. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  674. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  675. else:
  676. gain = ratio_pad[0][0]
  677. pad = ratio_pad[1]
  678. boxes[..., [0, 2]] -= pad[0] # x padding
  679. boxes[..., [1, 3]] -= pad[1] # y padding
  680. boxes[..., :4] /= gain
  681. clip_boxes(boxes, img0_shape)
  682. return boxes
  683. def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
  684. # Rescale coords (xyxy) from img1_shape to img0_shape
  685. if ratio_pad is None: # calculate from img0_shape
  686. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  687. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  688. else:
  689. gain = ratio_pad[0][0]
  690. pad = ratio_pad[1]
  691. segments[:, 0] -= pad[0] # x padding
  692. segments[:, 1] -= pad[1] # y padding
  693. segments /= gain
  694. clip_segments(segments, img0_shape)
  695. if normalize:
  696. segments[:, 0] /= img0_shape[1] # width
  697. segments[:, 1] /= img0_shape[0] # height
  698. return segments
  699. def strtolst(strpoint):
  700. strpoint = strpoint.split(":")
  701. lista = []
  702. for liststr in strpoint:
  703. if len(liststr) > 0:
  704. liststr = liststr.split(',')
  705. for point in liststr:
  706. lista.append(point.split('#'))
  707. return lista
  708. def strtolstl(strpoint):
  709. #print(f'strpoint = {strpoint}')
  710. strpoint = strpoint.split(":")
  711. lista = []
  712. for i ,liststr in enumerate(strpoint):
  713. #lista.append([])
  714. if len(liststr) > 0:
  715. lista.append([])
  716. liststr = liststr.split(',')
  717. for point in liststr:
  718. lista[i].append(point.split('#'))
  719. return lista
  720. def clip_boxes(boxes, shape):
  721. # Clip boxes (xyxy) to image shape (height, width)
  722. if isinstance(boxes, torch.Tensor): # faster individually
  723. boxes[..., 0].clamp_(0, shape[1]) # x1
  724. boxes[..., 1].clamp_(0, shape[0]) # y1
  725. boxes[..., 2].clamp_(0, shape[1]) # x2
  726. boxes[..., 3].clamp_(0, shape[0]) # y2
  727. else: # np.array (faster grouped)
  728. boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
  729. boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
  730. def clip_segments(segments, shape):
  731. # Clip segments (xy1,xy2,...) to image shape (height, width)
  732. if isinstance(segments, torch.Tensor): # faster individually
  733. segments[:, 0].clamp_(0, shape[1]) # x
  734. segments[:, 1].clamp_(0, shape[0]) # y
  735. else: # np.array (faster grouped)
  736. segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
  737. segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
  738. def non_max_suppression(
  739. prediction,
  740. conf_thres=0.25,
  741. iou_thres=0.45,
  742. classes=None,
  743. agnostic=False,
  744. multi_label=False,
  745. labels=(),
  746. max_det=300,
  747. nm=0, # number of masks
  748. ):
  749. """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
  750. Returns:
  751. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  752. """
  753. # Checks
  754. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  755. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  756. if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
  757. prediction = prediction[0] # select only inference output
  758. device = prediction.device
  759. mps = 'mps' in device.type # Apple MPS
  760. if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
  761. prediction = prediction.cpu()
  762. bs = prediction.shape[0] # batch size
  763. nc = prediction.shape[2] - nm - 5 # number of classes
  764. xc = prediction[..., 4] > conf_thres # candidates
  765. # Settings
  766. # min_wh = 2 # (pixels) minimum box width and height
  767. max_wh = 7680 # (pixels) maximum box width and height
  768. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  769. time_limit = 0.5 + 0.05 * bs # seconds to quit after
  770. redundant = True # require redundant detections
  771. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  772. merge = False # use merge-NMS
  773. t = time.time()
  774. mi = 5 + nc # mask start index
  775. output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  776. for xi, x in enumerate(prediction): # image index, image inference
  777. # Apply constraints
  778. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  779. x = x[xc[xi]] # confidence
  780. # Cat apriori labels if autolabelling
  781. if labels and len(labels[xi]):
  782. lb = labels[xi]
  783. v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
  784. v[:, :4] = lb[:, 1:5] # box
  785. v[:, 4] = 1.0 # conf
  786. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  787. x = torch.cat((x, v), 0)
  788. # If none remain process next image
  789. if not x.shape[0]:
  790. continue
  791. # Compute conf
  792. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  793. # Box/Mask
  794. box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
  795. mask = x[:, mi:] # zero columns if no masks
  796. # Detections matrix nx6 (xyxy, conf, cls)
  797. if multi_label:
  798. i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
  799. x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
  800. else: # best class only
  801. conf, j = x[:, 5:mi].max(1, keepdim=True)
  802. x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  803. # Filter by class
  804. if classes is not None:
  805. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  806. # Apply finite constraint
  807. # if not torch.isfinite(x).all():
  808. # x = x[torch.isfinite(x).all(1)]
  809. # Check shape
  810. n = x.shape[0] # number of boxes
  811. if not n: # no boxes
  812. continue
  813. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
  814. # Batched NMS
  815. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  816. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  817. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  818. i = i[:max_det] # limit detections
  819. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  820. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  821. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  822. weights = iou * scores[None] # box weights
  823. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  824. if redundant:
  825. i = i[iou.sum(1) > 1] # require redundancy
  826. output[xi] = x[i]
  827. if mps:
  828. output[xi] = output[xi].to(device)
  829. if (time.time() - t) > time_limit:
  830. LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
  831. break # time limit exceeded
  832. return output
  833. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  834. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  835. x = torch.load(f, map_location=torch.device('cpu'))
  836. if x.get('ema'):
  837. x['model'] = x['ema'] # replace model with ema
  838. for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
  839. x[k] = None
  840. x['epoch'] = -1
  841. x['model'].half() # to FP16
  842. for p in x['model'].parameters():
  843. p.requires_grad = False
  844. torch.save(x, s or f)
  845. mb = os.path.getsize(s or f) / 1E6 # filesize
  846. LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
  847. def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  848. evolve_csv = save_dir / 'evolve.csv'
  849. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  850. keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
  851. keys = tuple(x.strip() for x in keys)
  852. vals = results + tuple(hyp.values())
  853. n = len(keys)
  854. # Download (optional)
  855. if bucket:
  856. url = f'gs://{bucket}/evolve.csv'
  857. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  858. subprocess.run(['gsutil', 'cp', f'{url}', f'{save_dir}']) # download evolve.csv if larger than local
  859. # Log to evolve.csv
  860. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  861. with open(evolve_csv, 'a') as f:
  862. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  863. # Save yaml
  864. with open(evolve_yaml, 'w') as f:
  865. data = pd.read_csv(evolve_csv, skipinitialspace=True)
  866. data = data.rename(columns=lambda x: x.strip()) # strip keys
  867. i = np.argmax(fitness(data.values[:, :4])) #
  868. generations = len(data)
  869. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  870. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  871. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  872. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  873. # Print to screen
  874. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  875. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  876. for x in vals) + '\n\n')
  877. if bucket:
  878. subprocess.run(['gsutil', 'cp', f'{evolve_csv}', f'{evolve_yaml}', f'gs://{bucket}']) # upload
  879. def apply_classifier(x, model, img, im0):
  880. # Apply a second stage classifier to YOLO outputs
  881. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  882. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  883. for i, d in enumerate(x): # per image
  884. if d is not None and len(d):
  885. d = d.clone()
  886. # Reshape and pad cutouts
  887. b = xyxy2xywh(d[:, :4]) # boxes
  888. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  889. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  890. d[:, :4] = xywh2xyxy(b).long()
  891. # Rescale boxes from img_size to im0 size
  892. scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
  893. # Classes
  894. pred_cls1 = d[:, 5].long()
  895. ims = []
  896. for a in d:
  897. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  898. im = cv2.resize(cutout, (224, 224)) # BGR
  899. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  900. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  901. im /= 255 # 0 - 255 to 0.0 - 1.0
  902. ims.append(im)
  903. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  904. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  905. return x
  906. def apply_classifier1(x, model, img, im0,modelname):
  907. # Apply a second stage classifier to YOLO outputs
  908. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  909. #print(f'len = {len(im0)}')
  910. im0 = [im0] if isinstance(im0, np.ndarray) and len(im0.shape)==3 else im0
  911. #im0 = [im0]
  912. for i, d in enumerate(x): # per image
  913. if d is not None and len(d):
  914. d = d.clone()
  915. # Reshape and pad cutouts
  916. # b = xyxy2xywh(d[:, :4]) # boxes
  917. # b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  918. # b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  919. # d[:, :4] = xywh2xyxy(b).long()
  920. # Rescale boxes from img_size to im0 size
  921. scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
  922. # Classes
  923. pred_cls1 = d[:, 5].long()
  924. ims = []
  925. for a in d:
  926. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  927. #im = cv2.resize(cutout, (224, 224)) # BGR
  928. im = cutout[:, :, ::-1] # BGR to RGB, to 3x416x416
  929. im = Image.fromarray(np.uint8(im))
  930. im = test(im)
  931. im = im.to(a.device)
  932. #print(img.dtype)
  933. #im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  934. #im /= 255 # 0 - 255 to 0.0 - 1.0
  935. ims.append(im)
  936. ims = torch.stack(ims,dim=0)
  937. #pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  938. pred_cls2 = model(ims).argmax(1)
  939. print('cls')
  940. print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
  941. print(pred_cls2)
  942. print(pred_cls1)
  943. if modelname =='fall':
  944. x[i] = x[i][pred_cls2 == 0] # retain matching class detections
  945. else:
  946. x[i] = x[i][pred_cls1 == pred_cls2]
  947. return x
  948. def apply_classifieruniform(x, model, img, im0,modelname):
  949. # Apply a second stage classifier to YOLO outputs
  950. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  951. #im0 = [im0] if isinstance(im0, np.ndarray) else im0
  952. im0 = [im0] if isinstance(im0, np.ndarray) and len(im0)==1 and len(im0.shape)!=1 else im0
  953. print(type(x))
  954. xp = []
  955. xh = []
  956. for xi,xa in enumerate(x):
  957. print(xa)
  958. xp.append(xa[(xa[:, 5:6] == torch.tensor([4], device=xa.device)).any(1)])
  959. xh.append(xa[(xa[:, 5:6] == torch.tensor([1,2], device=xa.device)).any(1)])
  960. output = len(xp)*[[]]
  961. for i, d in enumerate(xp): # per image
  962. if d is not None and len(d):
  963. d = d.clone()
  964. # Reshape and pad cutouts
  965. # b = xyxy2xywh(d[:, :4]) # boxes
  966. # b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  967. # b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  968. # d[:, :4] = xywh2xyxy(b).long()
  969. # Rescale boxes from img_size to im0 size
  970. print(f'orid= {d}')
  971. scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
  972. # Classes
  973. pred_cls1 = d[:, 5].long()
  974. ims = []
  975. for a in d:
  976. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  977. cv2.imwrite(f'{time.time()}.jpg',cutout)
  978. #im = cv2.resize(cutout, (224, 224)) # BGR
  979. #im = cutout[:, :, ::-1] # BGR to RGB, to 3x416x416
  980. ims.append(cutout)
  981. #ims = torch.stack(ims,dim=0)
  982. #pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  983. print(f'ims= {len(ims)}')
  984. ims = torch.stack([v8transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in ims]).to(model.device)
  985. print(model(ims))
  986. pred_cls2 = model(ims).argmax(1)
  987. #if
  988. print('cls')
  989. print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
  990. print(f'pred_cls2 {pred_cls2}')
  991. print(pred_cls1)
  992. if modelname =='uniform':
  993. print(xp[i])
  994. xp[i] = xp[i][pred_cls2 == 0] # retain matching class detections
  995. else:
  996. xp[i] = xp[i][pred_cls1 == pred_cls2]
  997. if len(xp[i])>0:
  998. output[i] = torch.cat([xp[i],xh[i]],dim=0)
  999. #print(o)
  1000. return output
  1001. def apply_classifierarm(x, model, img, im0,modelname):
  1002. # Apply a second stage classifier to YOLO outputs
  1003. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  1004. #im0 = [im0] if isinstance(im0, np.ndarray) else im0
  1005. im0 = [im0] if isinstance(im0, np.ndarray) and len(im0.shape)==3 else im0
  1006. print(type(x))
  1007. xp = []
  1008. xh = []
  1009. #for xi,xa in enumerate(x):
  1010. # print(xa)
  1011. #xp.append(xa[(xa[:, 5:6] == torch.tensor([4], device=xa.device)).any(1)])
  1012. # x[xi][:,5] = 1-x[xi][:,5]
  1013. # xh.append(xa[(xa[:, 5:6] == torch.tensor([1,2], device=xa.device)).any(1)])
  1014. #output = len(xp)*[[]]
  1015. for i, d in enumerate(x): # per image
  1016. if d is not None and len(d):
  1017. d = d.clone()
  1018. # Reshape and pad cutouts
  1019. # b = xyxy2xywh(d[:, :4]) # boxes
  1020. # b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  1021. # b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  1022. # d[:, :4] = xywh2xyxy(b).long()
  1023. # Rescale boxes from img_size to im0 size
  1024. print(f'orid= {d}')
  1025. scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
  1026. # Classes
  1027. pred_cls1 = d[:, 5].long()
  1028. ims = []
  1029. for a in d:
  1030. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  1031. #cv2.imwrite(f'{time.time()}.jpg',cutout)
  1032. #im = cv2.resize(cutout, (224, 224)) # BGR
  1033. #im = cutout[:, :, ::-1] # BGR to RGB, to 3x416x416
  1034. ims.append(cutout)
  1035. #ims = torch.stack(ims,dim=0)
  1036. #pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  1037. print(f'ims= {len(ims)}')
  1038. ims = torch.stack([v8transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in ims]).to(model.device)
  1039. print(ims.size())
  1040. print(model(ims))
  1041. pred_cls2 = model(ims).argmax(1)
  1042. #if
  1043. print('cls')
  1044. print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
  1045. print(f'pred_cls2 {pred_cls2}')
  1046. print(pred_cls1)
  1047. pred_cls2 = 1 - pred_cls2
  1048. x[i] = x[i][pred_cls2==pred_cls1]
  1049. #x[i] = 1-x[i][:,5]
  1050. #x[i][:,5] = torch.tensor(0).to(x[i].device)
  1051. #if modelname =='uniform':
  1052. # print(xp[i])
  1053. # xp[i] = xp[i][pred_cls2 == 0] # retai n matching class detections
  1054. #else:
  1055. # xp[i] = xp[i][pred_cls1 == pred_cls2]
  1056. #if len(xp[i])>0:
  1057. # output[i] = torch.cat([xp[i],xh[i]],dim=0)
  1058. #print(o)
  1059. return x
  1060. def numpy_image_to_base64(image_array):
  1061. # 将 numpy 数组转换为 PIL 图像
  1062. image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
  1063. image = Image.fromarray(np.uint8(image_array))
  1064. # 创建一个内存中的二进制流
  1065. buffer = io.BytesIO()
  1066. # 将 PIL 图像保存到二进制流中,格式为 PNG
  1067. image.save(buffer, format="PNG")
  1068. # 获取二进制流中的数据
  1069. img_bytes = buffer.getvalue()
  1070. # 将二进制数据进行 Base64 编码
  1071. base64_encoded = base64.b64encode(img_bytes).decode('utf-8')
  1072. return base64_encoded
  1073. def compute_IOU(rec1,rec2):
  1074. """
  1075. 计算两个矩形框的交并比。
  1076. :param rec1: (x0,y0,x1,y1) (x0,y0)代表矩形左上的顶点,(x1,y1)代表矩形右下的顶点。下同。
  1077. :param rec2: (x0,y0,x1,y1)
  1078. :return: 交并比IOU.
  1079. """
  1080. left_column_max = max(rec1[0],rec2[0])
  1081. right_column_min = min(rec1[2],rec2[2])
  1082. up_row_max = max(rec1[1],rec2[1])
  1083. down_row_min = min(rec1[3],rec2[3])
  1084. #两矩形无相交区域的情况
  1085. if left_column_max>=right_column_min or down_row_min<=up_row_max:
  1086. return 0,0
  1087. # 两矩形有相交区域的情况
  1088. else:
  1089. S1 = (rec1[2]-rec1[0])*(rec1[3]-rec1[1])
  1090. S2 = (rec2[2]-rec2[0])*(rec2[3]-rec2[1])
  1091. S_cross = (down_row_min-up_row_max)*(right_column_min-left_column_max)
  1092. x1 = min(rec1[0],rec2[0])
  1093. y1 = min(rec1[1],rec2[1])
  1094. x2 = max(rec1[2],rec2[2])
  1095. y2 = max(rec1[3],rec2[3])
  1096. return S_cross/(S1+S2-S_cross),torch.tensor((x1,y1,x2,y2))
  1097. def task(cur,conn,url,urla):
  1098. modelnamedir = {'0':'helmet','8':'danager','10':'uniform','14':'smoke','16':'fire','21':'cross','25':'fall','29':'occupancy','30':'liquid','31':'pressure','32':'sleep','33':'conveyor','34':'personcount','35':'gloves','36':'sit','37':'other','38':'duty','98':'face','51':'run','64':'jump','62':'clear'}
  1099. modellabeldir = {'0':'head','8':'person','10':'other','14':'smoke','16':'fire','21':'cross','25':'fall','29':'car','30':'liquid','31':'pressure','32':'sleep','33':'conveyor','34':'personcount','35':'gloves','36':'sit','37':'other','38':'person','98':'face','51':'person','64':'person','62':'hand'}
  1100. modelalgdir = {'helmet': '0','danager': '8','uniform': '10','smoke': '14','fire': '16','cross': '21','fall': '25','occupancy': '29','liquid': '30','pressure': '31','sleep': '32','conveyor': '33','personcount': '34','gloves': '35','sit': '36','other': '37','duty': '38','face': '98','run': '51','jump':'64','clear':'62'}
  1101. data = {
  1102. "algorithmCode": None,
  1103. "deviceIp":None
  1104. }
  1105. result = requests.post(url=url,data=data).json()['data']
  1106. postlist = []
  1107. for info in result:
  1108. #print(f'{info['ip']}')
  1109. #print(f"{info['deviceIp']},{info['deviceAlgorithmIp']},{info['deviceChannel']},{info['videoName']},{info['videoPassword']},{info['algorithmCode']},{info['electricFence']}")
  1110. cursor = cur.execute("select point from stream where classindex = (?) and channel =(?)",(info['algorithmCode'],info['deviceChannel']))
  1111. resultsub = cursor.fetchall()
  1112. print(resultsub)
  1113. #if not result:
  1114. # print(123)
  1115. postlist.append((info['deviceChannel'],info['algorithmCode']))
  1116. if resultsub:
  1117. if info['algorithmCode'] not in modelnamedir:
  1118. continue
  1119. pointres = info['electricFence']
  1120. if resultsub[:-1] != pointres[:-1]:
  1121. print('true')
  1122. print(info['deviceChannel'])
  1123. print(f'electricFence {info["electricFence"]}')
  1124. if len(pointres)>0:
  1125. pointres = pointres[:-1]+ ':'
  1126. print(f'in{pointres}')
  1127. print(info['algorithmCode'],info['deviceChannel'])
  1128. cur.execute("UPDATE STREAM set fence ='1', point = (?) where classindex=(?) and channel = (?)",
  1129. (pointres,info['algorithmCode'],info['deviceChannel']))
  1130. cur.execute("select * from STREAM where classindex=(?) and channel = (?)",
  1131. (info['algorithmCode'], info['deviceChannel']))
  1132. print(cur.fetchall())
  1133. else:
  1134. cur.execute("UPDATE STREAM set fence ='0', point = (?) where classindex=(?) and channel = (?)",
  1135. (pointres, info['algorithmCode'], info['deviceChannel']))
  1136. cur.execute("SELECT * from STREAM where classindex=(?) and channel = (?)",
  1137. (info['algorithmCode'], info['deviceChannel']))
  1138. print(cur.fetchall())
  1139. print('commit')
  1140. conn.commit()
  1141. else:
  1142. #cur.execute("UPDATE STREAM set fence ='1', point = (?) where channel=(?) and classindex = (?)",
  1143. # (pointres, info['algorithmCode'], info['deviceChannel']))
  1144. if info['deviceIp'] is not None:
  1145. if info['algorithmCode'] not in modelnamedir:
  1146. continue
  1147. #address = f"rtsp://{info['videoName']}:{info['videoPassword']}@{info['deviceIp']}:554/Streaming/Channels/1"
  1148. address = info['playbackAddress']
  1149. label = modellabeldir[info['algorithmCode']]
  1150. modelname = modelnamedir[info['algorithmCode']]
  1151. print(modelname)
  1152. print(label)
  1153. print(address)
  1154. if modelname == 'duty':
  1155. durtime = 300
  1156. else:
  1157. durtime = 0
  1158. if info['electricFence']:
  1159. fence = 1
  1160. point = f"{info['electricFence'][:-1]}:"
  1161. print(info['electricFence'])
  1162. else:
  1163. fence = 0
  1164. point = '0'
  1165. channel = info['deviceChannel']
  1166. code = info['algorithmCode']
  1167. ip = info['deviceIp']
  1168. algdevice = info['fwqCode']
  1169. cur.execute("INSERT INTO STREAM (MODELNAME,ADDRESS,FENCE,POINT,CHANNEL,CLASSINDEX,IP,ALGIP,ALGDEVICE,LABEL,DURTIME) \
  1170. VALUES ((?),(?),(?),(?),(?),(?),(?),(?),(?),(?),(?))",(modelname,address,fence,point,channel,code,ip,info['deviceAlgorithmIp'],algdevice,label,durtime))
  1171. print('add--------------------------------------------------------')
  1172. cur.execute("select channel,classindex from stream")
  1173. realist = cur.fetchall()
  1174. for r in realist:
  1175. if r not in postlist:
  1176. cur.execute("delete from stream where channel = (?)and classindex = (?)",(r[0],r[1]))
  1177. conn.commit()
  1178. cursor = cur.execute("select modelname from stream")
  1179. modellist = set(cursor.fetchall())
  1180. cursor = cur.execute("select modelname from changestream")
  1181. changemodellist = set(cursor.fetchall())
  1182. for model in modellist:
  1183. if model not in changemodellist:
  1184. print(model[0])
  1185. a = modelalgdir[model[0]]
  1186. #if a!='0':
  1187. rea = requests.post(url=urla,data={'algorithmCode':a}).json()['data']
  1188. #print(rea)
  1189. if len(rea)>0:
  1190. con = rea[0]['confidence']
  1191. else:
  1192. con =0.25
  1193. cla = 0
  1194. cur.execute("INSERT INTO CHANGESTREAM (MODELNAME,ADDSTREAM,DELSTREAM,STREAMING,CONF,CLA) \
  1195. VALUES ((?),0,0,0,(?),(?))",(model[0],con,cla))
  1196. #else:
  1197. # con = 0.25
  1198. # cla = 0
  1199. # cur.execute("INSERT INTO CHANGESTREAM (MODELNAME,ADDSTREAM,DELSTREAM,STREAMING,CONF,CLA) \
  1200. # VALUES ((?),0,0,0,(?),(?))",(model[0],con,cla))
  1201. if ('stream',) not in changemodellist:
  1202. cur.execute("INSERT INTO CHANGESTREAM (MODELNAME,ADDSTREAM,DELSTREAM,STREAMING,CONF,CLA) \
  1203. VALUES ('stream',0,0,0,0,0)")
  1204. conn.commit()
  1205. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  1206. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  1207. path = Path(path) # os-agnostic
  1208. if path.exists() and not exist_ok:
  1209. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  1210. # Method 1
  1211. for n in range(2, 9999):
  1212. p = f'{path}{sep}{n}{suffix}' # increment path
  1213. if not os.path.exists(p): #
  1214. break
  1215. path = Path(p)
  1216. # Method 2 (deprecated)
  1217. # dirs = glob.glob(f"{path}{sep}*") # similar paths
  1218. # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
  1219. # i = [int(m.groups()[0]) for m in matches if m] # indices
  1220. # n = max(i) + 1 if i else 2 # increment number
  1221. # path = Path(f"{path}{sep}{n}{suffix}") # increment path
  1222. if mkdir:
  1223. path.mkdir(parents=True, exist_ok=True) # make directory
  1224. return path
  1225. # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------------
  1226. imshow_ = cv2.imshow # copy to avoid recursion errors
  1227. def imread(filename, flags=cv2.IMREAD_COLOR):
  1228. return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
  1229. def imwrite(filename, img):
  1230. try:
  1231. cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
  1232. return True
  1233. except Exception:
  1234. return False
  1235. def imshow(path, im):
  1236. imshow_(path.encode('unicode_escape').decode(), im)
  1237. if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename:
  1238. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  1239. # Variables ------------------------------------------------------------------------------------------------------------