Week 2 Code Snippets

Simple Enums

Colors

Here is a simple enumerated type

enum Color:
  case Red
  case Blue
  case Green

def getInitial(c:Color) = c match
  case Color.Red => 'R'
  case Color.Blue => 'B'
  case Color.Green => 'G'

def isRed(c:Color) = c match
  case Color.Red => true
  case _ => false

Avoiding the enum syntax, we can write this as:

trait ColorT:
    object Red extends ColorT
    object Blue extends ColorT
    object Green extends ColorT

Colors with parameters

enum Color:
  case RGB(red:Int, green:Int, blue:Int)
  case CMYK(cyan:Int, magenta:Int, yellow:Int, black:Int)

def chroma(c:Color) = c match
  case Color.RGB(r, g, b) => r+g+b
  case Color.CMYK(c, m, y, b) => 255-c-m-y-b

val x = Color.RGB(10,255,0)
chroma(x)
==> c match
     case Color.RGB(r, g, b) => r+g+b
     case Color.CMYK(c, m, y, b) => 255-c-m-y-b
    // where c = Color.RGB(10,255,0)
==> Color.RGB(10,255,0) match
     case Color.RGB(r, g, b) => r+g+b
     case Color.CMYK(c, m, y, b) => 255-c-m-y-b
==> r+g+b
    // where r=10, g=255, b=0
==> 10+255+0
==> 265

Here the same thing using a trait

trait ColorT:
case class RGB(red:Int, green:Int, blue:Int) extends ColorT
case class CMYK(cyan:Int, magenta:Int, yellow:Int, black:Int) extends ColorT

Lists as our own type

enum LList:
  case Empty
  case Node(head:Int, tail:LList)

import LList.*

def isEmpty(xs:LList) : Boolean = xs match
  case Empty => true
  case _ => false

def sum(xs:LList) : Int = xs match
  case Empty => 0
  case Node(h, t) => sum(t) + h

def f(xs:LList) : Int = xs match
  case Empty => ???
  case Node(h, t) => ???

val as = Empty
val bs = Node(31, as)
val cs = Node(21, bs)
val ds = Node(11, cs)

Builtin Lists

//enum List[T]:
//  case Nil
//  case ::(head:Int, tail:List[T])

def isEmpty(xs:List[Int]) : Boolean = xs match
  case Nil => true
  case _ => false

def sum(xs:List[Int]) : Int = xs match
  case Nil => 0
  case h :: t => sum(t) + h

def f(xs:List[Int]) : Int = xs match
  case Nil => ???
  case h :: t => ???

val as = Nil
val bs = 31 :: as
val cs = 21 :: bs
val ds = 11 :: cs

Expressions

enum Expr:
  case Num(x: Int)
  case Plus(l: Expr, r: Expr)
  case Sub(l: Expr, r: Expr)
  case Mult(l: Expr, r: Expr)

import Expr.*
def eval (e : Expr) : Int = e match
  case Num(x: Int) => x
  case Plus(l: Expr, r: Expr) => eval(l) + eval(r)
  case Sub(l: Expr, r: Expr) => eval(l) - eval(r)
  case Mult(l: Expr, r: Expr) => eval(l) * eval(r)

val e1 = Plus(Num(55), Num(11))
val e2 = Sub(Num(44), Num(66))
val e3 = Mult(e1, e2)

e3 = Mult(Plus(Num(55), Num(11)), Sub(Num(44), Num(66)))

eval(e3)
==>
e match
  case Num(x: Int) => x
  case Plus(l: Expr, r: Expr) => eval(l) + eval(r)
  case Sub(l: Expr, r: Expr) => eval(l) - eval(r)
  case Mult(l: Expr, r: Expr) => eval(l) * eval(r)
// where e = Mult(Plus(Num(55), Num(11)), Sub(Num(44), Num(66)))
==>
Mult(Plus(Num(55), Num(11)), Sub(Num(44), Num(66))) match
  case Num(x: Int) => x
  case Plus(l: Expr, r: Expr) => eval(l) + eval(r)
  case Sub(l: Expr, r: Expr) => eval(l) - eval(r)
  case Mult(l: Expr, r: Expr) => eval(l) * eval(r)
