Readable rules from a Decision Tree model

One of the many algorithms supported by Oracle Data Mining (ODM) is the decision tree algorithm.  This algorithm is popular, in large part, due to the transparency of its internals.  ODM provides model details for its algorithms, and decision tree is no exception.  The dbms_data_mining.get_model_details_xml function is used to retrieve an XML representation of the tree (PMML compliant) which is a complete description needed for scoring.  Even though the xml is complete, it is not easy to read - it is not merely a simple table of rules.

So how can we produce something that is easy to understand?  Oracle has been busy adding support for XML to its query processing engine, and this functionality can be used to parse the xml document and translate it to relational form.  In addition, since trees are heirarchical in nature, Oracle's heirarchical processing (connect_by functionality) can be leveraged to roll up information along a path in the tree.

Given a decision tree model named DT_SH_CLAS_SAMPLE (as created by provided ODM sample code), the Oracle sql engine can be used to translate the xml into readable rules.

The distribution of target class values in each node can be generated with:
SELECT * FROM
    XMLTable('for $s in /PMML/TreeModel//ScoreDistribution
              return
                <scores id="{$s/../@id}"
                        tvalue="{$s/@value}"
                        tcount="{$s/@recordCount}"
                />'
      passing dbms_data_mining.get_model_details_xml('DT_SH_CLAS_SAMPLE')
            COLUMNS
              node_id      NUMBER PATH '/scores/@id',
              target_value VARCHAR2(4000) PATH '/scores/@tvalue',
              target_count NUMBER PATH '/scores/@tcount')
ORDER BY node_id, target_value;

This code uses XMLTable to parse the xml and convert the results to relational form, which are then simply returned without much further processing.  The only thing that needs to be changed to apply the above query to a different model is to change the name of the model that is passed to the get_model_details_xml function.

