# ============================================================================
# ============================================================================
# Copyright (c) 2018 Diamond Light Source Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Author: Nghia T. Vo
# E-mail:
# Publication date: 10th July 2018
# ============================================================================
# Contributors:
# ============================================================================
"""
Module for I/O tasks:
- Load data from an image file (tif, png, jpg) or a hdf file.
- Save a 2D array as a tif/png/jpg image or a 2D, 3D array to a hdf file.
- Save a plot of data points to an image.
- Save/load metadata to/from a text/json file.
- Save/load python list.
"""
import json
import pickle
import platform
from pathlib import Path
import h5py
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import font_manager
from collections import OrderedDict
def __correct_path(file_path):
"""
Correct escaped sequences in WinOS file path.
"""
if isinstance(file_path, Path):
file_path = str(file_path)
escape_sequences = {
'\a': r'\a',
'\b': r'\b',
'\f': r'\f',
'\n': r'\n',
'\r': r'\r',
'\t': r'\t',
'\v': r'\v',
'\0': r'\0',
}
for char, escaped in escape_sequences.items():
if char in file_path:
file_path = file_path.replace(char, escaped)
file_path = file_path.replace('\\', '/')
return Path(file_path)
def __get_path(file_path, check_exist=True):
"""
Get/check a file path
"""
if platform.system() == 'Windows':
file_path = __correct_path(file_path)
else:
file_path = Path(file_path)
if check_exist:
if not file_path.exists():
raise ValueError(f"No such file: {file_path}")
return file_path
[docs]
def load_image(file_path, average=True):
"""
Load data from an image.
Parameters
----------
file_path : str
Path to a file.
average : bool, optional
Average a multichannel image if True.
Returns
-------
array_like
"""
try:
mat = np.array(Image.open(__get_path(file_path)), dtype=np.float32)
except Exception as error:
raise ValueError(error)
if len(mat.shape) > 2 and average is True:
axis_m = np.argmin(mat.shape)
mat = np.mean(mat, axis=axis_m)
return mat
[docs]
def find_hdf_key(file_path, pattern, display=False):
"""
Find datasets matching the name-pattern in a hdf/nxs file.
Parameters
----------
file_path : str
Path to the file.
pattern : str
Pattern to find the full names of the datasets.
display : bool
Print the results onto the screen if True.
Returns
-------
list_key : str
Keys to the datasets.
list_shape : tuple of int
Shapes of the datasets.
list_type : str
Types of the datasets.
"""
hdf_object = h5py.File(__get_path(file_path), 'r')
list_key, keys = [], []
hdf_object.visit(keys.append)
for key in keys:
try:
data = hdf_object[key]
if isinstance(data, h5py.Group):
list_tmp = list(data.items())
if list_tmp:
for key2, _ in list_tmp:
list_key.append(key + "/" + key2)
else:
list_key.append(key)
else:
list_key.append(data.name)
except KeyError:
pass
list_dkey, list_dshape, list_dtype = [], [], []
for _, key in enumerate(list_key):
if pattern in key:
list_dkey.append(key)
shape, dtype = None, None
try:
data = hdf_object[key]
if isinstance(data, h5py.Dataset):
shape, dtype = data.shape, data.dtype
list_dtype.append(dtype)
list_dshape.append(shape)
except KeyError:
list_dtype.append(dtype)
list_dshape.append(shape)
pass
hdf_object.close()
if display:
if list_dkey:
for i, key in enumerate(list_dkey):
print(key + " : " + str(list_dshape[i]) + " : " + str(
list_dtype[i]))
else:
print("Can't find datasets with keys matching the "
"pattern: {}".format(pattern))
return list_dkey, list_dshape, list_dtype
def _get_key(name, obj):
"""
Find a key path containing 'data' in a dataset. Use with Group.visititems()
method to walk through an HDF5 tree.
"""
if isinstance(obj, h5py.Group):
for key, val in obj.items():
if key == "data" and isinstance(val, h5py.Dataset):
return f"{obj.name}/{key}"
[docs]
def load_hdf_file(file_path, key_path=None, index=None, axis=0):
"""
Load data from a hdf5/nxs file.
Parameters
----------
file_path : str
Path to a hdf/nxs file.
key_path : str
Key path to a dataset.
index : int or tuple of int
Values for slicing data. Can be integer, tuple or list,
e.g index=(start,stop,step) or index=(slice1, slice2, slice3,slice4).
axis : int
Slice direction
Returns
-------
array_like
2D array or 3D array.
"""
mat = None
try:
ifile = h5py.File(__get_path(file_path), 'r')
except Exception as error:
raise ValueError(f"Error: {error}")
if key_path is None:
key_path = ifile.visititems(_get_key) # Find the key automatically
if key_path is None:
raise ValueError("Please provide the key path to the dataset!")
check = key_path in ifile
if not check:
raise ValueError("Couldn't open object with the key path: "
"{}".format(key_path))
idata = ifile[key_path]
shape = idata.shape
if len(shape) < 2 or len(shape) > 3:
raise ValueError("Require a 2D or 3D dataset!")
if len(shape) == 2:
mat = np.asarray(idata)
if len(shape) == 3:
axis = int(np.clip(axis, 0, 2))
if index is None:
mat = np.float32(idata[:, :, :])
else:
if isinstance(index, int):
try:
if axis == 0:
mat = np.float32(idata[index, :, :])
elif axis == 1:
mat = np.float32(idata[:, index, :])
else:
mat = np.float32(idata[:, :, index])
axis = np.clip(axis, 0, 1)
except IndexError:
raise
if isinstance(index, tuple) or isinstance(index, list):
if len(index) == 3:
start = index[0]
stop = index[1]
step = index[2]
list_index = list(range(start, stop, step))
elif len(index) == 2:
start = index[0]
stop = index[1]
list_index = list(range(start, stop))
else:
list_index = list(index)
try:
if axis == 0:
mat = np.float32(idata[list_index, :, :])
elif axis == 1:
mat = np.float32(idata[:, list_index, :])
else:
mat = np.float32(idata[:, :, list_index])
except IndexError:
raise
if mat.shape[axis] == 1:
mat = np.swapaxes(mat, axis, 0)[0]
if mat.shape[axis] == 0:
raise ValueError("Empty indices!")
return mat
[docs]
def load_hdf_object(file_path, key_path):
"""
Load a hdf/nexus dataset as an object.
Parameters
----------
file_path : str
Path to a hdf/nxs file.
key_path : str
Key path to a dataset.
Returns
-------
object
hdf/nxs object.
"""
try:
ifile = h5py.File(__get_path(file_path), 'r')
except Exception as error:
raise ValueError(f"Error: {error}")
check = key_path in ifile
if not check:
raise ValueError(f"Couldn't open object with the key: {key_path}")
return ifile[key_path]
def _create_folder(file_path):
"""
Create a folder to save a file if not exists.
Parameters
----------
file_path : str
Path to a file
"""
path = Path(file_path).resolve()
if path.suffix:
folder_path = path.parent
else:
folder_path = path
if not folder_path.exists():
try:
folder_path.mkdir(parents=True, exist_ok=True)
except Exception as e:
raise ValueError(f"Can't create : {folder_path}. Error: {e}")
def _create_file_name(file_path):
"""
Create a file name to avoid overwriting.
Parameters
----------
file_path : str
Path to a file
Returns
-------
str
Updated file path.
"""
file_path = Path(file_path)
file_base = file_path.stem
file_ext = file_path.suffix
parent_dir = file_path.parent
if file_path.exists():
nfile = 0
while True:
name_add = f"_{nfile:04d}"
new_file_name = f"{file_base}{name_add}{file_ext}"
new_file_path = parent_dir / new_file_name
if new_file_path.exists():
nfile += 1
else:
file_path = new_file_path
break
return str(file_path)
[docs]
def save_image(file_path, mat, overwrite=True):
"""
Save 2D data to an image.
Parameters
----------
file_path : str
Output file path.
mat : array_like
2D array.
overwrite : bool, optional
Overwrite an existing file if True.
Returns
-------
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False).resolve()
file_ext = file_path.suffix
if not ((file_ext == ".tif") or (file_ext == ".tiff")):
if mat.dtype != np.uint8:
nmin, nmax = np.min(mat), np.max(mat)
if nmax != nmin:
mat = np.uint8(255.0 * (mat - nmin) / (nmax - nmin))
else:
mat = np.uint8(mat)
else:
if len(mat.shape) > 2:
axis_m = np.argmin(mat.shape)
mat = np.mean(mat, axis=axis_m)
_create_folder(str(file_path))
if not overwrite:
file_path = _create_file_name(str(file_path))
image = Image.fromarray(mat)
try:
image.save(file_path)
except Exception as error:
raise ValueError(f"Couldn't write to file: {file_path}. Error {error}")
return file_path
[docs]
def save_plot_image(file_path, list_lines, height, width, overwrite=True,
dpi=100):
"""
Save the plot of points to an image. Useful to check if the points are
arranged properly where points on the same line having the same color.
Parameters
----------
file_path : str
Output file path.
list_lines : list of array_like
List of 2D arrays. Each list is the coordinates of points on a line.
height : int
Height of the image.
width : int
Width of the image.
overwrite : bool, optional
Overwrite the existing file if True.
dpi : int, optional
The resolution in points per inch.
Returns
-------
str
Updated file path.
"""
file_path = str(__get_path(file_path, check_exist=False))
_create_folder(file_path)
if not overwrite:
file_path = _create_file_name(file_path)
fig = plt.figure(frameon=False)
fig.set_size_inches(width / dpi, height / dpi)
ax = plt.Axes(fig, [0., 0., 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
plt.axis((0, width, 0, height))
m_size = 0.5 * min(height / dpi, width / dpi)
for line in list_lines:
plt.plot(line[:, 1], height - line[:, 0], '-o', markersize=m_size)
try:
plt.savefig(file_path, dpi=dpi)
except Exception as error:
raise ValueError(f"Couldn't write to file: {file_path}. Error {error}")
plt.close()
return file_path
def __check_font(font_family):
"""
Check if a specific font is available in Matplotlib.
Parameters
----------
font_family : str
Name of the font to check.
Returns
-------
bool
True if font is available, False otherwise.
"""
try:
font_manager.findfont(font_family, fallback_to_default=False)
return True
except:
return False
[docs]
def save_residual_plot(file_path, list_data, height, width, overwrite=True,
dpi=100, font_family='Times New Roman'):
"""
Save the plot of residual against radius to an image. Useful to check the
accuracy of unwarping results.
Parameters
----------
file_path : str
Output file path.
list_data : array_like
2D array. List of [residual, radius] of each point.
height : int
Height of the output image.
width : int
Width of the output image.
overwrite : bool, optional
Overwrite the existing file if True.
dpi : int, optional
The resolution in points per inch.
font_family : str, optional
To set the font family
Returns
-------
str
Updated file path.
"""
file_path = str(__get_path(file_path, check_exist=False))
_create_folder(file_path)
if not overwrite:
file_path = _create_file_name(file_path)
fig = plt.figure(frameon=False)
fig.set_size_inches(width / dpi, height / dpi)
m_size = 0.5 * min(height / dpi, width / dpi)
plt.rc('font', size=np.int16(m_size * 4))
if __check_font(font_family):
plt.rcParams['font.family'] = font_family
plt.rcParams['font.weight'] = 'bold'
plt.xlabel('Radius', fontweight='bold')
plt.ylabel('Residual', fontweight='bold')
plt.plot(list_data[:, 0], list_data[:, 1], '.', markersize=m_size)
try:
plt.savefig(file_path, dpi=dpi, bbox_inches='tight')
except Exception as error:
raise ValueError(f"Couldn't write to file: {file_path}. Error {error}")
plt.close()
plt.rcParams.update(plt.rcParamsDefault)
return file_path
[docs]
def save_hdf_file(file_path, idata, key_path='entry', overwrite=True):
"""
Write data to a hdf5 file.
Parameters
----------
file_path : str
Output file path.
idata : array_like
Data to be saved.
key_path : str
Key path to the dataset.
overwrite : bool, optional
Overwrite an existing file if True.
Returns
-------
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False).resolve()
if file_path.suffix.lower() not in {'.hdf', '.h5', '.nxs', '.hdf5'}:
file_path = file_path.with_suffix('.hdf')
_create_folder(str(file_path))
if not overwrite:
file_path = _create_file_name(str(file_path))
try:
ofile = h5py.File(file_path, 'w')
except Exception as error:
raise ValueError(f"Couldn't write to file: {file_path}. Error {error}")
grp = ofile.create_group(key_path)
grp.create_dataset("data", data=idata)
ofile.close()
return file_path
[docs]
def open_hdf_stream(file_path, data_shape, key_path='entry/data',
data_type='float32', overwrite=True, **options):
"""
Open stream to write data to a hdf/nxs file with options to add metadata.
Parameters
----------
file_path : str
Path to the file.
data_shape : tuple of int
Shape of the data.
key_path : str
Key path to the dataset.
data_type: str
Type of data.
overwrite : bool
Overwrite the existing file if True.
options : dict, optional
Add metadata. Example:
options={"entry/angles": angles, "entry/energy": 53}.
Returns
-------
object
hdf object.
"""
file_path = __get_path(file_path, check_exist=False).resolve()
if file_path.suffix.lower() not in {'.hdf', '.h5', '.nxs', '.hdf5'}:
file_path = file_path.with_suffix('.hdf')
_create_folder(str(file_path))
if not overwrite:
file_path = _create_file_name(str(file_path))
try:
ofile = h5py.File(file_path, 'w')
except Exception as error:
raise ValueError(f"Couldn't write to file: {file_path}. Error {error}")
if len(options) != 0:
for opt_name in options:
opts = options[opt_name]
for key in opts:
if key_path in key:
msg = "!!! Selected key path, '{0}', can not be a " \
"child key-path of '{1}' !!!\n!!! Change to make " \
"sure they are at the same level " \
"!!!".format(key, key_path)
raise ValueError(msg)
ofile.create_dataset(key, data=opts[key])
data_out = ofile.create_dataset(key_path, data_shape, dtype=data_type)
return data_out
[docs]
def save_plot_points(file_path, list_points, height, width, overwrite=True,
dpi=100, marker="o", color="blue"):
"""
Save the plot of points to an image. Useful to check if the points are
arranged properly where points on the same line having the same color.
Parameters
----------
file_path : str
Output file path.
list_points : list of 1D-array
List of the (y-x)-coordinates of points.
height : int
Height of the image.
width : int
Width of the image.
overwrite : bool, optional
Overwrite the existing file if True.
dpi : int, optional
The resolution in points per inch.
marker : str
Plot marker. Full list is at:
https://matplotlib.org/stable/api/markers_api.html
color : str
Marker color. Full list is at:
https://matplotlib.org/stable/tutorials/colors/colors.html
Returns
-------
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False).resolve()
_create_folder(str(file_path))
if not overwrite:
file_path = _create_file_name(str(file_path))
fig = plt.figure(frameon=False)
fig.set_size_inches(width / dpi, height / dpi)
ax = plt.Axes(fig, [0., 0., 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
plt.axis((0, width, 0, height))
m_size = 0.5 * min(height / dpi, width / dpi)
for point in list_points:
plt.plot(point[1], height - point[0], marker, color=color,
markersize=m_size)
try:
plt.savefig(file_path, dpi=dpi)
except IOError:
raise ValueError("Couldn't write to file {}".format(file_path))
plt.close()
return file_path
def __numpy_encoder(obj):
if isinstance(obj, (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64)):
return int(obj)
elif isinstance(obj, (np.float16, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()
raise TypeError(f"Object of type '{type(obj).__name__}' "
f"is not JSON serializable")
[docs]
def load_python_list(file_path):
"""
Load a Python list from a pickle file (.pkl).
Parameters
----------
file_path : str
Path to the pickle file.
Returns
-------
list
The Python list.
"""
with open(__get_path(file_path), 'rb') as f:
loaded_data = pickle.load(f)
return loaded_data
[docs]
def save_python_list(file_path, python_list, overwrite=True):
"""
Write python list to a pickle file (.pkl).
Parameters
----------
file_path : str
Output file path.
python_list : list
Python list.
overwrite : bool, optional
Overwrite an existing file if True.
Returns
-------
str
Updated file path.
"""
file_path = __get_path(file_path, check_exist=False).resolve()
if file_path.suffix.lower() != '.pkl':
file_path = file_path.with_suffix('.pkl')
_create_folder(str(file_path))
if not overwrite:
file_path = _create_file_name(str(file_path))
with open(file_path, 'wb') as f:
pickle.dump(python_list, f)
return file_path
[docs]
def find_file(path):
"""
Search file
Parameters
----------
path : str
Path and pattern to find files.
Returns
-------
str or list of str
List of files.
"""
path = __correct_path(path)
file_paths = list(path.parent.glob(path.name))
if not file_paths:
raise FileNotFoundError(f"No files found matching: {path}")
return sorted([file.as_posix() for file in file_paths])