Skip to content

Commit fe48a65

Browse files
authored
[ML] Data Frame Analytics: Fix feature importance cell value and decision path chart (#82011) (#82107)
Fixes a regression that caused data grid cells for feature importance to be empty and clicking on the button to show the decision path chart popover to render the whole page empty.
1 parent ae05572 commit fe48a65

File tree

7 files changed

+131
-10
lines changed

7 files changed

+131
-10
lines changed

x-pack/plugins/ml/common/types/feature_importance.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ export interface ClassFeatureImportance {
88
class_name: string | boolean;
99
importance: number;
1010
}
11+
12+
// TODO We should separate the interface because classes/importance
13+
// isn't both optional but either/or.
1114
export interface FeatureImportance {
1215
feature_name: string;
13-
importance?: number;
1416
classes?: ClassFeatureImportance[];
17+
importance?: number;
1518
}
1619

1720
export interface TopClass {

x-pack/plugins/ml/public/application/components/data_grid/common.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { EuiDataGridSorting } from '@elastic/eui';
88

99
import { multiColumnSortFactory } from './common';
1010

11-
describe('Transform: Define Pivot Common', () => {
11+
describe('Data Frame Analytics: Data Grid Common', () => {
1212
test('multiColumnSortFactory()', () => {
1313
const data = [
1414
{ s: 'a', n: 1 },

x-pack/plugins/ml/public/application/components/data_grid/common.ts

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ import {
2424
KBN_FIELD_TYPES,
2525
} from '../../../../../../../src/plugins/data/public';
2626

27+
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
2728
import { extractErrorMessage } from '../../../../common/util/errors';
29+
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
2830

2931
import {
3032
BASIC_NUMERICAL_TYPES,
@@ -158,6 +160,90 @@ export const getDataGridSchemaFromKibanaFieldType = (
158160
return schema;
159161
};
160162

163+
const getClassName = (className: string, isClassTypeBoolean: boolean) => {
164+
if (isClassTypeBoolean) {
165+
return className === 'true';
166+
}
167+
168+
return className;
169+
};
170+
/**
171+
* Helper to transform feature importance flattened fields with arrays back to object structure
172+
*
173+
* @param row - EUI data grid data row
174+
* @param mlResultsField - Data frame analytics results field
175+
* @returns nested object structure of feature importance values
176+
*/
177+
export const getFeatureImportance = (
178+
row: Record<string, any>,
179+
mlResultsField: string,
180+
isClassTypeBoolean = false
181+
): FeatureImportance[] => {
182+
const featureNames: string[] | undefined =
183+
row[`${mlResultsField}.feature_importance.feature_name`];
184+
const classNames: string[] | undefined =
185+
row[`${mlResultsField}.feature_importance.classes.class_name`];
186+
const classImportance: number[] | undefined =
187+
row[`${mlResultsField}.feature_importance.classes.importance`];
188+
189+
if (featureNames === undefined) {
190+
return [];
191+
}
192+
193+
// return object structure for classification job
194+
if (classNames !== undefined && classImportance !== undefined) {
195+
const overallClassNames = classNames?.slice(0, classNames.length / featureNames.length);
196+
197+
return featureNames.map((fName, index) => {
198+
const offset = overallClassNames.length * index;
199+
const featureClassImportance = classImportance.slice(
200+
offset,
201+
offset + overallClassNames.length
202+
);
203+
return {
204+
feature_name: fName,
205+
classes: overallClassNames.map((fClassName, fIndex) => {
206+
return {
207+
class_name: getClassName(fClassName, isClassTypeBoolean),
208+
importance: featureClassImportance[fIndex],
209+
};
210+
}),
211+
};
212+
});
213+
}
214+
215+
// return object structure for regression job
216+
const importance: number[] = row[`${mlResultsField}.feature_importance.importance`];
217+
return featureNames.map((fName, index) => ({
218+
feature_name: fName,
219+
importance: importance[index],
220+
}));
221+
};
222+
223+
/**
224+
* Helper to transforms top classes flattened fields with arrays back to object structure
225+
*
226+
* @param row - EUI data grid data row
227+
* @param mlResultsField - Data frame analytics results field
228+
* @returns nested object structure of feature importance values
229+
*/
230+
export const getTopClasses = (row: Record<string, any>, mlResultsField: string): TopClasses => {
231+
const classNames: string[] | undefined = row[`${mlResultsField}.top_classes.class_name`];
232+
const classProbabilities: number[] | undefined =
233+
row[`${mlResultsField}.top_classes.class_probability`];
234+
const classScores: number[] | undefined = row[`${mlResultsField}.top_classes.class_score`];
235+
236+
if (classNames === undefined || classProbabilities === undefined || classScores === undefined) {
237+
return [];
238+
}
239+
240+
return classNames.map((className, index) => ({
241+
class_name: className,
242+
class_probability: classProbabilities[index],
243+
class_score: classScores[index],
244+
}));
245+
};
246+
161247
export const useRenderCellValue = (
162248
indexPattern: IndexPattern | undefined,
163249
pagination: IndexPagination,
@@ -207,6 +293,15 @@ export const useRenderCellValue = (
207293
return item[cId];
208294
}
209295

296+
// For classification and regression results, we need to treat some fields with a custom transform.
297+
if (cId === `${resultsField}.feature_importance`) {
298+
return getFeatureImportance(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD);
299+
}
300+
301+
if (cId === `${resultsField}.top_classes`) {
302+
return getTopClasses(fullItem, resultsField ?? DEFAULT_RESULTS_FIELD);
303+
}
304+
210305
// Try if the field name is available as a nested field.
211306
return getNestedProperty(tableItems[adjustedRowIndex], cId, null);
212307
}

x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@ import { DEFAULT_SAMPLER_SHARD_SIZE } from '../../../../common/constants/field_h
2727

2828
import { ANALYSIS_CONFIG_TYPE, INDEX_STATUS } from '../../data_frame_analytics/common';
2929

30-
import { euiDataGridStyle, euiDataGridToolbarSettings } from './common';
30+
import {
31+
euiDataGridStyle,
32+
euiDataGridToolbarSettings,
33+
getFeatureImportance,
34+
getTopClasses,
35+
} from './common';
3136
import { UseIndexDataReturnType } from './types';
3237
import { DecisionPathPopover } from './feature_importance/decision_path_popover';
33-
import { TopClasses } from '../../../../common/types/feature_importance';
38+
import { FeatureImportance, TopClasses } from '../../../../common/types/feature_importance';
3439
import { DEFAULT_RESULTS_FIELD } from '../../../../common/constants/data_frame_analytics';
3540
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';
3641

@@ -118,18 +123,28 @@ export const DataGrid: FC<Props> = memo(
118123
if (!row) return <div />;
119124
// if resultsField for some reason is not available then use ml
120125
const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD;
121-
const parsedFIArray = row[mlResultsField].feature_importance;
122126
let predictedValue: string | number | undefined;
123127
let topClasses: TopClasses = [];
124128
if (
125129
predictionFieldName !== undefined &&
126130
row &&
127-
row[mlResultsField][predictionFieldName] !== undefined
131+
row[`${mlResultsField}.${predictionFieldName}`] !== undefined
128132
) {
129-
predictedValue = row[mlResultsField][predictionFieldName];
130-
topClasses = row[mlResultsField].top_classes;
133+
predictedValue = row[`${mlResultsField}.${predictionFieldName}`];
134+
topClasses = getTopClasses(row, mlResultsField);
131135
}
132136

137+
const isClassTypeBoolean = topClasses.reduce(
138+
(p, c) => typeof c.class_name === 'boolean' || p,
139+
false
140+
);
141+
142+
const parsedFIArray: FeatureImportance[] = getFeatureImportance(
143+
row,
144+
mlResultsField,
145+
isClassTypeBoolean
146+
);
147+
133148
return (
134149
<DecisionPathPopover
135150
analysisType={analysisType}

x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ export const getDefaultFieldsFromJobCaps = (
213213
name: `${resultsField}.${FEATURE_IMPORTANCE}`,
214214
type: KBN_FIELD_TYPES.UNKNOWN,
215215
});
216+
// remove flattened feature importance fields
217+
fields = fields.filter(
218+
(field: any) => !field.name.includes(`${resultsField}.${FEATURE_IMPORTANCE}.`)
219+
);
216220
}
217221

218222
if ((numTopClasses ?? 0) > 0) {
@@ -221,6 +225,10 @@ export const getDefaultFieldsFromJobCaps = (
221225
name: `${resultsField}.${TOP_CLASSES}`,
222226
type: KBN_FIELD_TYPES.UNKNOWN,
223227
});
228+
// remove flattened top classes fields
229+
fields = fields.filter(
230+
(field: any) => !field.name.includes(`${resultsField}.${TOP_CLASSES}.`)
231+
);
224232
}
225233

226234
// Only need to add these fields if we didn't use dest index pattern to get the fields

x-pack/plugins/ml/public/application/data_frame_analytics/common/get_index_data.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ export const getIndexData = async (
5353
index: jobConfig.dest.index,
5454
body: {
5555
fields: ['*'],
56-
_source: [],
56+
_source: false,
5757
query: searchQuery,
5858
from: pageIndex * pageSize,
5959
size: pageSize,

x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/exploration_results_table/exploration_results_table.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ interface Props {
2929
}
3030

3131
export const ExplorationResultsTable: FC<Props> = React.memo(
32-
({ indexPattern, jobConfig, jobStatus, needsDestIndexPattern, searchQuery }) => {
32+
({ indexPattern, jobConfig, needsDestIndexPattern, searchQuery }) => {
3333
const {
3434
services: {
3535
mlServices: { mlApiServices },

0 commit comments

Comments
 (0)