regular structs for stmt,expr / track span

Docs are somewhat out of date.
Tracking span should enable re-use of type-check results.
This commit is contained in:
Niko Matsakis 2022-08-19 06:46:54 -04:00
parent d83d3c44f8
commit 590c5ce8d3
5 changed files with 415 additions and 83 deletions

View file

@ -69,17 +69,49 @@ When you change the value of an input field, that increments a 'revision counter
indicating that some inputs are different now.
When we talk about a "revision" of the database, we are referring to the state of the database in between changes to the input values.
## Tracked structs
### Representing the parsed program
Next we will define a **tracked struct** to represent the functions in our input.
Next we will define a **tracked struct**.
Whereas inputs represent the *start* of a computation, tracked structs represent intermediate values created during your computation.
In this case, we are going to parse the raw input program, and create a `Function` for each of the functions defined by the user.
In this case, the parser is going to take in the `SourceProgram` struct that we saw and return a `Program` that represents the fully parsed program:
```rust
{{#include ../../../calc-example/calc/src/ir.rs:program}}
```
Like with an input, the fields of a tracked struct are also stored in the database.
Unlike an input, those fields are immutable (they cannot be "set"), and salsa compares them across revisions to know when they have changed.
In this case, if parsing the input produced the same `Program` result (e.g., because the only change to the input was some trailing whitespace, perhaps),
then subsequent parts of the computation won't need to re-execute.
(We'll revisit the role of tracked structs in reuse more in future parts of the IR.)
Apart from the fields being immutable, the API for working with a tracked struct is quite similar to an input:
* You can create a new value by using `new`, but with a tracked struct, you only need an `&dyn` database, not `&mut` (e.g., `Program::new(&db, some_staements)`)
* You use a getter to read the value of a field, just like with an input (e.g., `my_func.statements(db)` to read the `statements` field).
* In this case, the field is tagged as `#[return_ref]`, which means that the getter will return a `&Vec<Statement>`, instead of cloning the vector.
## Representing functions
We will also use a tracked struct to represent each function:
Next we will define a **tracked struct**.
Whereas inputs represent the *start* of a computation, tracked structs represent intermediate values created during your computation.
The `Function` struct is going to be created by the parser to represent each of the functions defined by the user:
```rust
{{#include ../../../calc-example/calc/src/ir.rs:functions}}
```
Unlike with inputs, the fields of tracked structs are immutable once created. Otherwise, working with a tracked struct is quite similar to an input:
Like with an input, the fields of a tracked struct are also stored in the database.
Unlike an input, those fields are immutable (they cannot be "set"), and salsa compares them across revisions to know when they have changed.
If we had created some `Function` instance `f`, for example, we might find that `the f.body` field changes
because the user changed the definition of `f`.
This would mean that we have to re-execute those parts of the code that depended on `f.body`
(but not those parts of the code that depended on the body of *other* functions).
Apart from the fields being immutable, the API for working with a tracked struct is quite similar to an input:
* You can create a new value by using `new`, but with a tracked struct, you only need an `&dyn` database, not `&mut` (e.g., `Function::new(&db, some_name, some_args, some_body)`)
* You use a getter to read the value of a field, just like with an input (e.g., `my_func.args(db)` to read the `args` field).
@ -116,14 +148,6 @@ let f2 = FunctionId::new(&db, "my_string".to_string());
assert_eq!(f1, f2);
```
### Expressions and statements
We'll also intern expressions and statements. This is convenient primarily because it allows us to have recursive structures very easily. Since we don't really need the "cheap equality comparison" aspect of interning, this isn't the most efficient choice, and many compilers would opt to represent expressions/statements in some other way.
```rust
{{#include ../../../calc-example/calc/src/ir.rs:statements_and_expressions}}
```
### Interned ids are guaranteed to be consistent within a revision, but not across revisions (but you don't have to care)
Interned ids are guaranteed not to change within a single revision, so you can intern things from all over your program and get back consistent results.
@ -134,3 +158,16 @@ just a different one than they saw in a previous revision.
In other words, within a salsa computation, you can assume that interning produces a single consistent integer, and you don't have to think about it.
If however you export interned identifiers outside the computation, and then change the inputs, they may not longer be valid or may refer to different values.
### Expressions and statements
We'll won't use any special "salsa structs" for expressions and statements:
```rust
{{#include ../../../calc-example/calc/src/ir.rs:statements_and_expressions}}
```
Since statements and expressions are not tracked, this implies that we are only attempting to get incremental re-use at the granularity of functions --
whenever anything in a function body changes, we consider the entire function body dirty and re-execute anything that depended on it.
It usually makes sense to draw some kind of "reasonably coarse" boundary like this.
One downside of the way we have set things up: we inlined the position into each of the structs.

View file

@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
derive-new = "0.5.9"
salsa = { path = "../../components/salsa-2022", package = "salsa-2022" }
ordered-float = "3.0"

View file

@ -1,3 +1,4 @@
use derive_new::new;
use ordered_float::OrderedFloat;
use salsa::debug::DebugWithDb;
@ -23,18 +24,23 @@ pub struct FunctionId {
}
// ANCHOR_END: interned_ids
// ANCHOR: statements_and_expressions
// ANCHOR: program
#[salsa::tracked]
pub struct Program {
#[return_ref]
statements: Vec<Statement>,
}
// ANCHOR_END: program
#[salsa::interned]
// ANCHOR: statements_and_expressions
#[derive(Eq, PartialEq, Debug, Hash, new)]
pub struct Statement {
span: Span,
data: StatementData,
}
#[derive(Eq, PartialEq, Clone, Hash)]
#[derive(Eq, PartialEq, Debug, Hash)]
pub enum StatementData {
/// Defines `fn <name>(<args>) = <body>`
Function(Function),
@ -42,15 +48,16 @@ pub enum StatementData {
Print(Expression),
}
#[salsa::interned]
#[derive(Eq, PartialEq, Debug, Hash, new)]
pub struct Expression {
#[return_ref]
span: Span,
data: ExpressionData,
}
#[derive(Eq, PartialEq, Clone, Hash)]
#[derive(Eq, PartialEq, Debug, Hash)]
pub enum ExpressionData {
Op(Expression, Op, Expression),
Op(Box<Expression>, Op, Box<Expression>),
Number(OrderedFloat<f64>),
Variable(VariableId),
Call(FunctionId, Vec<Expression>),
@ -77,7 +84,7 @@ impl DebugWithDb<dyn crate::Db + '_> for Function {
impl DebugWithDb<dyn crate::Db + '_> for Statement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn crate::Db) -> std::fmt::Result {
match self.data(db) {
match &self.data {
StatementData::Function(a) => DebugWithDb::fmt(&a, f, db),
StatementData::Print(a) => DebugWithDb::fmt(&a, f, db),
}
@ -87,7 +94,7 @@ impl DebugWithDb<dyn crate::Db + '_> for Statement {
// ANCHOR: expression_debug_impl
impl DebugWithDb<dyn crate::Db + '_> for Expression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn crate::Db) -> std::fmt::Result {
match self.data(db) {
match &self.data {
ExpressionData::Op(a, b, c) => f
.debug_tuple("ExpressionData::Op")
.field(&a.debug(db)) // use `a.debug(db)` for interned things
@ -149,11 +156,22 @@ impl DebugWithDb<dyn crate::Db + '_> for Diagnostic {
pub struct Function {
#[id]
name: FunctionId,
name_span: Span,
args: Vec<VariableId>,
#[return_ref]
body: Expression,
}
// ANCHOR_END: functions
#[salsa::tracked]
pub struct Span {
pub start: usize,
pub end: usize,
}
// ANCHOR: diagnostic
#[salsa::accumulator]
pub struct Diagnostics(Diagnostic);

View file

@ -5,11 +5,11 @@ pub struct Jar(
crate::ir::Program,
crate::ir::VariableId,
crate::ir::FunctionId,
crate::ir::Expression,
crate::ir::Statement,
crate::ir::Function,
crate::ir::Diagnostics,
crate::ir::Span,
crate::parser::parse_statements,
crate::type_check::type_check_program,
);
// ANCHOR_END: jar_struct

View file

@ -2,11 +2,11 @@ use ordered_float::OrderedFloat;
use crate::ir::{
Diagnostic, Diagnostics, Expression, ExpressionData, Function, FunctionId, Op, Program,
SourceProgram, Statement, StatementData, VariableId,
SourceProgram, Span, Statement, StatementData, VariableId,
};
// ANCHOR: parse_statements
#[salsa::tracked(return_ref)]
#[salsa::tracked]
pub fn parse_statements(db: &dyn crate::Db, source: SourceProgram) -> Program {
// Get the source text from the database
let source_text = source.text(db);
@ -93,34 +93,44 @@ impl Parser<'_> {
self.source_text[self.position..].chars().next()
}
// Returns a span ranging from `start_position` until the current position (exclusive)
fn span_from(&self, start_position: usize) -> Span {
Span::new(self.db, start_position, self.position)
}
fn consume(&mut self, ch: char) {
debug_assert!(self.peek() == Some(ch));
self.position += ch.len_utf8();
}
fn skip_whitespace(&mut self) {
/// Skips whitespace and returns the new position.
fn skip_whitespace(&mut self) -> usize {
while let Some(ch) = self.peek() {
if ch.is_whitespace() {
self.consume(ch);
} else {
return;
break;
}
}
self.position
}
// ANCHOR: parse_statement
fn parse_statement(&mut self) -> Option<Statement> {
self.skip_whitespace();
let start_position = self.skip_whitespace();
let word = self.word()?;
if word == "fn" {
let func = self.parse_function()?;
Some(Statement::new(self.db, StatementData::Function(func)))
// ^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^
// Create a new interned enum... |
// using the "data" type.
Some(Statement::new(
self.span_from(start_position),
StatementData::Function(func),
))
} else if word == "print" {
let expr = self.parse_expression()?;
Some(Statement::new(self.db, StatementData::Print(expr)))
Some(Statement::new(
self.span_from(start_position),
StatementData::Print(expr),
))
} else {
None
}
@ -129,7 +139,9 @@ impl Parser<'_> {
// ANCHOR: parse_function
fn parse_function(&mut self) -> Option<Function> {
let start_position = self.skip_whitespace();
let name = self.word()?;
let name_span = self.span_from(start_position);
let name: FunctionId = FunctionId::new(self.db, name);
// ^^^^^^^^^^^^^^^
// Create a new interned struct.
@ -138,7 +150,7 @@ impl Parser<'_> {
self.ch(')')?;
self.ch('=')?;
let body = self.parse_expression()?;
Some(Function::new(self.db, name, args, body))
Some(Function::new(self.db, name, name_span, args, body))
// ^^^^^^^^^^^^^
// Create a new entity struct.
}
@ -149,9 +161,9 @@ impl Parser<'_> {
}
fn low_op(&mut self) -> Option<Op> {
if self.ch('+').is_some() {
if let Some(_) = self.ch('+') {
Some(Op::Add)
} else if self.ch('-').is_some() {
} else if let Some(_) = self.ch('-') {
Some(Op::Subtract)
} else {
None
@ -166,9 +178,9 @@ impl Parser<'_> {
}
fn high_op(&mut self) -> Option<Op> {
if self.ch('*').is_some() {
if let Some(_) = self.ch('*') {
Some(Op::Multiply)
} else if self.ch('/').is_some() {
} else if let Some(_) = self.ch('/') {
Some(Op::Divide)
} else {
None
@ -180,11 +192,15 @@ impl Parser<'_> {
mut parse_expr: impl FnMut(&mut Self) -> Option<Expression>,
mut op: impl FnMut(&mut Self) -> Option<Op>,
) -> Option<Expression> {
let start_position = self.skip_whitespace();
let mut expr1 = parse_expr(self)?;
while let Some(op) = op(self) {
let expr2 = parse_expr(self)?;
expr1 = Expression::new(self.db, ExpressionData::Op(expr1, op, expr2));
expr1 = Expression::new(
self.span_from(start_position),
ExpressionData::Op(Box::new(expr1), op, Box::new(expr2)),
);
}
Some(expr1)
@ -194,22 +210,29 @@ impl Parser<'_> {
///
/// On failure, skips arbitrary tokens.
fn parse_expression2(&mut self) -> Option<Expression> {
let start_position = self.skip_whitespace();
if let Some(w) = self.word() {
if let Some(()) = self.ch('(') {
if let Some(_) = self.ch('(') {
let f = FunctionId::new(self.db, w);
let args = self.parse_expressions()?;
self.ch(')')?;
return Some(Expression::new(self.db, ExpressionData::Call(f, args)));
return Some(Expression::new(
self.span_from(start_position),
ExpressionData::Call(f, args),
));
}
let v = VariableId::new(self.db, w);
Some(Expression::new(self.db, ExpressionData::Variable(v)))
Some(Expression::new(
self.span_from(start_position),
ExpressionData::Variable(v),
))
} else if let Some(n) = self.number() {
Some(Expression::new(
self.db,
self.span_from(start_position),
ExpressionData::Number(OrderedFloat::from(n)),
))
} else if let Some(()) = self.ch('(') {
} else if let Some(_) = self.ch('(') {
let expr = self.parse_expression()?;
self.ch(')')?;
Some(expr)
@ -249,12 +272,12 @@ impl Parser<'_> {
/// Parses a single character.
///
/// Even on failure, only skips whitespace.
fn ch(&mut self, c: char) -> Option<()> {
self.skip_whitespace();
fn ch(&mut self, c: char) -> Option<Span> {
let start_position = self.skip_whitespace();
match self.peek() {
Some(p) if c == p => {
self.consume(c);
Some(())
Some(self.span_from(start_position))
}
_ => None,
}
@ -269,6 +292,7 @@ impl Parser<'_> {
// In this loop, if we consume any characters, we always
// return `Some`.
let mut s = String::new();
let position = self.position;
while let Some(ch) = self.peek() {
if ch.is_alphabetic() || ch == '_' {
s.push(ch);
@ -292,7 +316,7 @@ impl Parser<'_> {
///
/// Even on failure, only skips whitespace.
fn number(&mut self) -> Option<f64> {
self.skip_whitespace();
let start_position = self.skip_whitespace();
self.probe(|this| {
// 👆 We need the call to `probe` here because we could consume
@ -354,11 +378,49 @@ fn parse_print() {
(
Program {
statements: [
Statement(
Id {
value: 1,
},
),
Statement {
span: Span(
Id {
value: 5,
},
),
data: Print(
Expression {
span: Span(
Id {
value: 4,
},
),
data: Op(
Expression {
span: Span(
Id {
value: 1,
},
),
data: Number(
OrderedFloat(
1.0,
),
),
},
Add,
Expression {
span: Span(
Id {
value: 3,
},
),
data: Number(
OrderedFloat(
2.0,
),
),
},
),
},
),
},
],
},
[],
@ -382,31 +444,163 @@ fn parse_example() {
(
Program {
statements: [
Statement(
Id {
value: 1,
},
),
Statement(
Id {
value: 2,
},
),
Statement(
Id {
value: 3,
},
),
Statement(
Id {
value: 4,
},
),
Statement(
Id {
value: 5,
},
),
Statement {
span: Span(
Id {
value: 10,
},
),
data: Function(
Function(
Id {
value: 1,
},
),
),
},
Statement {
span: Span(
Id {
value: 22,
},
),
data: Function(
Function(
Id {
value: 2,
},
),
),
},
Statement {
span: Span(
Id {
value: 29,
},
),
data: Print(
Expression {
span: Span(
Id {
value: 28,
},
),
data: Call(
FunctionId(
Id {
value: 1,
},
),
[
Expression {
span: Span(
Id {
value: 24,
},
),
data: Number(
OrderedFloat(
3.0,
),
),
},
Expression {
span: Span(
Id {
value: 26,
},
),
data: Number(
OrderedFloat(
4.0,
),
),
},
],
),
},
),
},
Statement {
span: Span(
Id {
value: 34,
},
),
data: Print(
Expression {
span: Span(
Id {
value: 33,
},
),
data: Call(
FunctionId(
Id {
value: 2,
},
),
[
Expression {
span: Span(
Id {
value: 31,
},
),
data: Number(
OrderedFloat(
1.0,
),
),
},
],
),
},
),
},
Statement {
span: Span(
Id {
value: 39,
},
),
data: Print(
Expression {
span: Span(
Id {
value: 38,
},
),
data: Op(
Expression {
span: Span(
Id {
value: 35,
},
),
data: Number(
OrderedFloat(
11.0,
),
),
},
Multiply,
Expression {
span: Span(
Id {
value: 37,
},
),
data: Number(
OrderedFloat(
2.0,
),
),
},
),
},
),
},
],
},
[],
@ -440,11 +634,93 @@ fn parse_precedence() {
(
Program {
statements: [
Statement(
Id {
value: 1,
},
),
Statement {
span: Span(
Id {
value: 11,
},
),
data: Print(
Expression {
span: Span(
Id {
value: 10,
},
),
data: Op(
Expression {
span: Span(
Id {
value: 7,
},
),
data: Op(
Expression {
span: Span(
Id {
value: 1,
},
),
data: Number(
OrderedFloat(
1.0,
),
),
},
Add,
Expression {
span: Span(
Id {
value: 6,
},
),
data: Op(
Expression {
span: Span(
Id {
value: 3,
},
),
data: Number(
OrderedFloat(
2.0,
),
),
},
Multiply,
Expression {
span: Span(
Id {
value: 5,
},
),
data: Number(
OrderedFloat(
3.0,
),
),
},
),
},
),
},
Add,
Expression {
span: Span(
Id {
value: 9,
},
),
data: Number(
OrderedFloat(
4.0,
),
),
},
),
},
),
},
],
},
[],