/* University of Strasbourg - Master ILC-ISI-RISE
 * Compilation Lab - Compiler for the arith language
 * Written by Cedric Bastoul cedric.bastoul@unistra.fr
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "ast.h"

// Create a new AST operation node with specified fields.
struct ast* ast_new_operation(enum ast_type op,
                              struct ast* left, struct ast* right) {
  struct ast* new = malloc(sizeof(struct ast));
  new->type = op;
  new->u.operation.left = left;
  new->u.operation.right = right;
  return new;
}

// Create a new AST statement node with specified fields.
struct ast* ast_new_statement(char* identifier,
                              struct ast* expression) {
  struct ast* new = malloc(sizeof(struct ast));
  new->type = ast_type_statement;
  new->u.statement.identifier = identifier;
  new->u.statement.expression = expression;
  new->u.statement.next = NULL;
  return new;
}

// Create a new AST number node with specified fields.
struct ast* ast_new_number(int number) {
  struct ast* new = malloc(sizeof(struct ast));
  new->type = ast_type_number;
  new->u.number = number;
  return new;
}

// Create a new AST identifier node with specified fields.
struct ast* ast_new_identifier(char* identifier) {
  struct ast* new = malloc(sizeof(struct ast));
  new->type = ast_type_identifier;
  new->u.identifier = identifier;
  return new;
}

// Concatenate AST statement list ast2 to ast1.
struct ast* ast_concat(struct ast* ast1, struct ast* ast2) {
  struct ast* temp = ast1;

  if (temp != NULL) {
    while (((temp->type == ast_type_statement) &&
            (temp->u.statement.next != NULL))) {
      temp = temp->u.statement.next;
    }
    if (temp->type == ast_type_statement)
      temp->u.statement.next = ast2;
  }
  return ast1;
}

// Free the allocated memory for an AST.
void ast_free(struct ast* ast) {
  struct ast* to_free = NULL;
  do {
    free(to_free);
    to_free = ast;
    switch (ast->type) {
      case ast_type_number:
        break;
      case ast_type_identifier:
        free(ast->u.identifier);
        break;
      case ast_type_statement:
        free(ast->u.statement.identifier);
        ast_free(ast->u.statement.expression);
        break;
      case ast_type_add:
        ast_free(ast->u.operation.left);
        ast_free(ast->u.operation.right);
        break;
      case ast_type_mul:
        ast_free(ast->u.operation.left);
        ast_free(ast->u.operation.right);
        break;
      default:
        printf("Unknown AST node type\n");
        exit(1);
    }
    if (ast->type == ast_type_statement)
      ast = ast->u.statement.next;
  } while ((ast != NULL) && (ast->type == ast_type_statement));
  free(to_free);
}

// Print an AST.
void ast_print(struct ast* ast, int indent) {
  size_t i;
  if (ast == NULL)
    return;
  
  do {
    for (i = 0; i < indent; i++)
      printf("    ");
    switch (ast->type) {
      case ast_type_number:
        printf("number (%d)\n", ast->u.number);
        break;
      case ast_type_identifier:
        printf("identifier (%s)\n", ast->u.identifier);
        break;
      case ast_type_statement:
        if (ast->u.statement.identifier == NULL)
          printf("statement print\n");
        else
          printf("statement %s = \n", ast->u.statement.identifier);
        ast_print(ast->u.statement.expression, indent + 1);
        break;
      case ast_type_add:
        printf("+\n");
        ast_print(ast->u.operation.left,  indent + 1);
        ast_print(ast->u.operation.right, indent + 1);
        break;
      case ast_type_mul:
        printf("*\n");
        ast_print(ast->u.operation.left,  indent + 1);
        ast_print(ast->u.operation.right, indent + 1);
        break;
      default:
        printf("Unknown AST node type\n");
        exit(1);
    }
    if (ast->type == ast_type_statement)
      ast = ast->u.statement.next;
  } while ((ast != NULL) && (ast->type == ast_type_statement));
}
