Refactor code that extracts terms for relevance-ranking
authorAdam Dickmeiss <adam@indexdata.dk>
Mon, 19 Sep 2011 12:03:16 +0000 (14:03 +0200)
committerAdam Dickmeiss <adam@indexdata.dk>
Mon, 19 Sep 2011 12:03:16 +0000 (14:03 +0200)
Function pull_terms checks for max number of terms in termlist -
to prevent SEGV/exploit.

src/client.c
src/relevance.c
src/relevance.h
src/session.c
src/session.h

index b7aa66a..ba2567d 100644 (file)
@@ -851,14 +851,6 @@ void client_disconnect(struct client *cl)
     client_set_connection(cl, 0);
 }
 
-// Extract terms from query into null-terminated termlist
-static void extract_terms(NMEM nmem, struct ccl_rpn_node *query, char **termlist)
-{
-    int num = 0;
-
-    pull_terms(nmem, query, termlist, &num);
-    termlist[num] = 0;
-}
 
 // Initialize CCL map for a target
 static CCL_bibset prepare_cclmap(struct client *cl)
@@ -1077,11 +1069,8 @@ int client_parse_query(struct client *cl, const char *query,
     if (!se->relevance)
     {
         // Initialize relevance structure with query terms
-        char *p[512];
-        extract_terms(se->nmem, cn, p);
-        se->relevance = relevance_create(
-            se->service->charsets,
-            se->nmem, (const char **) p);
+        se->relevance = relevance_create_ccl(
+            se->service->charsets, se->nmem, cn);
     }
 
     ccl_rpn_delete(cn);
index 708f2ba..933ca20 100644 (file)
@@ -120,8 +120,8 @@ void relevance_countwords(struct relevance *r, struct record_cluster *cluster,
     cluster->term_frequency_vec[0] += length;
 }
 
-struct relevance *relevance_create(pp2_charset_fact_t pft,
-                                   NMEM nmem, const char **terms)
+static struct relevance *relevance_create(pp2_charset_fact_t pft,
+                                          NMEM nmem, const char **terms)
 {
     struct relevance *res = nmem_malloc(nmem, sizeof(struct relevance));
     const char **p;
@@ -138,6 +138,47 @@ struct relevance *relevance_create(pp2_charset_fact_t pft,
     return res;
 }
 
+// Recursively traverse query structure to extract terms.
+static void pull_terms(NMEM nmem, struct ccl_rpn_node *n,
+                       char **termlist, int *num, int max_terms)
+{
+    char **words;
+    int numwords;
+    int i;
+
+    switch (n->kind)
+    {
+    case CCL_RPN_AND:
+    case CCL_RPN_OR:
+    case CCL_RPN_NOT:
+    case CCL_RPN_PROX:
+        pull_terms(nmem, n->u.p[0], termlist, num, max_terms);
+        pull_terms(nmem, n->u.p[1], termlist, num, max_terms);
+        break;
+    case CCL_RPN_TERM:
+        nmem_strsplit(nmem, " ", n->u.t.term, &words, &numwords);
+        for (i = 0; i < numwords; i++)
+        {
+            if (*num < max_terms)
+                termlist[(*num)++] = words[i];
+        }
+        break;
+    default: // NOOP
+        break;
+    }
+}
+
+struct relevance *relevance_create_ccl(pp2_charset_fact_t pft,
+                                       NMEM nmem, struct ccl_rpn_node *query)
+{
+    char *termlist[512];
+    int num = 0;
+
+    pull_terms(nmem, query, termlist, &num, sizeof(termlist)/sizeof(*termlist));
+    termlist[num] = 0;
+    return relevance_create(pft, nmem, (const char **) termlist);
+}
+
 void relevance_destroy(struct relevance **rp)
 {
     if (*rp)
index e357382..16682e6 100644 (file)
@@ -21,14 +21,15 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 #define RELEVANCE_H
 
 #include <yaz/yaz-util.h>
+#include <yaz/ccl.h>
 #include "charsets.h"
 
 struct relevance;
 struct record_cluster;
 struct reclist;
 
-struct relevance *relevance_create(pp2_charset_fact_t pft,
-                                   NMEM nmem, const char **terms);
+struct relevance *relevance_create_ccl(pp2_charset_fact_t pft,
+                                       NMEM nmem, struct ccl_rpn_node *query);
 void relevance_destroy(struct relevance **rp);
 void relevance_newrec(struct relevance *r, struct record_cluster *cluster);
 void relevance_countwords(struct relevance *r, struct record_cluster *cluster,
index 9875515..72feef6 100644 (file)
@@ -111,15 +111,17 @@ static int session_use(int delta)
         no_session_total += delta;
     sessions = no_sessions;
     yaz_mutex_leave(g_session_mutex);
-    yaz_log(YLOG_DEBUG, "%s sesions=%d", delta == 0 ? "" : (delta > 0 ? "INC" : "DEC"), no_sessions);
+    yaz_log(YLOG_DEBUG, "%s sessions=%d", delta == 0 ? "" : (delta > 0 ? "INC" : "DEC"), no_sessions);
     return sessions;
 }
 
-int sessions_count(void) {
+int sessions_count(void)
+{
     return session_use(0);
 }
 
-int  session_count_total(void) {
+int session_count_total(void)
+{
     int total = 0;
     if (!g_session_mutex)
         return 0;
@@ -129,7 +131,6 @@ int  session_count_total(void) {
     return total;
 }
 
-
 static void log_xml_doc(xmlDoc *doc)
 {
     FILE *lf = yaz_log_file();
@@ -158,33 +159,6 @@ static void session_leave(struct session *s)
     yaz_mutex_leave(s->session_mutex);
 }
 
-// Recursively traverse query structure to extract terms.
-void pull_terms(NMEM nmem, struct ccl_rpn_node *n, char **termlist, int *num)
-{
-    char **words;
-    int numwords;
-    int i;
-
-    switch (n->kind)
-    {
-    case CCL_RPN_AND:
-    case CCL_RPN_OR:
-    case CCL_RPN_NOT:
-    case CCL_RPN_PROX:
-        pull_terms(nmem, n->u.p[0], termlist, num);
-        pull_terms(nmem, n->u.p[1], termlist, num);
-        break;
-    case CCL_RPN_TERM:
-        nmem_strsplit(nmem, " ", n->u.t.term, &words, &numwords);
-        for (i = 0; i < numwords; i++)
-            termlist[(*num)++] = words[i];
-        break;
-    default: // NOOP
-        break;
-    }
-}
-
-
 void add_facet(struct session *s, const char *type, const char *value, int count)
 {
     struct conf_service *service = s->service;
index 2c943d9..fbb4e3c 100644 (file)
@@ -178,8 +178,6 @@ int host_getaddrinfo(struct host *host, iochan_man_t iochan_man);
 
 int ingest_record(struct client *cl, const char *rec, int record_no, NMEM nmem);
 void session_alert_watch(struct session *s, int what);
-void pull_terms(NMEM nmem, struct ccl_rpn_node *n, char **termlist, int *num);
-
 void add_facet(struct session *s, const char *type, const char *value, int count);
 void session_log(struct session *s, int level, const char *fmt, ...)
 #ifdef __GNUC__