Saturday, February 6, 2016

PRHYME SPOJ (Trie) Solution - C++


1. Each node stores the number of children that each of its subtree has.
2. Every node stores information about two strings. The first string is the one which is smallest at that node (lexicographically; amongst all the strings that pass though that node) and the second one is the second smallest string.
3. Strings are inserted in reverse order into the trie.
4. For answering the queries, traverse the trie, check the number of nodes the current node has for the next character of the string. If it equals one, check if the best string stored at that node is the string being processed. If this is true, print the second string, else print the first string.


#include <stdio.h>
#include <stdlib.h>
#include <cstring>
#include <vector>
#include <math.h>
#include <algorithm>

using namespace std;

#define debug if(0)
#define W 250001
#define L 31
#define INF 0x7fffffff

struct node{
 int children[26]; 
 int num[26];
 int fs, ss; 

 void init(){
  for(int i=0;i<26;++i){
   children[i] = 0;
   fs = INF;
   ss = INF;

int gIndex; 
node trie[800000];
char dict[W][31];
int sorted[W]; 

void add(char *s, int jz, int cI, int len, int rank){
 int q[31];int qI=0;
 q[qI++] = cI; 
 for(int i=len-1;i>=0;--i){  
  int nI = trie[cI].children[s[i]-'a']; 
  int cS = cI;
   trie[cI].children[s[i]-'a'] = ++gIndex;
   cI = gIndex;  
   cI = nI;
  q[qI++] = cI;
 for(int i=qI-1;i>=0;--i){
  int v = q[i];  
   trie[v].ss = trie[v].fs;
   trie[v].fs = rank; 
  }else if(rank<trie[v].ss){ 
   trie[v].ss = rank; 

void find(char *s, int len){
  int cI = 0;  
   if(strcmp(s, dict[sorted[trie[cI].fs]])==0){puts(dict[sorted[trie[cI].ss]]);}
   else {puts(dict[sorted[trie[cI].fs]]);} 

  for(int i=len-1;i>=0;--i){
   cI = trie[cI].children[s[i]-'a']; 
   if((i-1)<0 || trie[cI].num[s[i-1]-'a']==1){
    if(strcmp(s, dict[sorted[trie[cI].fs]])==0){puts(dict[sorted[trie[cI].ss]]);}
    else {puts(dict[sorted[trie[cI].fs]]);}

inline bool comp(int a, int b){
 return strcmp(dict[a], dict[b])<0;

int main(){
 char s[31];
 int n=0, i;
 while(gets(dict[n]) && dict[n][0])n++;
 for(i=0;i<n;++i)sorted[i] = i;
 sort(sorted, sorted+n, comp);  

 gIndex= 0;
  int len = strlen(dict[sorted[i]]);
  add(dict[sorted[i]], len-1, 0, len, i); 
  find(s, strlen(s)); 

1 comment: