diff --git a/src/aixd/data/data_objects.py b/src/aixd/data/data_objects.py index 4808330186cf88b580b6b7112182d8ff1c258092..d753879384e9ddc3d9e1a57b595e180b14971870 100644 --- a/src/aixd/data/data_objects.py +++ b/src/aixd/data/data_objects.py @@ -1012,6 +1012,7 @@ class DataBool(DataCategorical): def __init__(self, name: str, **kwargs): domain = kwargs.pop("domain", dd.Options(["True", "False"], type="categorical")) super().__init__(name=name, domain=domain, **kwargs) + self.type = "bool" class DataOrdinal(DataDiscrete): diff --git a/src/aixd/visualisation/plotter.py b/src/aixd/visualisation/plotter.py index 89c305359e3b89900ef1136b988e05cd06f50613..420480c68617a6f1da11559e4484e176f0c2ac03 100644 --- a/src/aixd/visualisation/plotter.py +++ b/src/aixd/visualisation/plotter.py @@ -1336,6 +1336,123 @@ class Plotter: return self._output(fig, output_name=output_name or f"PerformanceSummary_{block.name}") + def parallel_coordinates( + self, + data: pd.DataFrame, + variables: Optional[List[str]] = None, + value_ranges: Optional[dict[str, Tuple[float, float]]] = None, + color_by: Optional[str] = None, + output_name: Optional[str] = None, + ): + """ + Parallel coordinates plot for the given data. + + Parameters + ---------- + data: [:class:`pd.DataFrame`] + Dataframe containing data to plot. + This could be data from the dataset or new samples generated by the model. + The order of the parallel coordinates corresponds to the order of columns. + variables: List[str], optional + List of variable names to plot (a selection from the data given). If not specified, all columns are plotted. + value_ranges: dict[str, Tuple[float, float]], optional + Dictionary with column names (or variable names, if `variables` are given) as keys and [min,max] as values. + If not specified, the ranges are derived from the dataset (if `variables` are given), or from the provided data (if otherwise). + color_by: str, optional + Name of the variable to color the lines by. If not specified, default color (grey) is used for all samples. + output_name : str, optional, default=None + Name of the output file. If None, the name is automatically generated from the data block name. + + Returns + ------- + Optional[:class:`plotly.graph_objects.Figure`] + Plotly figure object, if self.output is None, otherwise None. + """ + # TODO: add transformed flag + # TODO: separate function for dataset and separate for generated samples? + + # if data is real or int --> calc value ranges + # if data is bool or cat --> set tickvals and ticktext + # need to create a dummy variable: + # https://stackoverflow.com/questions/64139316/plotly-how-to-insert-a-categorical-variable-into-a-parallel-coordinates-plot + + df = data + + all_parcoorddimensions = [] + linesettings = None + + if variables: + """ + If variable names are given, some settings are derived from the data objects via Dataset. + """ + all_dobjs = self.dataset.get_data_objects_by_name(variables) + all_column_names = list(chain(*[dobj.columns_df for dobj in all_dobjs])) + for dobj in all_dobjs: + if color_by is not None and dobj.name == color_by: + if dobj.dim > 1: + raise ValueError(f"Cannot color by multi-dimensional attribute {dobj.name}.") + + for colname in dobj.columns_df: + if dobj.type in ["real", "integer"]: + if value_ranges is not None: + vrange = value_ranges[colname] + else: + vrange = [dobj.domain.min_value, dobj.domain.max_value] + parcoorddimension = dict(range=vrange, label=colname, values=df[colname].tolist()) + if color_by is not None and dobj.name == color_by: + linesettings = dict(color=df[color_by], colorscale=color_divergent_centered, cmin=vrange[0], cmax=vrange[1]) + elif dobj.type in ["bool", "categorical"]: + options = [str(val) for val in dobj.domain.array] + options.sort() + dummyindex = {opt: idx for idx, opt in enumerate(options)} + parcoorddimension = dict( + tickvals=list(dummyindex.values()), ticktext=list(dummyindex.keys()), label=colname, values=[dummyindex[str(val)] for val in df[colname].tolist()] + ) + else: + raise ValueError(f"DataObject type {dobj.type} not supported for parallel coordinates plot.") + all_parcoorddimensions.append(parcoorddimension) + + else: + """ + If variable names are not given, some settings are derived from the dataframe directly. + """ + all_column_names = df.columns + df = df.infer_objects() + for colname in all_column_names: + if df[colname].dtype in [ + "int64", + "float64", + "float32", + "int32", + ]: + if value_ranges is not None: + vrange = value_ranges[colname] + else: + vrange = [df[colname].min(), df[colname].max()] + parcoorddimension = dict(range=vrange, label=colname, values=df[colname].tolist()) + if color_by is not None and colname == color_by: + linesettings = dict(color=df[color_by], colorscale=color_divergent_centered, cmin=vrange[0], cmax=vrange[1]) + elif df[colname].dtype in ["str", "object", "bool"]: + options = [str(val) for val in df[colname].unique()] + options.sort() + dummyindex = {opt: idx for idx, opt in enumerate(options)} + parcoorddimension = dict( + tickvals=list(dummyindex.values()), ticktext=list(dummyindex.keys()), label=colname, values=[dummyindex[str(val)] for val in df[colname].tolist()] + ) + else: + raise ValueError(f"Data type {df[colname].dtype} not supported for parallel coordinates plot.") + all_parcoorddimensions.append(parcoorddimension) + + if color_by is not None: + if linesettings is None: + # linesettings not set yet --> it is a categorical/bool variable + linesettings = dict(color=df[color_by], colorscale=color_qualitative10) + else: + linesettings = dict(color="grey") + + fig = go.Figure(data=go.Parcoords(line=linesettings, dimensions=all_parcoorddimensions)) + return self._output(fig, output_name=output_name or "ParallelCoordinates") + @staticmethod def _open_fig(size: Tuple[int, int] = (1, 1), **kwargs) -> go.Figure: """Helper method to open a figure with the desired number of rows and columns."""