==>
eval(l) * eval(r)
// where l = Plus(Num(55), Num(11)
// and   r = Sub(Num(44), Num(66))
==>
eval(Plus(Num(55), Num(11)) * eval(Sub(Num(44), Num(66)))

Calculator: Object-Oriented vs. Functional Implementation

We implement a simple calculator. Expressions are created as objects in memory, and can be evaluated and printed.

Calculator in Scala

Scala supports blends object-oriented and functional programming.

Object-Oriented Implementation

We use a trait (similar to a Java interface) to specify operations common to all expressions. Concrete implementations Number, Plus, and Times of the expression trait each implement a fragment of the full eval and print functionality.

trait Expr:
  def eval: Int
  def print: String
end Expr

class Number(val x: Int) extends Expr:
  def eval: Int = x
  def print: String = x.toString
end Number

class Plus(val left: Expr, val right: Expr) extends Expr:
  def eval: Int = left.eval + right.eval
  def print: String = s"(${left.print} + ${right.print})"
end Plus

class Times(val left: Expr, val right: Expr) extends Expr:
  def eval: Int = left.eval * right.eval
  def print: String = s"(${left.print} * ${right.print})"
end Times

In object-oriented languages, adding new classes is easy; it is harder to understand the full functionality of eval and print, because their implementation is spread over many classes. Adding new functionality may require changing many existing classes (as opposed to a single function in functional programming).

Functional Implementation

We implement the same calculator in a functional style in Scala, using enum to express an algebraic datatype, and functions eval and print with pattern matching on the argument shape.

enum Expr:
  case Number(x: Int)
  case Plus(left: Expr, right: Expr)
  case Times(left: Expr, right: Expr)
end Expr

import Expr.*

def eval(expr: Expr): Int = expr match
  case Number(x)          => x
  case Plus(left, right)  => eval(left) + eval(right)
  case Times(left, right) => eval(left) * eval(right)
end eval

def print(e: Expr): String = e match
  case Number(x)          => x.toString
  case Plus(left, right)  => s"(${print(left)} + ${print(right)})"
  case Times(left, right) => s"(${print(left)} * ${print(right)})"
end print

In functional programming languages, adding new functions is easy; it is harder to understand the full functionality of a certain class, like Plus, because it is spread over many functions. Extending the algebraic datatype may require changing many different functions (not just adding a single class as in object-oriented programming).

Calculator in Java

Java was originally designed as an object-oriented language, but more and more features of functional languages become available in modern Java.

Object-Oriented Implementation

public interface Expr {
  public int eval();
  public String print();
}

public class Number implements Expr {
  private final int x;

  public Number(int x) {
    this.x = x;
  }

  @Override
  public int eval() {
    return x;
  }

  @Override
  public String print() {
    return Integer.toString(x);
  }
}

public class Plus implements Expr {
  private final Expr left;
  private final Expr right;

  public Plus(Expr left, Expr right) {
    this.left = left;
    this.right = right;
  }

  @Override
  public int eval() {
    return left.eval() + right.eval();
  }

  @Override
  public String print() {
    return "(" + left.print() + " + " + right.print() + ")";
  }
}

public class Times implements Expr {
  private final Expr left;
  private final Expr right;

  public Times(Expr left, Expr right) {
    this.left = left;
    this.right = right;
  }

  @Override
  public int eval() {
    return left.eval() * right.eval();
  }

  @Override
  public String print() {
    return "(" + left.print() + " * " + right.print() + ")";
  }
}

Functional Implementation

Java 21 provides record patterns and pattern matching switch. These can be used to implement our calculator in a functional style in Java. The base interface is sealed so that the compiler can check that the switch lists all records exhaustively (without sealed, additional implementations of Expr could be added elsewhere in the program). The Java code uses var to declare a variable without annotated type (i.e., use type inference). Like var in Scala, var in Java does also mean mutability and we can use final to make a variable immutable (final var).

public class FunctionalCalculator {

  sealed interface Expr {}

  record Number(int x) implements Expr {}

  record Plus(Expr left, Expr right) implements Expr {}

  record Times(Expr left, Expr right) implements Expr {}

  public static int eval(Expr expr) {
    return switch (expr) {
      case Number(var x)              -> x;
      case Plus(var left, var right)  -> eval(left) + eval(right);
      case Times(var left, var right) -> eval(left) * eval(right);
    };
  }

  public static String print(Expr e) {
    return switch (e) {
      case Number(var x)              -> Integer.toString(x);
      case Plus(var left, var right)  -> "(" + print(left) + " + " + print(right) + ")";
      case Times(var left, var right) -> "(" + print(left) + " * " + print(right) + ")";
    };
  }

  public static void main(String[] args) {
    Expr expr = new Plus(new Number(1), new Times(new Number(2), new Number(3)));
    System.out.println("Expression: " + print(expr));
    System.out.println("Evaluated Result: " + eval(expr));
  }
}

Implementation in Python

Python has object-oriented features and, with @dataclass we can emulate algebraic datatypes. Pattern matching is an enhancement proposal PEP 634; we use conditionals based on type tests instead.

Object-Oriented Implementation

class Expr:
  def eval(self):
    raise NotImplementedError

  def print(self):
    raise NotImplementedError

class Number(Expr):
  def __init__(self, x):
    self.x = x

  def eval(self):
    return self.x

  def print(self):
    return str(self.x)

class Plus(Expr):
  def __init__(self, left, right):
    self.left = left
    self.right = right

  def eval(self):
    return self.left.eval() + self.right.eval()

  def print(self):
    return f"({self.left.print()} + {self.right.print()})"

class Times(Expr):
  def __init__(self, left, right):
    self.left = left
    self.right = right

  def eval(self):
    return self.left.eval() * self.right.eval()

  def print(self):
    return f"({self.left.print()} * {self.right.print()})"

if __name__ == "__main__":
    expr1 = Plus(Number(3), Number(4))
    expr2 = Times(Number(5), Plus(Number(2), Number(3)))

    print(f"Expression 1: {expr1.print()} = {expr1.eval()}")
    print(f"Expression 2: {expr2.print()} = {expr2.eval()}")

Emulating Functional Implementation

from dataclasses import dataclass

@dataclass
class Number:
  x: int

@dataclass
class Plus:
  left: 'Expr'
  right: 'Expr'

@dataclass
class Times:
  left: 'Expr'
  right: 'Expr'

Expr = Number | Plus | Times

def eval(expr: Expr) -> int:
  if isinstance(expr, Number):
    return expr.x
  elif isinstance(expr, Plus):
    return eval(expr.left) + eval(expr.right)
  elif isinstance(expr, Times):
    return eval(expr.left) * eval(expr.right)
  else:
    raise ValueError("Unknown expression type")

def print_expr(e: Expr) -> str:
  if isinstance(e, Number):
    return str(e.x)
  elif isinstance(e, Plus):
    return f"({print_expr(e.left)} + {print_expr(e.right)})"
  elif isinstance(e, Times):
    return f"({print_expr(e.left)} * {print_expr(e.right)})"
  else:
    raise ValueError("Unknown expression type")

if __name__ == "__main__":
    expr1 = Plus(Number(3), Number(4))
    expr2 = Times(Number(5), Plus(Number(2), Number(3)))

    print(f"Expression 1: {print_expr(expr1)} = {eval(expr1)}")
    print(f"Expression 2: {print_expr(expr2)} = {eval(expr2)}")

Implementation in C

In C, we do not have access to runtime type information. As a consequence, we must create type tags ourselves (ExprType). The type Expr uses a type tag type to identify the alternative of the union type int | struct. We use the type tag in switch statements to determine which implementation to choose. Due to manual memory management, the print function is especially verbose.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

typedef enum {
  NUMBER,
  PLUS,
  TIMES
} ExprType;

typedef struct Expr Expr;

struct Expr {
  ExprType type;
  union {
      int number;
      struct {
          Expr* left;
          Expr* right;
      } binary;
  } data;
};

Expr* make_number(int x) {
  Expr* e = malloc(sizeof(Expr));
  e->type = NUMBER;
  e->data.number = x;
  return e;
}

Expr* make_binary(Expr* left, Expr* right, ExprType type) {
  Expr* e = malloc(sizeof(Expr));
  e->type = type;
  e->data.binary.left = left;
  e->data.binary.right = right;
  return e;
}

Expr* make_plus(Expr* left, Expr* right) {
  return make_binary(left, right, PLUS);
}

Expr* make_times(Expr* left, Expr* right) {
  return make_binary(left, right, TIMES);
}

int eval(Expr* expr) {
  switch (expr->type) {
    case NUMBER:
      return expr->data.number;
    case PLUS:
      return eval(expr->data.binary.left) + eval(expr->data.binary.right);
    case TIMES:
      return eval(expr->data.binary.left) * eval(expr->data.binary.right);
    default:
      return 0; // error
  }
}

char* print_expr(Expr* e) {
  char* left_str;
  char* right_str;
  char* result;
  switch (e->type) {
    case NUMBER:
      result = malloc(12); // enough for int
      sprintf(result, "%d", e->data.number);
      return result;
    case PLUS:
      left_str = print_expr(e->data.binary.left);
      right_str = print_expr(e->data.binary.right);
      result = malloc(strlen(left_str) + strlen(right_str) + 6); // ( +  )
      sprintf(result, "(%s + %s)", left_str, right_str);
      free(left_str);
      free(right_str);
      return result;
    case TIMES:
      left_str = print_expr(e->data.binary.left);
      right_str = print_expr(e->data.binary.right);
      result = malloc(strlen(left_str) + strlen(right_str) + 6);
      sprintf(result, "(%s * %s)", left_str, right_str);
      free(left_str);
      free(right_str);
      return result;
    default:
      return NULL;
  }
}

void free_expr(Expr* expr) {
  if (expr->type != NUMBER) {
    free_expr(expr->data.binary.left);
    free_expr(expr->data.binary.right);
  }
  free(expr);
}

int main() {
  Expr* expr = make_plus(
    make_number(3),
    make_times(
      make_number(4),
      make_number(5)
    )
  );

  char* expr_str = print_expr(expr);
  int result = eval(expr);

  printf("Expression: %s\n", expr_str);
  printf("Result: %d\n", result);

  free(expr_str);
  free_expr(expr);

  return 0;
}

Class Example: Matrix

We implement a class Matrix that can be initialized with a two-dimensional array. The assumption that all rows have the same number of columns is encoded in a require expression in the class body. For illustration purposes, this class has a design flaw that will let us trigger a runtime exception in the + operation. Can you spot it?

class Matrix(val data: Array[Array[Int]]):
  require (data != null && data.nonEmpty && data.forall(_.length == data(0).length), "All rows must have the same number of columns")

  val rows = data.length
  val cols = data(0).length

  def +(other: Matrix): Matrix =
    val result = Array.ofDim[Int](rows, cols)

    for (i <- 0 until rows; j <- 0 until cols) do
      result(i)(j) = this.data(i)(j) + other.data(i)(j)
    end for

    new Matrix(result)
  end +

  override def toString: String =
    data.map(row => row.mkString(" ")).mkString("\n")
  end toString
end Matrix

We instantiate a new object of class Matrix:

val m = new Matrix(Array(
  Array(1, 2, 3),
  Array(4, 5, 6)
))

The intended use of updating cells in the matrix is:

m.data(0)(0) = 10
println(m)

But our choice of mutable array (even though the field data is immutable) lets us also modify the rows of the matrix:

m.data(0) = Array(10)
println(m)

As a result, the following sum operation results in a runtime exception: the rows of the matrix m now violate the assumption made implicitly in val cols = data(0).length that m has still the same number of columns as during initialization:

val sum = m + new Matrix(Array(
  Array(1, 1, 1),
  Array(1, 1, 1)
))