To generate the readable rules requires quite a bit more code, but can also be used for new models with the same replacement:
WITH X as
(SELECT * FROM
 XMLTable('for $n in /PMML/TreeModel//Node
            let $rf :=
              if (count($n/CompoundPredicate) > 0) then
                $n/CompoundPredicate/*[1]/@field
              else
                if (count($n/SimplePredicate) > 0) then
                  $n/SimplePredicate/@field
                else
                  $n/SimpleSetPredicate/@field
            let $ro :=
              if (count($n/CompoundPredicate) > 0) then
                if ($n/CompoundPredicate/*[1] instance of
                    element(SimplePredicate)) then
                  $n/CompoundPredicate/*[1]/@operator
                else if ($n/CompoundPredicate/*[1] instance of
                    element(SimpleSetPredicate)) then
                  ("in")
                else ()
              else
                if (count($n/SimplePredicate) > 0) then
                  $n/SimplePredicate/@operator
                else if (count($n/SimpleSetPredicate) > 0) then
                  ("in")
                else ()
            let $rv :=
              if (count($n/CompoundPredicate) > 0) then
                if ($n/CompoundPredicate/*[1] instance of
                    element(SimplePredicate)) then
                  $n/CompoundPredicate/*[1]/@value
                else
                  $n/CompoundPredicate/*[1]/Array/text()
              else
                if (count($n/SimplePredicate) > 0) then
                  $n/SimplePredicate/@value
                else
                  $n/SimpleSetPredicate/Array/text()
            let $sf :=
              if (count($n/CompoundPredicate) > 0) then
                $n/CompoundPredicate/*[2]/@field
              else ()
            let $so :=
              if (count($n/CompoundPredicate) > 0) then
                if ($n/CompoundPredicate/*[2] instance of
                    element(SimplePredicate)) then
                  $n/CompoundPredicate/*[2]/@operator
                else if ($n/CompoundPredicate/*[2] instance of
                    element(SimpleSetPredicate)) then
                  ("in")
                else ()
              else ()
            let $sv :=
              if (count($n/CompoundPredicate) > 0) then
                if ($n/CompoundPredicate/*[2] instance of
                    element(SimplePredicate)) then
                  $n/CompoundPredicate/*[2]/@value
                else
                  $n/CompoundPredicate/*[2]/Array/text()
              else ()
            return
              <pred id="{$n/../@id}"
                    score="{$n/@score}"
                    rec="{$n/@recordCount}"
                    cid="{$n/@id}"
                    rf="{$rf}"
                    ro="{$ro}"
                    rv="{$rv}"
                    sf="{$sf}"
                    so="{$so}"
                    sv="{$sv}"
              />'
      passing dbms_data_mining.get_model_details_xml('DT_SH_CLAS_SAMPLE')
            COLUMNS
              parent_node_id   NUMBER PATH '/pred/@id',
              child_node_id    NUMBER PATH '/pred/@cid',
              rec              NUMBER PATH '/pred/@rec',
              score            VARCHAR2(4000) PATH '/pred/@score',
              rule_field       VARCHAR2(4000) PATH '/pred/@rf',
              rule_op          VARCHAR2(20) PATH '/pred/@ro',
              rule_value       VARCHAR2(4000) PATH '/pred/@rv',
              surr_field       VARCHAR2(4000) PATH '/pred/@sf',
              surr_op          VARCHAR2(20) PATH '/pred/@so',
              surr_value       VARCHAR2(4000) PATH '/pred/@sv'))
select pid parent_node, nid node, rec record_count,
      score prediction, rule_pred local_rule, surr_pred local_surrogate,
      rtrim(replace(full_rule,'$O$D$M$'),' AND') full_simple_rule from (
select row_number() over (partition by nid order by rn desc) rn,
 pid, nid, rec, score, rule_pred, surr_pred, full_rule from (
 select rn, pid, nid, rec, score, rule_pred, surr_pred,
   sys_connect_by_path(pred, '$O$D$M$') full_rule from (
  select row_number() over (partition by nid order by rid) rn,
    pid, nid, rec, score, rule_pred, surr_pred,
    nvl2(pred,pred || ' AND ',null) pred from(
   select rid, pid, nid, rec, score, rule_pred, surr_pred,
     decode(rn, 1, pred, null) pred from (
    select rid, nid, rec, score, pid, rule_pred, surr_pred,
     nvl2(root_op, '(' || root_field || ' ' || root_op || ' ' || root_value || ')', null) pred,
     row_number() over (partition by nid, root_field, root_op order by rid desc) rn from (
     SELECT
       connect_by_root(parent_node_id) rid,
       child_node_id nid,
       rec, score,
       connect_by_root(rule_field) root_field,
       connect_by_root(rule_op) root_op,
       connect_by_root(rule_value) root_value,
       nvl2(rule_op, '(' || rule_field || ' ' || rule_op || ' ' || rule_value || ')',  null) rule_pred,
       nvl2(surr_op, '(' || surr_field || ' ' || surr_op || ' ' || surr_value || ')',  null) surr_pred,
       parent_node_id pid
       FROM (
        SELECT parent_node_id, child_node_id, rec, score, rule_field, surr_field, rule_op, surr_op,
               replace(replace(rule_value,'&quot; &quot;', ''', '''),'&quot;', '''') rule_value,
               replace(replace(surr_value,'&quot; &quot;', ''', '''),'&quot;', '''') surr_value
        FROM (
          SELECT parent_node_id, child_node_id, rec, score, rule_field, surr_field,
                 decode(rule_op,'lessOrEqual','<=','greaterThan','>',rule_op) rule_op,
                 decode(rule_op,'in','('||rule_value||')',rule_value) rule_value,
                 decode(surr_op,'lessOrEqual','<=','greaterThan','>',surr_op) surr_op,
                 decode(surr_op,'in','('||surr_value||')',surr_value) surr_value
          FROM X)
       )
       CONNECT BY PRIOR child_node_id = parent_node_id
     )
    )
   )
  )
  CONNECT BY PRIOR rn = rn - 1
         AND PRIOR nid = nid
  START WITH rn = 1
)
)
where rn = 1;

If this query is being run from sqlplus, make sure to:
set define off
to avoid replacing the & characters in the query.
For each node in the tree, this query will provide the id of the parent node, the number of records in the node, the top predicted class for the node, the rule followed to get to the node from its parent, the surrogate rule that would be followed if the main rule cannot be used (e.g., the attribute is null), and the entire, simplified, rule from the root to the node.
The simplified rule ignores surrogates and combines multiple rule pieces for the same attribute into a single piece to increase readability.

Note that these queries provide information for branches and leaves of the tree, which is useful since ODM can score a record as stopping in a branch if there is not enough information to continue further down the tree.
Comments:

Post a Comment:
  • HTML Syntax: NOT allowed
About

Everything about Oracle Data Mining, a component of the Oracle Advanced Analytics Option - News, Technical Information, Opinions, Tips & Tricks. All in One Place

Search

Categories
Archives
« April 2014
SunMonTueWedThuFriSat
  
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
   
       
Today